1use crate::error::{QuantRS2Error, QuantRS2Result};
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
9use scirs2_core::Complex64;
10use crate::linalg_stubs::CsrMatrix;
12use std::fmt::Debug;
13
14pub trait QuantumMatrix: Debug + Send + Sync {
16 fn dim(&self) -> usize;
18
19 fn to_dense(&self) -> Array2<Complex64>;
21
22 fn to_sparse(&self) -> QuantRS2Result<CsrMatrix<Complex64>>;
24
25 fn is_unitary(&self, tolerance: f64) -> QuantRS2Result<bool>;
27
28 fn tensor_product(&self, other: &dyn QuantumMatrix) -> QuantRS2Result<Array2<Complex64>>;
30
31 fn apply(&self, state: &ArrayView1<Complex64>) -> QuantRS2Result<Array1<Complex64>>;
33}
34
35#[derive(Debug, Clone)]
37pub struct DenseMatrix {
38 data: Array2<Complex64>,
39}
40
41impl DenseMatrix {
42 pub fn new(data: Array2<Complex64>) -> QuantRS2Result<Self> {
44 if data.nrows() != data.ncols() {
45 return Err(QuantRS2Error::InvalidInput(
46 "Matrix must be square".to_string(),
47 ));
48 }
49 Ok(Self { data })
50 }
51
52 pub fn from_vec(data: Vec<Complex64>, dim: usize) -> QuantRS2Result<Self> {
54 if data.len() != dim * dim {
55 return Err(QuantRS2Error::InvalidInput(format!(
56 "Expected {} elements, got {}",
57 dim * dim,
58 data.len()
59 )));
60 }
61 let matrix = Array2::from_shape_vec((dim, dim), data)
62 .map_err(|e| QuantRS2Error::InvalidInput(e.to_string()))?;
63 Self::new(matrix)
64 }
65
66 pub const fn as_array(&self) -> &Array2<Complex64> {
68 &self.data
69 }
70
71 pub fn is_hermitian(&self, tolerance: f64) -> bool {
73 let n = self.data.nrows();
74 for i in 0..n {
75 for j in i..n {
76 let diff = (self.data[[i, j]] - self.data[[j, i]].conj()).norm();
77 if diff > tolerance {
78 return false;
79 }
80 }
81 }
82 true
83 }
84}
85
86impl QuantumMatrix for DenseMatrix {
87 fn dim(&self) -> usize {
88 self.data.nrows()
89 }
90
91 fn to_dense(&self) -> Array2<Complex64> {
92 self.data.clone()
93 }
94
95 fn to_sparse(&self) -> QuantRS2Result<CsrMatrix<Complex64>> {
96 let n = self.dim();
97 let mut rows = Vec::new();
98 let mut cols = Vec::new();
99 let mut data = Vec::new();
100
101 let tolerance = 1e-14;
102 for i in 0..n {
103 for j in 0..n {
104 let val = self.data[[i, j]];
105 if val.norm() > tolerance {
106 rows.push(i);
107 cols.push(j);
108 data.push(val);
109 }
110 }
111 }
112
113 CsrMatrix::new(data, rows, cols, (n, n)).map_err(|e| QuantRS2Error::InvalidInput(e))
114 }
115
116 fn is_unitary(&self, tolerance: f64) -> QuantRS2Result<bool> {
117 let n = self.dim();
118 let conj_transpose = self.data.t().mapv(|x| x.conj());
119 let product = self.data.dot(&conj_transpose);
120
121 for i in 0..n {
123 for j in 0..n {
124 let expected = if i == j {
125 Complex64::new(1.0, 0.0)
126 } else {
127 Complex64::new(0.0, 0.0)
128 };
129 let diff = (product[[i, j]] - expected).norm();
130 if diff > tolerance {
131 return Ok(false);
132 }
133 }
134 }
135 Ok(true)
136 }
137
138 fn tensor_product(&self, other: &dyn QuantumMatrix) -> QuantRS2Result<Array2<Complex64>> {
139 let other_dense = other.to_dense();
140 let n1 = self.dim();
141 let n2 = other_dense.nrows();
142 let n = n1 * n2;
143
144 let mut result = Array2::zeros((n, n));
145
146 for i1 in 0..n1 {
147 for j1 in 0..n1 {
148 let val1 = self.data[[i1, j1]];
149 for i2 in 0..n2 {
150 for j2 in 0..n2 {
151 let val2 = other_dense[[i2, j2]];
152 result[[i1 * n2 + i2, j1 * n2 + j2]] = val1 * val2;
153 }
154 }
155 }
156 }
157
158 Ok(result)
159 }
160
161 fn apply(&self, state: &ArrayView1<Complex64>) -> QuantRS2Result<Array1<Complex64>> {
162 if state.len() != self.dim() {
163 return Err(QuantRS2Error::InvalidInput(format!(
164 "State dimension {} doesn't match matrix dimension {}",
165 state.len(),
166 self.dim()
167 )));
168 }
169 Ok(self.data.dot(state))
170 }
171}
172
173#[derive(Clone)]
175pub struct SparseMatrix {
176 csr: CsrMatrix<Complex64>,
177 dim: usize,
178}
179
180impl Debug for SparseMatrix {
181 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182 f.debug_struct("SparseMatrix")
183 .field("dim", &self.dim)
184 .field("nnz", &self.csr.nnz())
185 .finish()
186 }
187}
188
189impl SparseMatrix {
190 pub fn new(csr: CsrMatrix<Complex64>) -> QuantRS2Result<Self> {
192 let (rows, cols) = csr.shape();
193 if rows != cols {
194 return Err(QuantRS2Error::InvalidInput(
195 "Matrix must be square".to_string(),
196 ));
197 }
198 Ok(Self { csr, dim: rows })
199 }
200
201 pub fn from_triplets(
203 rows: Vec<usize>,
204 cols: Vec<usize>,
205 data: Vec<Complex64>,
206 dim: usize,
207 ) -> QuantRS2Result<Self> {
208 let csr = CsrMatrix::new(data, rows, cols, (dim, dim))
209 .map_err(|e| QuantRS2Error::InvalidInput(e))?;
210 Self::new(csr)
211 }
212}
213
214impl QuantumMatrix for SparseMatrix {
215 fn dim(&self) -> usize {
216 self.dim
217 }
218
219 fn to_dense(&self) -> Array2<Complex64> {
220 self.csr.to_dense()
221 }
222
223 fn to_sparse(&self) -> QuantRS2Result<CsrMatrix<Complex64>> {
224 Ok(self.csr.clone())
225 }
226
227 fn is_unitary(&self, tolerance: f64) -> QuantRS2Result<bool> {
228 let dense = DenseMatrix::new(self.to_dense())?;
230 dense.is_unitary(tolerance)
231 }
232
233 fn tensor_product(&self, other: &dyn QuantumMatrix) -> QuantRS2Result<Array2<Complex64>> {
234 let other_dense = other.to_dense();
237 let (a_rows, a_cols) = self.csr.shape();
238 let b_rows = other_dense.nrows();
239 let b_cols = other_dense.ncols();
240
241 if a_rows == 0 || a_cols == 0 || b_rows == 0 || b_cols == 0 {
243 let out_rows = a_rows * b_rows;
244 let out_cols = a_cols * b_cols;
245 return Ok(Array2::zeros((out_rows, out_cols)));
246 }
247
248 let out_rows = a_rows * b_rows;
249 let out_cols = a_cols * b_cols;
250 let mut result = Array2::zeros((out_rows, out_cols));
251
252 let dense_a = self.csr.to_dense();
255 for i in 0..a_rows {
256 for j in 0..a_cols {
257 let val_a = dense_a[[i, j]];
258 if val_a.norm() < 1e-14 {
260 continue;
261 }
262 for k in 0..b_rows {
263 for l in 0..b_cols {
264 result[[i * b_rows + k, j * b_cols + l]] = val_a * other_dense[[k, l]];
265 }
266 }
267 }
268 }
269
270 Ok(result)
271 }
272
273 fn apply(&self, state: &ArrayView1<Complex64>) -> QuantRS2Result<Array1<Complex64>> {
274 if state.len() != self.dim() {
275 return Err(QuantRS2Error::InvalidInput(format!(
276 "State dimension {} doesn't match matrix dimension {}",
277 state.len(),
278 self.dim()
279 )));
280 }
281 let dense = self.to_dense();
283 Ok(dense.dot(state))
284 }
285}
286
287pub fn sparse_tensor_product(
294 a: &Array2<scirs2_core::Complex64>,
295 b: &Array2<scirs2_core::Complex64>,
296) -> QuantRS2Result<Array2<scirs2_core::Complex64>> {
297 let a_rows = a.nrows();
298 let a_cols = a.ncols();
299 let b_rows = b.nrows();
300 let b_cols = b.ncols();
301
302 if a_rows == 0 || a_cols == 0 || b_rows == 0 || b_cols == 0 {
304 return Ok(Array2::zeros((a_rows * b_rows, a_cols * b_cols)));
305 }
306
307 let out_rows = a_rows * b_rows;
308 let out_cols = a_cols * b_cols;
309 let mut result = Array2::zeros((out_rows, out_cols));
310
311 for i in 0..a_rows {
312 for j in 0..a_cols {
313 let val_a = a[[i, j]];
314 if val_a.norm() < 1e-14 {
316 continue;
317 }
318 for k in 0..b_rows {
319 for l in 0..b_cols {
320 result[[i * b_rows + k, j * b_cols + l]] = val_a * b[[k, l]];
321 }
322 }
323 }
324 }
325
326 Ok(result)
327}
328
329pub fn partial_trace(
331 matrix: &Array2<Complex64>,
332 keep_qubits: &[usize],
333 total_qubits: usize,
334) -> QuantRS2Result<Array2<Complex64>> {
335 let full_dim = 1 << total_qubits;
336 if matrix.nrows() != full_dim || matrix.ncols() != full_dim {
337 return Err(QuantRS2Error::InvalidInput(format!(
338 "Matrix dimension {} doesn't match {} qubits",
339 matrix.nrows(),
340 total_qubits
341 )));
342 }
343
344 let keep_dim = 1 << keep_qubits.len();
345 let trace_qubits: Vec<usize> = (0..total_qubits)
346 .filter(|q| !keep_qubits.contains(q))
347 .collect();
348 let trace_dim = 1 << trace_qubits.len();
349
350 let mut result = Array2::zeros((keep_dim, keep_dim));
351
352 for i in 0..keep_dim {
354 for j in 0..keep_dim {
355 let mut sum = Complex64::new(0.0, 0.0);
356
357 for t in 0..trace_dim {
359 let row_idx = reconstruct_index(i, t, keep_qubits, &trace_qubits, total_qubits);
360 let col_idx = reconstruct_index(j, t, keep_qubits, &trace_qubits, total_qubits);
361 sum += matrix[[row_idx, col_idx]];
362 }
363
364 result[[i, j]] = sum;
365 }
366 }
367
368 Ok(result)
369}
370
371fn reconstruct_index(
373 keep_idx: usize,
374 trace_idx: usize,
375 keep_qubits: &[usize],
376 trace_qubits: &[usize],
377 _total_qubits: usize,
378) -> usize {
379 let mut index = 0;
380
381 for (i, &q) in keep_qubits.iter().enumerate() {
383 if (keep_idx >> i) & 1 == 1 {
384 index |= 1 << q;
385 }
386 }
387
388 for (i, &q) in trace_qubits.iter().enumerate() {
390 if (trace_idx >> i) & 1 == 1 {
391 index |= 1 << q;
392 }
393 }
394
395 index
396}
397
398pub fn tensor_product_many(matrices: &[&dyn QuantumMatrix]) -> QuantRS2Result<Array2<Complex64>> {
400 if matrices.is_empty() {
401 return Err(QuantRS2Error::InvalidInput(
402 "Cannot compute tensor product of empty list".to_string(),
403 ));
404 }
405
406 if matrices.len() == 1 {
407 return Ok(matrices[0].to_dense());
408 }
409
410 let mut result = matrices[0].to_dense();
411 for matrix in matrices.iter().skip(1) {
412 let dense_result = DenseMatrix::new(result)?;
413 result = dense_result.tensor_product(*matrix)?;
414 }
415
416 Ok(result)
417}
418
419pub fn matrices_approx_equal(
421 a: &ArrayView2<Complex64>,
422 b: &ArrayView2<Complex64>,
423 tolerance: f64,
424) -> bool {
425 if a.shape() != b.shape() {
426 return false;
427 }
428
429 for (x, y) in a.iter().zip(b.iter()) {
430 if (x - y).norm() > tolerance {
431 return false;
432 }
433 }
434
435 true
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441 use scirs2_core::Complex64;
442
443 #[test]
444 fn test_dense_matrix_creation() {
445 let data = Array2::from_shape_vec(
446 (2, 2),
447 vec![
448 Complex64::new(1.0, 0.0),
449 Complex64::new(0.0, 0.0),
450 Complex64::new(0.0, 0.0),
451 Complex64::new(1.0, 0.0),
452 ],
453 )
454 .expect("Matrix data creation should succeed");
455
456 let matrix = DenseMatrix::new(data).expect("DenseMatrix creation should succeed");
457 assert_eq!(matrix.dim(), 2);
458 }
459
460 #[test]
461 fn test_unitary_check() {
462 let sqrt2 = 1.0 / 2.0_f64.sqrt();
464 let data = Array2::from_shape_vec(
465 (2, 2),
466 vec![
467 Complex64::new(sqrt2, 0.0),
468 Complex64::new(sqrt2, 0.0),
469 Complex64::new(sqrt2, 0.0),
470 Complex64::new(-sqrt2, 0.0),
471 ],
472 )
473 .expect("Hadamard matrix data creation should succeed");
474
475 let matrix = DenseMatrix::new(data).expect("DenseMatrix creation should succeed");
476 assert!(matrix
477 .is_unitary(1e-10)
478 .expect("Unitary check should succeed"));
479 }
480
481 #[test]
482 fn test_tensor_product() {
483 let id = DenseMatrix::new(
485 Array2::from_shape_vec(
486 (2, 2),
487 vec![
488 Complex64::new(1.0, 0.0),
489 Complex64::new(0.0, 0.0),
490 Complex64::new(0.0, 0.0),
491 Complex64::new(1.0, 0.0),
492 ],
493 )
494 .expect("Identity matrix data creation should succeed"),
495 )
496 .expect("Identity DenseMatrix creation should succeed");
497
498 let x = DenseMatrix::new(
499 Array2::from_shape_vec(
500 (2, 2),
501 vec![
502 Complex64::new(0.0, 0.0),
503 Complex64::new(1.0, 0.0),
504 Complex64::new(1.0, 0.0),
505 Complex64::new(0.0, 0.0),
506 ],
507 )
508 .expect("Pauli-X matrix data creation should succeed"),
509 )
510 .expect("Pauli-X DenseMatrix creation should succeed");
511
512 let result = id
513 .tensor_product(&x)
514 .expect("Tensor product should succeed");
515 assert_eq!(result.shape(), &[4, 4]);
516
517 assert_eq!(result[[0, 1]], Complex64::new(1.0, 0.0));
519 assert_eq!(result[[2, 3]], Complex64::new(1.0, 0.0));
520 }
521}