Skip to main content

quantrs2_sim/tensor/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use crate::adaptive_gate_fusion::QuantumGate;
6use quantrs2_circuit::prelude::*;
7use quantrs2_core::prelude::*;
8use scirs2_core::ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, Axis};
9use scirs2_core::Complex64;
10
11use super::types::{
12    AdvancedContractionAlgorithms, ContractionStrategy, IndexType, Tensor, TensorIndex,
13    TensorNetwork, TensorNetworkSimulator,
14};
15
16pub(super) fn pauli_x() -> Array2<Complex64> {
17    Array2::from_shape_vec(
18        (2, 2),
19        vec![
20            Complex64::new(0.0, 0.0),
21            Complex64::new(1.0, 0.0),
22            Complex64::new(1.0, 0.0),
23            Complex64::new(0.0, 0.0),
24        ],
25    )
26    .expect("Pauli-X matrix has valid 2x2 shape")
27}
28pub(super) fn pauli_y() -> Array2<Complex64> {
29    Array2::from_shape_vec(
30        (2, 2),
31        vec![
32            Complex64::new(0.0, 0.0),
33            Complex64::new(0.0, -1.0),
34            Complex64::new(0.0, 1.0),
35            Complex64::new(0.0, 0.0),
36        ],
37    )
38    .expect("Pauli-Y matrix has valid 2x2 shape")
39}
40pub(super) fn pauli_z() -> Array2<Complex64> {
41    Array2::from_shape_vec(
42        (2, 2),
43        vec![
44            Complex64::new(1.0, 0.0),
45            Complex64::new(0.0, 0.0),
46            Complex64::new(0.0, 0.0),
47            Complex64::new(-1.0, 0.0),
48        ],
49    )
50    .expect("Pauli-Z matrix has valid 2x2 shape")
51}
52pub(super) fn pauli_h() -> Array2<Complex64> {
53    let inv_sqrt2 = 1.0 / 2.0_f64.sqrt();
54    Array2::from_shape_vec(
55        (2, 2),
56        vec![
57            Complex64::new(inv_sqrt2, 0.0),
58            Complex64::new(inv_sqrt2, 0.0),
59            Complex64::new(inv_sqrt2, 0.0),
60            Complex64::new(-inv_sqrt2, 0.0),
61        ],
62    )
63    .expect("Hadamard matrix has valid 2x2 shape")
64}
65pub(super) fn cnot_matrix() -> Array2<Complex64> {
66    Array2::from_shape_vec(
67        (4, 4),
68        vec![
69            Complex64::new(1.0, 0.0),
70            Complex64::new(0.0, 0.0),
71            Complex64::new(0.0, 0.0),
72            Complex64::new(0.0, 0.0),
73            Complex64::new(0.0, 0.0),
74            Complex64::new(1.0, 0.0),
75            Complex64::new(0.0, 0.0),
76            Complex64::new(0.0, 0.0),
77            Complex64::new(0.0, 0.0),
78            Complex64::new(0.0, 0.0),
79            Complex64::new(0.0, 0.0),
80            Complex64::new(1.0, 0.0),
81            Complex64::new(0.0, 0.0),
82            Complex64::new(0.0, 0.0),
83            Complex64::new(1.0, 0.0),
84            Complex64::new(0.0, 0.0),
85        ],
86    )
87    .expect("CNOT matrix has valid 4x4 shape")
88}
89pub(super) fn rotation_x(theta: f64) -> Array2<Complex64> {
90    let cos_half = (theta / 2.0).cos();
91    let sin_half = (theta / 2.0).sin();
92    Array2::from_shape_vec(
93        (2, 2),
94        vec![
95            Complex64::new(cos_half, 0.0),
96            Complex64::new(0.0, -sin_half),
97            Complex64::new(0.0, -sin_half),
98            Complex64::new(cos_half, 0.0),
99        ],
100    )
101    .expect("Rotation-X matrix has valid 2x2 shape")
102}
103pub(super) fn rotation_y(theta: f64) -> Array2<Complex64> {
104    let cos_half = (theta / 2.0).cos();
105    let sin_half = (theta / 2.0).sin();
106    Array2::from_shape_vec(
107        (2, 2),
108        vec![
109            Complex64::new(cos_half, 0.0),
110            Complex64::new(-sin_half, 0.0),
111            Complex64::new(sin_half, 0.0),
112            Complex64::new(cos_half, 0.0),
113        ],
114    )
115    .expect("Rotation-Y matrix has valid 2x2 shape")
116}
117pub(super) fn rotation_z(theta: f64) -> Array2<Complex64> {
118    let exp_neg = Complex64::from_polar(1.0, -theta / 2.0);
119    let exp_pos = Complex64::from_polar(1.0, theta / 2.0);
120    Array2::from_shape_vec(
121        (2, 2),
122        vec![
123            exp_neg,
124            Complex64::new(0.0, 0.0),
125            Complex64::new(0.0, 0.0),
126            exp_pos,
127        ],
128    )
129    .expect("Rotation-Z matrix has valid 2x2 shape")
130}
131/// S gate (phase gate)
132pub(super) fn s_gate() -> Array2<Complex64> {
133    Array2::from_shape_vec(
134        (2, 2),
135        vec![
136            Complex64::new(1.0, 0.0),
137            Complex64::new(0.0, 0.0),
138            Complex64::new(0.0, 0.0),
139            Complex64::new(0.0, 1.0),
140        ],
141    )
142    .expect("S gate matrix has valid 2x2 shape")
143}
144/// T gate (Ï€/8 gate)
145pub(super) fn t_gate() -> Array2<Complex64> {
146    let phase = Complex64::from_polar(1.0, std::f64::consts::PI / 4.0);
147    Array2::from_shape_vec(
148        (2, 2),
149        vec![
150            Complex64::new(1.0, 0.0),
151            Complex64::new(0.0, 0.0),
152            Complex64::new(0.0, 0.0),
153            phase,
154        ],
155    )
156    .expect("T gate matrix has valid 2x2 shape")
157}
158/// CZ gate (controlled-Z)
159pub(super) fn cz_gate() -> Array2<Complex64> {
160    Array2::from_shape_vec(
161        (4, 4),
162        vec![
163            Complex64::new(1.0, 0.0),
164            Complex64::new(0.0, 0.0),
165            Complex64::new(0.0, 0.0),
166            Complex64::new(0.0, 0.0),
167            Complex64::new(0.0, 0.0),
168            Complex64::new(1.0, 0.0),
169            Complex64::new(0.0, 0.0),
170            Complex64::new(0.0, 0.0),
171            Complex64::new(0.0, 0.0),
172            Complex64::new(0.0, 0.0),
173            Complex64::new(1.0, 0.0),
174            Complex64::new(0.0, 0.0),
175            Complex64::new(0.0, 0.0),
176            Complex64::new(0.0, 0.0),
177            Complex64::new(0.0, 0.0),
178            Complex64::new(-1.0, 0.0),
179        ],
180    )
181    .expect("CZ gate matrix has valid 4x4 shape")
182}
183/// SWAP gate
184pub(super) fn swap_gate() -> Array2<Complex64> {
185    Array2::from_shape_vec(
186        (4, 4),
187        vec![
188            Complex64::new(1.0, 0.0),
189            Complex64::new(0.0, 0.0),
190            Complex64::new(0.0, 0.0),
191            Complex64::new(0.0, 0.0),
192            Complex64::new(0.0, 0.0),
193            Complex64::new(0.0, 0.0),
194            Complex64::new(1.0, 0.0),
195            Complex64::new(0.0, 0.0),
196            Complex64::new(0.0, 0.0),
197            Complex64::new(1.0, 0.0),
198            Complex64::new(0.0, 0.0),
199            Complex64::new(0.0, 0.0),
200            Complex64::new(0.0, 0.0),
201            Complex64::new(0.0, 0.0),
202            Complex64::new(0.0, 0.0),
203            Complex64::new(1.0, 0.0),
204        ],
205    )
206    .expect("SWAP gate matrix has valid 4x4 shape")
207}
208#[cfg(test)]
209mod tests {
210    use super::*;
211    use approx::assert_abs_diff_eq;
212    #[test]
213    fn test_tensor_creation() {
214        let data = Array3::zeros((2, 2, 1));
215        let indices = vec![
216            TensorIndex {
217                id: 0,
218                dimension: 2,
219                index_type: IndexType::Physical(0),
220            },
221            TensorIndex {
222                id: 1,
223                dimension: 2,
224                index_type: IndexType::Physical(0),
225            },
226        ];
227        let tensor = Tensor::new(data, indices, "test".to_string());
228        assert_eq!(tensor.rank(), 2);
229        assert_eq!(tensor.label, "test");
230    }
231    #[test]
232    fn test_tensor_network_creation() {
233        let network = TensorNetwork::new(3);
234        assert_eq!(network.num_qubits, 3);
235        assert_eq!(network.tensors.len(), 0);
236    }
237    #[test]
238    fn test_simulator_initialization() {
239        let mut sim = TensorNetworkSimulator::new(2);
240        sim.initialize_zero_state()
241            .expect("Failed to initialize zero state");
242        assert_eq!(sim.network.tensors.len(), 2);
243    }
244    #[test]
245    fn test_single_qubit_gate() {
246        let mut sim = TensorNetworkSimulator::new(1);
247        sim.initialize_zero_state()
248            .expect("Failed to initialize zero state");
249        let initial_tensors = sim.network.tensors.len();
250        let h_gate = QuantumGate::new(
251            crate::adaptive_gate_fusion::GateType::Hadamard,
252            vec![0],
253            vec![],
254        );
255        sim.apply_gate(h_gate)
256            .expect("Failed to apply Hadamard gate");
257        assert_eq!(sim.network.tensors.len(), initial_tensors + 1);
258    }
259    #[test]
260    fn test_measurement() {
261        let mut sim = TensorNetworkSimulator::new(1);
262        sim.initialize_zero_state()
263            .expect("Failed to initialize zero state");
264        let result = sim.measure(0).expect("Failed to measure qubit");
265        let _: bool = result;
266    }
267    #[test]
268    fn test_contraction_strategies() {
269        let _sim = TensorNetworkSimulator::new(2);
270        let strat1 = ContractionStrategy::Sequential;
271        let strat2 = ContractionStrategy::Greedy;
272        let strat3 = ContractionStrategy::Custom(vec![0, 1]);
273        assert_ne!(strat1, strat2);
274        assert_ne!(strat2, strat3);
275    }
276    #[test]
277    fn test_gate_matrices() {
278        let h = pauli_h();
279        assert_abs_diff_eq!(h[[0, 0]].re, 1.0 / 2.0_f64.sqrt(), epsilon = 1e-10);
280        let x = pauli_x();
281        assert_abs_diff_eq!(x[[0, 1]].re, 1.0, epsilon = 1e-10);
282        assert_abs_diff_eq!(x[[1, 0]].re, 1.0, epsilon = 1e-10);
283    }
284    #[test]
285    fn test_enhanced_tensor_contraction() {
286        let mut id_gen = 0;
287        let tensor_a = Tensor::identity(0, &mut id_gen);
288        let tensor_b = Tensor::identity(0, &mut id_gen);
289        let result = tensor_a.contract(&tensor_b, 1, 0);
290        assert!(result.is_ok());
291        let contracted = result.expect("Failed to contract tensors");
292        assert!(!contracted.data.is_empty());
293    }
294    #[test]
295    fn test_contraction_cost_estimation() {
296        let network = TensorNetwork::new(2);
297        let mut id_gen = 0;
298        let tensor_a = Tensor::identity(0, &mut id_gen);
299        let tensor_b = Tensor::identity(1, &mut id_gen);
300        let cost = network.estimate_contraction_cost(&tensor_a, &tensor_b);
301        assert!(cost > 0.0);
302        assert!(cost.is_finite());
303    }
304    #[test]
305    fn test_optimal_contraction_order() {
306        let mut network = TensorNetwork::new(3);
307        let mut id_gen = 0;
308        for i in 0..3 {
309            let tensor = Tensor::identity(i, &mut id_gen);
310            network.add_tensor(tensor);
311        }
312        let order = network.find_optimal_contraction_order();
313        assert!(order.is_ok());
314        let order_vec = order.expect("Failed to find optimal contraction order");
315        assert!(!order_vec.is_empty());
316    }
317    #[test]
318    fn test_greedy_contraction_strategy() {
319        let mut simulator =
320            TensorNetworkSimulator::new(2).with_strategy(ContractionStrategy::Greedy);
321        let mut id_gen = 0;
322        for i in 0..2 {
323            let tensor = Tensor::identity(i, &mut id_gen);
324            simulator.network.add_tensor(tensor);
325        }
326        let result = simulator.contract_greedy();
327        assert!(result.is_ok());
328        let amplitude = result.expect("Failed to contract network");
329        assert!(amplitude.norm() >= 0.0);
330    }
331    #[test]
332    fn test_basis_state_boundary_conditions() {
333        let mut network = TensorNetwork::new(2);
334        let mut id_gen = 0;
335        for i in 0..2 {
336            let tensor = Tensor::identity(i, &mut id_gen);
337            network.add_tensor(tensor);
338        }
339        let result = network.set_basis_state_boundary(1);
340        assert!(result.is_ok());
341    }
342    #[test]
343    fn test_full_state_vector_contraction() {
344        let simulator = TensorNetworkSimulator::new(2);
345        let result = simulator.contract_network_to_state_vector();
346        assert!(result.is_ok());
347        let state_vector = result.expect("Failed to contract network to state vector");
348        assert_eq!(state_vector.len(), 4);
349        assert!((state_vector[0].norm() - 1.0).abs() < 1e-10);
350    }
351    #[test]
352    fn test_advanced_contraction_algorithms() {
353        let mut id_gen = 0;
354        let tensor = Tensor::identity(0, &mut id_gen);
355        let qr_result = AdvancedContractionAlgorithms::hotqr_decomposition(&tensor);
356        assert!(qr_result.is_ok());
357        let (q, r) = qr_result.expect("Failed to perform HOTQR decomposition");
358        assert_eq!(q.label, "Q");
359        assert_eq!(r.label, "R");
360    }
361    #[test]
362    fn test_tree_contraction() {
363        let mut id_gen = 0;
364        let tensors = vec![
365            Tensor::identity(0, &mut id_gen),
366            Tensor::identity(1, &mut id_gen),
367        ];
368        let result = AdvancedContractionAlgorithms::tree_contraction(&tensors);
369        assert!(result.is_ok());
370        let amplitude = result.expect("Failed to perform tree contraction");
371        assert!(amplitude.norm() >= 0.0);
372    }
373}