quantrs2_sim/
linalg_ops.rs

1//! Linear algebra operations for quantum simulation using `SciRS2`
2//!
3//! This module provides optimized linear algebra operations for quantum
4//! simulation by leveraging `SciRS2`'s BLAS/LAPACK bindings when available.
5
6// Note: NdArrayExt would be used here if it was available in scirs2_core
7// For now, we'll use standard ndarray operations
8
9use scirs2_core::ndarray::{Array2, ArrayView2};
10use scirs2_core::Complex64;
11
12/// Matrix-vector multiplication for quantum state evolution
13///
14/// Computes |ψ'⟩ = U|ψ⟩ where U is a unitary matrix and |ψ⟩ is a state vector.
15pub fn apply_unitary(
16    unitary: &ArrayView2<Complex64>,
17    state: &mut [Complex64],
18) -> Result<(), String> {
19    let n = state.len();
20
21    // Check dimensions
22    if unitary.shape() != [n, n] {
23        return Err(format!(
24            "Unitary matrix shape {:?} doesn't match state dimension {}",
25            unitary.shape(),
26            n
27        ));
28    }
29
30    // Create temporary storage for the result
31    let mut result = vec![Complex64::new(0.0, 0.0); n];
32
33    // Perform matrix-vector multiplication
34    #[cfg(feature = "advanced_math")]
35    {
36        // Use optimized matrix multiplication when available
37        for i in 0..n {
38            for j in 0..n {
39                result[i] += unitary[[i, j]] * state[j];
40            }
41        }
42    }
43
44    #[cfg(not(feature = "advanced_math"))]
45    {
46        // Fallback to manual implementation
47        for i in 0..n {
48            for j in 0..n {
49                result[i] += unitary[[i, j]] * state[j];
50            }
51        }
52    }
53
54    // Copy result back to state
55    state.copy_from_slice(&result);
56    Ok(())
57}
58
59/// Compute the tensor product of two matrices
60///
61/// This is used for constructing multi-qubit gates from single-qubit gates.
62#[must_use]
63pub fn tensor_product(a: &ArrayView2<Complex64>, b: &ArrayView2<Complex64>) -> Array2<Complex64> {
64    let (m, n) = a.dim();
65    let (p, q) = b.dim();
66
67    let mut result = Array2::zeros((m * p, n * q));
68
69    for i in 0..m {
70        for j in 0..n {
71            for k in 0..p {
72                for l in 0..q {
73                    result[[i * p + k, j * q + l]] = a[[i, j]] * b[[k, l]];
74                }
75            }
76        }
77    }
78
79    result
80}
81
82/// Compute the partial trace over specified qubits
83///
84/// This is used for obtaining reduced density matrices.
85pub fn partial_trace(
86    density_matrix: &ArrayView2<Complex64>,
87    qubits_to_trace: &[usize],
88    total_qubits: usize,
89) -> Result<Array2<Complex64>, String> {
90    let dim = 1 << total_qubits;
91
92    if density_matrix.shape() != [dim, dim] {
93        return Err(format!(
94            "Density matrix shape {:?} doesn't match {} qubits",
95            density_matrix.shape(),
96            total_qubits
97        ));
98    }
99
100    // Calculate dimensions after tracing
101    let traced_qubits = qubits_to_trace.len();
102    let remaining_qubits = total_qubits - traced_qubits;
103    let remaining_dim = 1 << remaining_qubits;
104
105    let mut result = Array2::zeros((remaining_dim, remaining_dim));
106
107    // Perform the partial trace
108    // This is a simplified implementation; a full implementation would be more complex
109    for i in 0..remaining_dim {
110        for j in 0..remaining_dim {
111            let mut sum = Complex64::new(0.0, 0.0);
112
113            // Sum over traced-out basis states
114            for k in 0..(1 << traced_qubits) {
115                // Map indices appropriately (simplified for demonstration)
116                let full_i = i + (k << remaining_qubits);
117                let full_j = j + (k << remaining_qubits);
118
119                if full_i < dim && full_j < dim {
120                    sum += density_matrix[[full_i, full_j]];
121                }
122            }
123
124            result[[i, j]] = sum;
125        }
126    }
127
128    Ok(result)
129}
130
131/// Check if a matrix is unitary (U†U = I)
132#[must_use]
133pub fn is_unitary(matrix: &ArrayView2<Complex64>, tolerance: f64) -> bool {
134    let n = matrix.nrows();
135    if matrix.ncols() != n {
136        return false; // Not square
137    }
138
139    // Compute U†U
140    let mut product: Array2<Complex64> = Array2::zeros((n, n));
141
142    #[cfg(feature = "advanced_math")]
143    {
144        // Use optimized matrix multiplication
145        let conjugate_transpose = matrix.t().mapv(|x| x.conj());
146        product = conjugate_transpose.dot(matrix);
147    }
148
149    #[cfg(not(feature = "advanced_math"))]
150    {
151        // Manual implementation
152        for i in 0..n {
153            for j in 0..n {
154                for k in 0..n {
155                    product[[i, j]] += matrix[[k, i]].conj() * matrix[[k, j]];
156                }
157            }
158        }
159    }
160
161    // Check if result is identity
162    for i in 0..n {
163        for j in 0..n {
164            let expected = if i == j {
165                Complex64::new(1.0, 0.0)
166            } else {
167                Complex64::new(0.0, 0.0)
168            };
169
170            let diff: Complex64 = product[[i, j]] - expected;
171            if diff.norm() > tolerance {
172                return false;
173            }
174        }
175    }
176
177    true
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use scirs2_core::ndarray::arr2;
184
185    #[test]
186    fn test_apply_unitary() {
187        // Hadamard gate
188        let h = arr2(&[
189            [
190                Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0),
191                Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0),
192            ],
193            [
194                Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0),
195                Complex64::new(-1.0 / std::f64::consts::SQRT_2, 0.0),
196            ],
197        ]);
198
199        // |0⟩ state
200        let mut state = vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)];
201
202        apply_unitary(&h.view(), &mut state).expect("unitary application should succeed");
203
204        // Should produce |+⟩ state
205        let expected_0 = Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0);
206        assert!((state[0] - expected_0).norm() < 1e-10);
207        assert!((state[1] - expected_0).norm() < 1e-10);
208    }
209
210    #[test]
211    fn test_tensor_product() {
212        // Two 2x2 matrices
213        let a = arr2(&[
214            [Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)],
215            [Complex64::new(3.0, 0.0), Complex64::new(4.0, 0.0)],
216        ]);
217
218        let b = arr2(&[
219            [Complex64::new(5.0, 0.0), Complex64::new(6.0, 0.0)],
220            [Complex64::new(7.0, 0.0), Complex64::new(8.0, 0.0)],
221        ]);
222
223        let result = tensor_product(&a.view(), &b.view());
224
225        assert_eq!(result.dim(), (4, 4));
226        assert_eq!(result[[0, 0]], Complex64::new(5.0, 0.0));
227        assert_eq!(result[[0, 1]], Complex64::new(6.0, 0.0));
228        assert_eq!(result[[3, 3]], Complex64::new(32.0, 0.0));
229    }
230
231    #[test]
232    fn test_is_unitary() {
233        // Pauli X gate
234        let x = arr2(&[
235            [Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)],
236            [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
237        ]);
238
239        assert!(is_unitary(&x.view(), 1e-10));
240
241        // Non-unitary matrix
242        let non_unitary = arr2(&[
243            [Complex64::new(1.0, 0.0), Complex64::new(1.0, 0.0)],
244            [Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)],
245        ]);
246
247        assert!(!is_unitary(&non_unitary.view(), 1e-10));
248    }
249}