1use crate::adaptive_gate_fusion::QuantumGate;
6use quantrs2_circuit::prelude::*;
7use quantrs2_core::prelude::*;
8use scirs2_core::ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, Axis};
9use scirs2_core::Complex64;
10
11use super::types::{
12 AdvancedContractionAlgorithms, ContractionStrategy, IndexType, Tensor, TensorIndex,
13 TensorNetwork, TensorNetworkSimulator,
14};
15
16pub(super) fn pauli_x() -> Array2<Complex64> {
17 Array2::from_shape_vec(
18 (2, 2),
19 vec![
20 Complex64::new(0.0, 0.0),
21 Complex64::new(1.0, 0.0),
22 Complex64::new(1.0, 0.0),
23 Complex64::new(0.0, 0.0),
24 ],
25 )
26 .expect("Pauli-X matrix has valid 2x2 shape")
27}
28pub(super) fn pauli_y() -> Array2<Complex64> {
29 Array2::from_shape_vec(
30 (2, 2),
31 vec![
32 Complex64::new(0.0, 0.0),
33 Complex64::new(0.0, -1.0),
34 Complex64::new(0.0, 1.0),
35 Complex64::new(0.0, 0.0),
36 ],
37 )
38 .expect("Pauli-Y matrix has valid 2x2 shape")
39}
40pub(super) fn pauli_z() -> Array2<Complex64> {
41 Array2::from_shape_vec(
42 (2, 2),
43 vec![
44 Complex64::new(1.0, 0.0),
45 Complex64::new(0.0, 0.0),
46 Complex64::new(0.0, 0.0),
47 Complex64::new(-1.0, 0.0),
48 ],
49 )
50 .expect("Pauli-Z matrix has valid 2x2 shape")
51}
52pub(super) fn pauli_h() -> Array2<Complex64> {
53 let inv_sqrt2 = 1.0 / 2.0_f64.sqrt();
54 Array2::from_shape_vec(
55 (2, 2),
56 vec![
57 Complex64::new(inv_sqrt2, 0.0),
58 Complex64::new(inv_sqrt2, 0.0),
59 Complex64::new(inv_sqrt2, 0.0),
60 Complex64::new(-inv_sqrt2, 0.0),
61 ],
62 )
63 .expect("Hadamard matrix has valid 2x2 shape")
64}
65pub(super) fn cnot_matrix() -> Array2<Complex64> {
66 Array2::from_shape_vec(
67 (4, 4),
68 vec![
69 Complex64::new(1.0, 0.0),
70 Complex64::new(0.0, 0.0),
71 Complex64::new(0.0, 0.0),
72 Complex64::new(0.0, 0.0),
73 Complex64::new(0.0, 0.0),
74 Complex64::new(1.0, 0.0),
75 Complex64::new(0.0, 0.0),
76 Complex64::new(0.0, 0.0),
77 Complex64::new(0.0, 0.0),
78 Complex64::new(0.0, 0.0),
79 Complex64::new(0.0, 0.0),
80 Complex64::new(1.0, 0.0),
81 Complex64::new(0.0, 0.0),
82 Complex64::new(0.0, 0.0),
83 Complex64::new(1.0, 0.0),
84 Complex64::new(0.0, 0.0),
85 ],
86 )
87 .expect("CNOT matrix has valid 4x4 shape")
88}
89pub(super) fn rotation_x(theta: f64) -> Array2<Complex64> {
90 let cos_half = (theta / 2.0).cos();
91 let sin_half = (theta / 2.0).sin();
92 Array2::from_shape_vec(
93 (2, 2),
94 vec![
95 Complex64::new(cos_half, 0.0),
96 Complex64::new(0.0, -sin_half),
97 Complex64::new(0.0, -sin_half),
98 Complex64::new(cos_half, 0.0),
99 ],
100 )
101 .expect("Rotation-X matrix has valid 2x2 shape")
102}
103pub(super) fn rotation_y(theta: f64) -> Array2<Complex64> {
104 let cos_half = (theta / 2.0).cos();
105 let sin_half = (theta / 2.0).sin();
106 Array2::from_shape_vec(
107 (2, 2),
108 vec![
109 Complex64::new(cos_half, 0.0),
110 Complex64::new(-sin_half, 0.0),
111 Complex64::new(sin_half, 0.0),
112 Complex64::new(cos_half, 0.0),
113 ],
114 )
115 .expect("Rotation-Y matrix has valid 2x2 shape")
116}
117pub(super) fn rotation_z(theta: f64) -> Array2<Complex64> {
118 let exp_neg = Complex64::from_polar(1.0, -theta / 2.0);
119 let exp_pos = Complex64::from_polar(1.0, theta / 2.0);
120 Array2::from_shape_vec(
121 (2, 2),
122 vec![
123 exp_neg,
124 Complex64::new(0.0, 0.0),
125 Complex64::new(0.0, 0.0),
126 exp_pos,
127 ],
128 )
129 .expect("Rotation-Z matrix has valid 2x2 shape")
130}
131pub(super) fn s_gate() -> Array2<Complex64> {
133 Array2::from_shape_vec(
134 (2, 2),
135 vec![
136 Complex64::new(1.0, 0.0),
137 Complex64::new(0.0, 0.0),
138 Complex64::new(0.0, 0.0),
139 Complex64::new(0.0, 1.0),
140 ],
141 )
142 .expect("S gate matrix has valid 2x2 shape")
143}
144pub(super) fn t_gate() -> Array2<Complex64> {
146 let phase = Complex64::from_polar(1.0, std::f64::consts::PI / 4.0);
147 Array2::from_shape_vec(
148 (2, 2),
149 vec![
150 Complex64::new(1.0, 0.0),
151 Complex64::new(0.0, 0.0),
152 Complex64::new(0.0, 0.0),
153 phase,
154 ],
155 )
156 .expect("T gate matrix has valid 2x2 shape")
157}
158pub(super) fn cz_gate() -> Array2<Complex64> {
160 Array2::from_shape_vec(
161 (4, 4),
162 vec![
163 Complex64::new(1.0, 0.0),
164 Complex64::new(0.0, 0.0),
165 Complex64::new(0.0, 0.0),
166 Complex64::new(0.0, 0.0),
167 Complex64::new(0.0, 0.0),
168 Complex64::new(1.0, 0.0),
169 Complex64::new(0.0, 0.0),
170 Complex64::new(0.0, 0.0),
171 Complex64::new(0.0, 0.0),
172 Complex64::new(0.0, 0.0),
173 Complex64::new(1.0, 0.0),
174 Complex64::new(0.0, 0.0),
175 Complex64::new(0.0, 0.0),
176 Complex64::new(0.0, 0.0),
177 Complex64::new(0.0, 0.0),
178 Complex64::new(-1.0, 0.0),
179 ],
180 )
181 .expect("CZ gate matrix has valid 4x4 shape")
182}
183pub(super) fn swap_gate() -> Array2<Complex64> {
185 Array2::from_shape_vec(
186 (4, 4),
187 vec![
188 Complex64::new(1.0, 0.0),
189 Complex64::new(0.0, 0.0),
190 Complex64::new(0.0, 0.0),
191 Complex64::new(0.0, 0.0),
192 Complex64::new(0.0, 0.0),
193 Complex64::new(0.0, 0.0),
194 Complex64::new(1.0, 0.0),
195 Complex64::new(0.0, 0.0),
196 Complex64::new(0.0, 0.0),
197 Complex64::new(1.0, 0.0),
198 Complex64::new(0.0, 0.0),
199 Complex64::new(0.0, 0.0),
200 Complex64::new(0.0, 0.0),
201 Complex64::new(0.0, 0.0),
202 Complex64::new(0.0, 0.0),
203 Complex64::new(1.0, 0.0),
204 ],
205 )
206 .expect("SWAP gate matrix has valid 4x4 shape")
207}
208#[cfg(test)]
209mod tests {
210 use super::*;
211 use approx::assert_abs_diff_eq;
212 #[test]
213 fn test_tensor_creation() {
214 let data = Array3::zeros((2, 2, 1));
215 let indices = vec![
216 TensorIndex {
217 id: 0,
218 dimension: 2,
219 index_type: IndexType::Physical(0),
220 },
221 TensorIndex {
222 id: 1,
223 dimension: 2,
224 index_type: IndexType::Physical(0),
225 },
226 ];
227 let tensor = Tensor::new(data, indices, "test".to_string());
228 assert_eq!(tensor.rank(), 2);
229 assert_eq!(tensor.label, "test");
230 }
231 #[test]
232 fn test_tensor_network_creation() {
233 let network = TensorNetwork::new(3);
234 assert_eq!(network.num_qubits, 3);
235 assert_eq!(network.tensors.len(), 0);
236 }
237 #[test]
238 fn test_simulator_initialization() {
239 let mut sim = TensorNetworkSimulator::new(2);
240 sim.initialize_zero_state()
241 .expect("Failed to initialize zero state");
242 assert_eq!(sim.network.tensors.len(), 2);
243 }
244 #[test]
245 fn test_single_qubit_gate() {
246 let mut sim = TensorNetworkSimulator::new(1);
247 sim.initialize_zero_state()
248 .expect("Failed to initialize zero state");
249 let initial_tensors = sim.network.tensors.len();
250 let h_gate = QuantumGate::new(
251 crate::adaptive_gate_fusion::GateType::Hadamard,
252 vec![0],
253 vec![],
254 );
255 sim.apply_gate(h_gate)
256 .expect("Failed to apply Hadamard gate");
257 assert_eq!(sim.network.tensors.len(), initial_tensors + 1);
258 }
259 #[test]
260 fn test_measurement() {
261 let mut sim = TensorNetworkSimulator::new(1);
262 sim.initialize_zero_state()
263 .expect("Failed to initialize zero state");
264 let result = sim.measure(0).expect("Failed to measure qubit");
265 let _: bool = result;
266 }
267 #[test]
268 fn test_contraction_strategies() {
269 let _sim = TensorNetworkSimulator::new(2);
270 let strat1 = ContractionStrategy::Sequential;
271 let strat2 = ContractionStrategy::Greedy;
272 let strat3 = ContractionStrategy::Custom(vec![0, 1]);
273 assert_ne!(strat1, strat2);
274 assert_ne!(strat2, strat3);
275 }
276 #[test]
277 fn test_gate_matrices() {
278 let h = pauli_h();
279 assert_abs_diff_eq!(h[[0, 0]].re, 1.0 / 2.0_f64.sqrt(), epsilon = 1e-10);
280 let x = pauli_x();
281 assert_abs_diff_eq!(x[[0, 1]].re, 1.0, epsilon = 1e-10);
282 assert_abs_diff_eq!(x[[1, 0]].re, 1.0, epsilon = 1e-10);
283 }
284 #[test]
285 fn test_enhanced_tensor_contraction() {
286 let mut id_gen = 0;
287 let tensor_a = Tensor::identity(0, &mut id_gen);
288 let tensor_b = Tensor::identity(0, &mut id_gen);
289 let result = tensor_a.contract(&tensor_b, 1, 0);
290 assert!(result.is_ok());
291 let contracted = result.expect("Failed to contract tensors");
292 assert!(!contracted.data.is_empty());
293 }
294 #[test]
295 fn test_contraction_cost_estimation() {
296 let network = TensorNetwork::new(2);
297 let mut id_gen = 0;
298 let tensor_a = Tensor::identity(0, &mut id_gen);
299 let tensor_b = Tensor::identity(1, &mut id_gen);
300 let cost = network.estimate_contraction_cost(&tensor_a, &tensor_b);
301 assert!(cost > 0.0);
302 assert!(cost.is_finite());
303 }
304 #[test]
305 fn test_optimal_contraction_order() {
306 let mut network = TensorNetwork::new(3);
307 let mut id_gen = 0;
308 for i in 0..3 {
309 let tensor = Tensor::identity(i, &mut id_gen);
310 network.add_tensor(tensor);
311 }
312 let order = network.find_optimal_contraction_order();
313 assert!(order.is_ok());
314 let order_vec = order.expect("Failed to find optimal contraction order");
315 assert!(!order_vec.is_empty());
316 }
317 #[test]
318 fn test_greedy_contraction_strategy() {
319 let mut simulator =
320 TensorNetworkSimulator::new(2).with_strategy(ContractionStrategy::Greedy);
321 let mut id_gen = 0;
322 for i in 0..2 {
323 let tensor = Tensor::identity(i, &mut id_gen);
324 simulator.network.add_tensor(tensor);
325 }
326 let result = simulator.contract_greedy();
327 assert!(result.is_ok());
328 let amplitude = result.expect("Failed to contract network");
329 assert!(amplitude.norm() >= 0.0);
330 }
331 #[test]
332 fn test_basis_state_boundary_conditions() {
333 let mut network = TensorNetwork::new(2);
334 let mut id_gen = 0;
335 for i in 0..2 {
336 let tensor = Tensor::identity(i, &mut id_gen);
337 network.add_tensor(tensor);
338 }
339 let result = network.set_basis_state_boundary(1);
340 assert!(result.is_ok());
341 }
342 #[test]
343 fn test_full_state_vector_contraction() {
344 let simulator = TensorNetworkSimulator::new(2);
345 let result = simulator.contract_network_to_state_vector();
346 assert!(result.is_ok());
347 let state_vector = result.expect("Failed to contract network to state vector");
348 assert_eq!(state_vector.len(), 4);
349 assert!((state_vector[0].norm() - 1.0).abs() < 1e-10);
350 }
351 #[test]
352 fn test_advanced_contraction_algorithms() {
353 let mut id_gen = 0;
354 let tensor = Tensor::identity(0, &mut id_gen);
355 let qr_result = AdvancedContractionAlgorithms::hotqr_decomposition(&tensor);
356 assert!(qr_result.is_ok());
357 let (q, r) = qr_result.expect("Failed to perform HOTQR decomposition");
358 assert_eq!(q.label, "Q");
359 assert_eq!(r.label, "R");
360 }
361 #[test]
362 fn test_tree_contraction() {
363 let mut id_gen = 0;
364 let tensors = vec![
365 Tensor::identity(0, &mut id_gen),
366 Tensor::identity(1, &mut id_gen),
367 ];
368 let result = AdvancedContractionAlgorithms::tree_contraction(&tensors);
369 assert!(result.is_ok());
370 let amplitude = result.expect("Failed to perform tree contraction");
371 assert!(amplitude.norm() >= 0.0);
372 }
373}