1use crate::error::{QuantRS2Error, QuantRS2Result};
8use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
9use num_complex::Complex64;
10use scirs2_sparse::csr::CsrMatrix;
11use std::fmt::Debug;
12
13pub trait QuantumMatrix: Debug + Send + Sync {
15 fn dim(&self) -> usize;
17
18 fn to_dense(&self) -> Array2<Complex64>;
20
21 fn to_sparse(&self) -> QuantRS2Result<CsrMatrix<Complex64>>;
23
24 fn is_unitary(&self, tolerance: f64) -> QuantRS2Result<bool>;
26
27 fn tensor_product(&self, other: &dyn QuantumMatrix) -> QuantRS2Result<Array2<Complex64>>;
29
30 fn apply(&self, state: &ArrayView1<Complex64>) -> QuantRS2Result<Array1<Complex64>>;
32}
33
34#[derive(Debug, Clone)]
36pub struct DenseMatrix {
37 data: Array2<Complex64>,
38}
39
40impl DenseMatrix {
41 pub fn new(data: Array2<Complex64>) -> QuantRS2Result<Self> {
43 if data.nrows() != data.ncols() {
44 return Err(QuantRS2Error::InvalidInput(
45 "Matrix must be square".to_string(),
46 ));
47 }
48 Ok(Self { data })
49 }
50
51 pub fn from_vec(data: Vec<Complex64>, dim: usize) -> QuantRS2Result<Self> {
53 if data.len() != dim * dim {
54 return Err(QuantRS2Error::InvalidInput(format!(
55 "Expected {} elements, got {}",
56 dim * dim,
57 data.len()
58 )));
59 }
60 let matrix = Array2::from_shape_vec((dim, dim), data)
61 .map_err(|e| QuantRS2Error::InvalidInput(e.to_string()))?;
62 Self::new(matrix)
63 }
64
65 pub fn as_array(&self) -> &Array2<Complex64> {
67 &self.data
68 }
69
70 pub fn is_hermitian(&self, tolerance: f64) -> bool {
72 let n = self.data.nrows();
73 for i in 0..n {
74 for j in i..n {
75 let diff = (self.data[[i, j]] - self.data[[j, i]].conj()).norm();
76 if diff > tolerance {
77 return false;
78 }
79 }
80 }
81 true
82 }
83}
84
85impl QuantumMatrix for DenseMatrix {
86 fn dim(&self) -> usize {
87 self.data.nrows()
88 }
89
90 fn to_dense(&self) -> Array2<Complex64> {
91 self.data.clone()
92 }
93
94 fn to_sparse(&self) -> QuantRS2Result<CsrMatrix<Complex64>> {
95 let n = self.dim();
96 let mut rows = Vec::new();
97 let mut cols = Vec::new();
98 let mut data = Vec::new();
99
100 let tolerance = 1e-14;
101 for i in 0..n {
102 for j in 0..n {
103 let val = self.data[[i, j]];
104 if val.norm() > tolerance {
105 rows.push(i);
106 cols.push(j);
107 data.push(val);
108 }
109 }
110 }
111
112 CsrMatrix::new(data, rows, cols, (n, n))
113 .map_err(|e| QuantRS2Error::InvalidInput(e.to_string()))
114 }
115
116 fn is_unitary(&self, tolerance: f64) -> QuantRS2Result<bool> {
117 let n = self.dim();
118 let conj_transpose = self.data.t().mapv(|x| x.conj());
119 let product = self.data.dot(&conj_transpose);
120
121 for i in 0..n {
123 for j in 0..n {
124 let expected = if i == j {
125 Complex64::new(1.0, 0.0)
126 } else {
127 Complex64::new(0.0, 0.0)
128 };
129 let diff = (product[[i, j]] - expected).norm();
130 if diff > tolerance {
131 return Ok(false);
132 }
133 }
134 }
135 Ok(true)
136 }
137
138 fn tensor_product(&self, other: &dyn QuantumMatrix) -> QuantRS2Result<Array2<Complex64>> {
139 let other_dense = other.to_dense();
140 let n1 = self.dim();
141 let n2 = other_dense.nrows();
142 let n = n1 * n2;
143
144 let mut result = Array2::zeros((n, n));
145
146 for i1 in 0..n1 {
147 for j1 in 0..n1 {
148 let val1 = self.data[[i1, j1]];
149 for i2 in 0..n2 {
150 for j2 in 0..n2 {
151 let val2 = other_dense[[i2, j2]];
152 result[[i1 * n2 + i2, j1 * n2 + j2]] = val1 * val2;
153 }
154 }
155 }
156 }
157
158 Ok(result)
159 }
160
161 fn apply(&self, state: &ArrayView1<Complex64>) -> QuantRS2Result<Array1<Complex64>> {
162 if state.len() != self.dim() {
163 return Err(QuantRS2Error::InvalidInput(format!(
164 "State dimension {} doesn't match matrix dimension {}",
165 state.len(),
166 self.dim()
167 )));
168 }
169 Ok(self.data.dot(state))
170 }
171}
172
173#[derive(Clone)]
175pub struct SparseMatrix {
176 csr: CsrMatrix<Complex64>,
177 dim: usize,
178}
179
180impl Debug for SparseMatrix {
181 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182 f.debug_struct("SparseMatrix")
183 .field("dim", &self.dim)
184 .field("nnz", &self.csr.nnz())
185 .finish()
186 }
187}
188
189impl SparseMatrix {
190 pub fn new(csr: CsrMatrix<Complex64>) -> QuantRS2Result<Self> {
192 let (rows, cols) = csr.shape();
193 if rows != cols {
194 return Err(QuantRS2Error::InvalidInput(
195 "Matrix must be square".to_string(),
196 ));
197 }
198 Ok(Self { csr, dim: rows })
199 }
200
201 pub fn from_triplets(
203 rows: Vec<usize>,
204 cols: Vec<usize>,
205 data: Vec<Complex64>,
206 dim: usize,
207 ) -> QuantRS2Result<Self> {
208 let csr = CsrMatrix::new(data, rows, cols, (dim, dim))
209 .map_err(|e| QuantRS2Error::InvalidInput(e.to_string()))?;
210 Self::new(csr)
211 }
212}
213
214impl QuantumMatrix for SparseMatrix {
215 fn dim(&self) -> usize {
216 self.dim
217 }
218
219 fn to_dense(&self) -> Array2<Complex64> {
220 let dense_vec = self.csr.to_dense();
221 let rows = dense_vec.len();
222 let cols = if rows > 0 { dense_vec[0].len() } else { 0 };
223
224 let mut flat = Vec::with_capacity(rows * cols);
225 for row in dense_vec {
226 flat.extend(row);
227 }
228
229 Array2::from_shape_vec((rows, cols), flat).unwrap()
230 }
231
232 fn to_sparse(&self) -> QuantRS2Result<CsrMatrix<Complex64>> {
233 Ok(self.csr.clone())
234 }
235
236 fn is_unitary(&self, tolerance: f64) -> QuantRS2Result<bool> {
237 let dense = DenseMatrix::new(self.to_dense())?;
239 dense.is_unitary(tolerance)
240 }
241
242 fn tensor_product(&self, other: &dyn QuantumMatrix) -> QuantRS2Result<Array2<Complex64>> {
243 let dense = DenseMatrix::new(self.to_dense())?;
246 dense.tensor_product(other)
247 }
248
249 fn apply(&self, state: &ArrayView1<Complex64>) -> QuantRS2Result<Array1<Complex64>> {
250 if state.len() != self.dim() {
251 return Err(QuantRS2Error::InvalidInput(format!(
252 "State dimension {} doesn't match matrix dimension {}",
253 state.len(),
254 self.dim()
255 )));
256 }
257 let dense = self.to_dense();
259 Ok(dense.dot(state))
260 }
261}
262
263pub fn partial_trace(
265 matrix: &Array2<Complex64>,
266 keep_qubits: &[usize],
267 total_qubits: usize,
268) -> QuantRS2Result<Array2<Complex64>> {
269 let full_dim = 1 << total_qubits;
270 if matrix.nrows() != full_dim || matrix.ncols() != full_dim {
271 return Err(QuantRS2Error::InvalidInput(format!(
272 "Matrix dimension {} doesn't match {} qubits",
273 matrix.nrows(),
274 total_qubits
275 )));
276 }
277
278 let keep_dim = 1 << keep_qubits.len();
279 let trace_qubits: Vec<usize> = (0..total_qubits)
280 .filter(|q| !keep_qubits.contains(q))
281 .collect();
282 let trace_dim = 1 << trace_qubits.len();
283
284 let mut result = Array2::zeros((keep_dim, keep_dim));
285
286 for i in 0..keep_dim {
288 for j in 0..keep_dim {
289 let mut sum = Complex64::new(0.0, 0.0);
290
291 for t in 0..trace_dim {
293 let row_idx = reconstruct_index(i, t, keep_qubits, &trace_qubits, total_qubits);
294 let col_idx = reconstruct_index(j, t, keep_qubits, &trace_qubits, total_qubits);
295 sum += matrix[[row_idx, col_idx]];
296 }
297
298 result[[i, j]] = sum;
299 }
300 }
301
302 Ok(result)
303}
304
305fn reconstruct_index(
307 keep_idx: usize,
308 trace_idx: usize,
309 keep_qubits: &[usize],
310 trace_qubits: &[usize],
311 _total_qubits: usize,
312) -> usize {
313 let mut index = 0;
314
315 for (i, &q) in keep_qubits.iter().enumerate() {
317 if (keep_idx >> i) & 1 == 1 {
318 index |= 1 << q;
319 }
320 }
321
322 for (i, &q) in trace_qubits.iter().enumerate() {
324 if (trace_idx >> i) & 1 == 1 {
325 index |= 1 << q;
326 }
327 }
328
329 index
330}
331
332pub fn tensor_product_many(matrices: &[&dyn QuantumMatrix]) -> QuantRS2Result<Array2<Complex64>> {
334 if matrices.is_empty() {
335 return Err(QuantRS2Error::InvalidInput(
336 "Cannot compute tensor product of empty list".to_string(),
337 ));
338 }
339
340 if matrices.len() == 1 {
341 return Ok(matrices[0].to_dense());
342 }
343
344 let mut result = matrices[0].to_dense();
345 for matrix in matrices.iter().skip(1) {
346 let dense_result = DenseMatrix::new(result)?;
347 result = dense_result.tensor_product(*matrix)?;
348 }
349
350 Ok(result)
351}
352
353pub fn matrices_approx_equal(
355 a: &ArrayView2<Complex64>,
356 b: &ArrayView2<Complex64>,
357 tolerance: f64,
358) -> bool {
359 if a.shape() != b.shape() {
360 return false;
361 }
362
363 for (x, y) in a.iter().zip(b.iter()) {
364 if (x - y).norm() > tolerance {
365 return false;
366 }
367 }
368
369 true
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375 use num_complex::Complex64;
376
377 #[test]
378 fn test_dense_matrix_creation() {
379 let data = Array2::from_shape_vec(
380 (2, 2),
381 vec![
382 Complex64::new(1.0, 0.0),
383 Complex64::new(0.0, 0.0),
384 Complex64::new(0.0, 0.0),
385 Complex64::new(1.0, 0.0),
386 ],
387 )
388 .unwrap();
389
390 let matrix = DenseMatrix::new(data).unwrap();
391 assert_eq!(matrix.dim(), 2);
392 }
393
394 #[test]
395 fn test_unitary_check() {
396 let sqrt2 = 1.0 / 2.0_f64.sqrt();
398 let data = Array2::from_shape_vec(
399 (2, 2),
400 vec![
401 Complex64::new(sqrt2, 0.0),
402 Complex64::new(sqrt2, 0.0),
403 Complex64::new(sqrt2, 0.0),
404 Complex64::new(-sqrt2, 0.0),
405 ],
406 )
407 .unwrap();
408
409 let matrix = DenseMatrix::new(data).unwrap();
410 assert!(matrix.is_unitary(1e-10).unwrap());
411 }
412
413 #[test]
414 fn test_tensor_product() {
415 let id = DenseMatrix::new(
417 Array2::from_shape_vec(
418 (2, 2),
419 vec![
420 Complex64::new(1.0, 0.0),
421 Complex64::new(0.0, 0.0),
422 Complex64::new(0.0, 0.0),
423 Complex64::new(1.0, 0.0),
424 ],
425 )
426 .unwrap(),
427 )
428 .unwrap();
429
430 let x = DenseMatrix::new(
431 Array2::from_shape_vec(
432 (2, 2),
433 vec![
434 Complex64::new(0.0, 0.0),
435 Complex64::new(1.0, 0.0),
436 Complex64::new(1.0, 0.0),
437 Complex64::new(0.0, 0.0),
438 ],
439 )
440 .unwrap(),
441 )
442 .unwrap();
443
444 let result = id.tensor_product(&x).unwrap();
445 assert_eq!(result.shape(), &[4, 4]);
446
447 assert_eq!(result[[0, 1]], Complex64::new(1.0, 0.0));
449 assert_eq!(result[[2, 3]], Complex64::new(1.0, 0.0));
450 }
451}