1use scirs2_core::ndarray::{Array2, ArrayView2};
10use scirs2_core::Complex64;
11
12pub fn apply_unitary(
16 unitary: &ArrayView2<Complex64>,
17 state: &mut [Complex64],
18) -> Result<(), String> {
19 let n = state.len();
20
21 if unitary.shape() != [n, n] {
23 return Err(format!(
24 "Unitary matrix shape {:?} doesn't match state dimension {}",
25 unitary.shape(),
26 n
27 ));
28 }
29
30 let mut result = vec![Complex64::new(0.0, 0.0); n];
32
33 #[cfg(feature = "advanced_math")]
35 {
36 for i in 0..n {
38 for j in 0..n {
39 result[i] += unitary[[i, j]] * state[j];
40 }
41 }
42 }
43
44 #[cfg(not(feature = "advanced_math"))]
45 {
46 for i in 0..n {
48 for j in 0..n {
49 result[i] += unitary[[i, j]] * state[j];
50 }
51 }
52 }
53
54 state.copy_from_slice(&result);
56 Ok(())
57}
58
59#[must_use]
63pub fn tensor_product(a: &ArrayView2<Complex64>, b: &ArrayView2<Complex64>) -> Array2<Complex64> {
64 let (m, n) = a.dim();
65 let (p, q) = b.dim();
66
67 let mut result = Array2::zeros((m * p, n * q));
68
69 for i in 0..m {
70 for j in 0..n {
71 for k in 0..p {
72 for l in 0..q {
73 result[[i * p + k, j * q + l]] = a[[i, j]] * b[[k, l]];
74 }
75 }
76 }
77 }
78
79 result
80}
81
82pub fn partial_trace(
86 density_matrix: &ArrayView2<Complex64>,
87 qubits_to_trace: &[usize],
88 total_qubits: usize,
89) -> Result<Array2<Complex64>, String> {
90 let dim = 1 << total_qubits;
91
92 if density_matrix.shape() != [dim, dim] {
93 return Err(format!(
94 "Density matrix shape {:?} doesn't match {} qubits",
95 density_matrix.shape(),
96 total_qubits
97 ));
98 }
99
100 let traced_qubits = qubits_to_trace.len();
102 let remaining_qubits = total_qubits - traced_qubits;
103 let remaining_dim = 1 << remaining_qubits;
104
105 let mut result = Array2::zeros((remaining_dim, remaining_dim));
106
107 for i in 0..remaining_dim {
110 for j in 0..remaining_dim {
111 let mut sum = Complex64::new(0.0, 0.0);
112
113 for k in 0..(1 << traced_qubits) {
115 let full_i = i + (k << remaining_qubits);
117 let full_j = j + (k << remaining_qubits);
118
119 if full_i < dim && full_j < dim {
120 sum += density_matrix[[full_i, full_j]];
121 }
122 }
123
124 result[[i, j]] = sum;
125 }
126 }
127
128 Ok(result)
129}
130
131#[must_use]
133pub fn is_unitary(matrix: &ArrayView2<Complex64>, tolerance: f64) -> bool {
134 let n = matrix.nrows();
135 if matrix.ncols() != n {
136 return false; }
138
139 let mut product: Array2<Complex64> = Array2::zeros((n, n));
141
142 #[cfg(feature = "advanced_math")]
143 {
144 let conjugate_transpose = matrix.t().mapv(|x| x.conj());
146 product = conjugate_transpose.dot(matrix);
147 }
148
149 #[cfg(not(feature = "advanced_math"))]
150 {
151 for i in 0..n {
153 for j in 0..n {
154 for k in 0..n {
155 product[[i, j]] += matrix[[k, i]].conj() * matrix[[k, j]];
156 }
157 }
158 }
159 }
160
161 for i in 0..n {
163 for j in 0..n {
164 let expected = if i == j {
165 Complex64::new(1.0, 0.0)
166 } else {
167 Complex64::new(0.0, 0.0)
168 };
169
170 let diff: Complex64 = product[[i, j]] - expected;
171 if diff.norm() > tolerance {
172 return false;
173 }
174 }
175 }
176
177 true
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use scirs2_core::ndarray::arr2;
184
185 #[test]
186 fn test_apply_unitary() {
187 let h = arr2(&[
189 [
190 Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0),
191 Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0),
192 ],
193 [
194 Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0),
195 Complex64::new(-1.0 / std::f64::consts::SQRT_2, 0.0),
196 ],
197 ]);
198
199 let mut state = vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)];
201
202 apply_unitary(&h.view(), &mut state).expect("unitary application should succeed");
203
204 let expected_0 = Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0);
206 assert!((state[0] - expected_0).norm() < 1e-10);
207 assert!((state[1] - expected_0).norm() < 1e-10);
208 }
209
210 #[test]
211 fn test_tensor_product() {
212 let a = arr2(&[
214 [Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)],
215 [Complex64::new(3.0, 0.0), Complex64::new(4.0, 0.0)],
216 ]);
217
218 let b = arr2(&[
219 [Complex64::new(5.0, 0.0), Complex64::new(6.0, 0.0)],
220 [Complex64::new(7.0, 0.0), Complex64::new(8.0, 0.0)],
221 ]);
222
223 let result = tensor_product(&a.view(), &b.view());
224
225 assert_eq!(result.dim(), (4, 4));
226 assert_eq!(result[[0, 0]], Complex64::new(5.0, 0.0));
227 assert_eq!(result[[0, 1]], Complex64::new(6.0, 0.0));
228 assert_eq!(result[[3, 3]], Complex64::new(32.0, 0.0));
229 }
230
231 #[test]
232 fn test_is_unitary() {
233 let x = arr2(&[
235 [Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)],
236 [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
237 ]);
238
239 assert!(is_unitary(&x.view(), 1e-10));
240
241 let non_unitary = arr2(&[
243 [Complex64::new(1.0, 0.0), Complex64::new(1.0, 0.0)],
244 [Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)],
245 ]);
246
247 assert!(!is_unitary(&non_unitary.view(), 1e-10));
248 }
249}