Skip to main content

quantrs2_core/
matrix_ops.rs

1//! Matrix operations for quantum gates using SciRS2
2//!
3//! This module provides efficient matrix operations for quantum computing,
4//! including sparse/dense conversions, tensor products, and specialized
5//! quantum operations.
6
7use crate::error::{QuantRS2Error, QuantRS2Result};
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
9use scirs2_core::Complex64;
10// use scirs2_sparse::csr::CsrMatrix;
11use crate::linalg_stubs::CsrMatrix;
12use std::fmt::Debug;
13
14/// Trait for quantum matrix operations
15pub trait QuantumMatrix: Debug + Send + Sync {
16    /// Get the dimension of the matrix (assumed square)
17    fn dim(&self) -> usize;
18
19    /// Convert to dense representation
20    fn to_dense(&self) -> Array2<Complex64>;
21
22    /// Convert to sparse representation
23    fn to_sparse(&self) -> QuantRS2Result<CsrMatrix<Complex64>>;
24
25    /// Check if the matrix is unitary
26    fn is_unitary(&self, tolerance: f64) -> QuantRS2Result<bool>;
27
28    /// Compute the tensor product with another matrix
29    fn tensor_product(&self, other: &dyn QuantumMatrix) -> QuantRS2Result<Array2<Complex64>>;
30
31    /// Apply to a state vector
32    fn apply(&self, state: &ArrayView1<Complex64>) -> QuantRS2Result<Array1<Complex64>>;
33}
34
35/// Dense matrix representation
36#[derive(Debug, Clone)]
37pub struct DenseMatrix {
38    data: Array2<Complex64>,
39}
40
41impl DenseMatrix {
42    /// Create a new dense matrix
43    pub fn new(data: Array2<Complex64>) -> QuantRS2Result<Self> {
44        if data.nrows() != data.ncols() {
45            return Err(QuantRS2Error::InvalidInput(
46                "Matrix must be square".to_string(),
47            ));
48        }
49        Ok(Self { data })
50    }
51
52    /// Create from a flat vector (column-major order)
53    pub fn from_vec(data: Vec<Complex64>, dim: usize) -> QuantRS2Result<Self> {
54        if data.len() != dim * dim {
55            return Err(QuantRS2Error::InvalidInput(format!(
56                "Expected {} elements, got {}",
57                dim * dim,
58                data.len()
59            )));
60        }
61        let matrix = Array2::from_shape_vec((dim, dim), data)
62            .map_err(|e| QuantRS2Error::InvalidInput(e.to_string()))?;
63        Self::new(matrix)
64    }
65
66    /// Get a reference to the underlying array
67    pub const fn as_array(&self) -> &Array2<Complex64> {
68        &self.data
69    }
70
71    /// Check if matrix is hermitian
72    pub fn is_hermitian(&self, tolerance: f64) -> bool {
73        let n = self.data.nrows();
74        for i in 0..n {
75            for j in i..n {
76                let diff = (self.data[[i, j]] - self.data[[j, i]].conj()).norm();
77                if diff > tolerance {
78                    return false;
79                }
80            }
81        }
82        true
83    }
84}
85
86impl QuantumMatrix for DenseMatrix {
87    fn dim(&self) -> usize {
88        self.data.nrows()
89    }
90
91    fn to_dense(&self) -> Array2<Complex64> {
92        self.data.clone()
93    }
94
95    fn to_sparse(&self) -> QuantRS2Result<CsrMatrix<Complex64>> {
96        let n = self.dim();
97        let mut rows = Vec::new();
98        let mut cols = Vec::new();
99        let mut data = Vec::new();
100
101        let tolerance = 1e-14;
102        for i in 0..n {
103            for j in 0..n {
104                let val = self.data[[i, j]];
105                if val.norm() > tolerance {
106                    rows.push(i);
107                    cols.push(j);
108                    data.push(val);
109                }
110            }
111        }
112
113        CsrMatrix::new(data, rows, cols, (n, n)).map_err(|e| QuantRS2Error::InvalidInput(e))
114    }
115
116    fn is_unitary(&self, tolerance: f64) -> QuantRS2Result<bool> {
117        let n = self.dim();
118        let conj_transpose = self.data.t().mapv(|x| x.conj());
119        let product = self.data.dot(&conj_transpose);
120
121        // Check if product is identity
122        for i in 0..n {
123            for j in 0..n {
124                let expected = if i == j {
125                    Complex64::new(1.0, 0.0)
126                } else {
127                    Complex64::new(0.0, 0.0)
128                };
129                let diff = (product[[i, j]] - expected).norm();
130                if diff > tolerance {
131                    return Ok(false);
132                }
133            }
134        }
135        Ok(true)
136    }
137
138    fn tensor_product(&self, other: &dyn QuantumMatrix) -> QuantRS2Result<Array2<Complex64>> {
139        let other_dense = other.to_dense();
140        let n1 = self.dim();
141        let n2 = other_dense.nrows();
142        let n = n1 * n2;
143
144        let mut result = Array2::zeros((n, n));
145
146        for i1 in 0..n1 {
147            for j1 in 0..n1 {
148                let val1 = self.data[[i1, j1]];
149                for i2 in 0..n2 {
150                    for j2 in 0..n2 {
151                        let val2 = other_dense[[i2, j2]];
152                        result[[i1 * n2 + i2, j1 * n2 + j2]] = val1 * val2;
153                    }
154                }
155            }
156        }
157
158        Ok(result)
159    }
160
161    fn apply(&self, state: &ArrayView1<Complex64>) -> QuantRS2Result<Array1<Complex64>> {
162        if state.len() != self.dim() {
163            return Err(QuantRS2Error::InvalidInput(format!(
164                "State dimension {} doesn't match matrix dimension {}",
165                state.len(),
166                self.dim()
167            )));
168        }
169        Ok(self.data.dot(state))
170    }
171}
172
173/// Sparse matrix representation for quantum gates
174#[derive(Clone)]
175pub struct SparseMatrix {
176    csr: CsrMatrix<Complex64>,
177    dim: usize,
178}
179
180impl Debug for SparseMatrix {
181    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182        f.debug_struct("SparseMatrix")
183            .field("dim", &self.dim)
184            .field("nnz", &self.csr.nnz())
185            .finish()
186    }
187}
188
189impl SparseMatrix {
190    /// Create a new sparse matrix
191    pub fn new(csr: CsrMatrix<Complex64>) -> QuantRS2Result<Self> {
192        let (rows, cols) = csr.shape();
193        if rows != cols {
194            return Err(QuantRS2Error::InvalidInput(
195                "Matrix must be square".to_string(),
196            ));
197        }
198        Ok(Self { csr, dim: rows })
199    }
200
201    /// Create from triplets
202    pub fn from_triplets(
203        rows: Vec<usize>,
204        cols: Vec<usize>,
205        data: Vec<Complex64>,
206        dim: usize,
207    ) -> QuantRS2Result<Self> {
208        let csr = CsrMatrix::new(data, rows, cols, (dim, dim))
209            .map_err(|e| QuantRS2Error::InvalidInput(e))?;
210        Self::new(csr)
211    }
212}
213
214impl QuantumMatrix for SparseMatrix {
215    fn dim(&self) -> usize {
216        self.dim
217    }
218
219    fn to_dense(&self) -> Array2<Complex64> {
220        self.csr.to_dense()
221    }
222
223    fn to_sparse(&self) -> QuantRS2Result<CsrMatrix<Complex64>> {
224        Ok(self.csr.clone())
225    }
226
227    fn is_unitary(&self, tolerance: f64) -> QuantRS2Result<bool> {
228        // Convert to dense for unitary check
229        let dense = DenseMatrix::new(self.to_dense())?;
230        dense.is_unitary(tolerance)
231    }
232
233    fn tensor_product(&self, other: &dyn QuantumMatrix) -> QuantRS2Result<Array2<Complex64>> {
234        // Implement sparse tensor product (Kronecker product) efficiently
235        // Result dimensions: (a.rows * b.rows, a.cols * b.cols)
236        let other_dense = other.to_dense();
237        let (a_rows, a_cols) = self.csr.shape();
238        let b_rows = other_dense.nrows();
239        let b_cols = other_dense.ncols();
240
241        // Early return for zero dimensions
242        if a_rows == 0 || a_cols == 0 || b_rows == 0 || b_cols == 0 {
243            let out_rows = a_rows * b_rows;
244            let out_cols = a_cols * b_cols;
245            return Ok(Array2::zeros((out_rows, out_cols)));
246        }
247
248        let out_rows = a_rows * b_rows;
249        let out_cols = a_cols * b_cols;
250        let mut result = Array2::zeros((out_rows, out_cols));
251
252        // Iterate over non-zero elements of the sparse matrix A
253        // result[(i*p + k, j*q + l)] = a[(i,j)] * b[(k,l)]
254        let dense_a = self.csr.to_dense();
255        for i in 0..a_rows {
256            for j in 0..a_cols {
257                let val_a = dense_a[[i, j]];
258                // Skip near-zero entries for efficiency
259                if val_a.norm() < 1e-14 {
260                    continue;
261                }
262                for k in 0..b_rows {
263                    for l in 0..b_cols {
264                        result[[i * b_rows + k, j * b_cols + l]] = val_a * other_dense[[k, l]];
265                    }
266                }
267            }
268        }
269
270        Ok(result)
271    }
272
273    fn apply(&self, state: &ArrayView1<Complex64>) -> QuantRS2Result<Array1<Complex64>> {
274        if state.len() != self.dim() {
275            return Err(QuantRS2Error::InvalidInput(format!(
276                "State dimension {} doesn't match matrix dimension {}",
277                state.len(),
278                self.dim()
279            )));
280        }
281        // Convert to dense and apply
282        let dense = self.to_dense();
283        Ok(dense.dot(state))
284    }
285}
286
287/// Compute the sparse Kronecker (tensor) product of two dense matrices.
288///
289/// Given matrices A of shape `(m, n)` and B of shape `(p, q)`,
290/// returns the Kronecker product A ⊗ B of shape `(m*p, n*q)`.
291///
292/// `result[(i*p + k, j*q + l)] = a[(i,j)] * b[(k,l)]`
293pub fn sparse_tensor_product(
294    a: &Array2<scirs2_core::Complex64>,
295    b: &Array2<scirs2_core::Complex64>,
296) -> QuantRS2Result<Array2<scirs2_core::Complex64>> {
297    let a_rows = a.nrows();
298    let a_cols = a.ncols();
299    let b_rows = b.nrows();
300    let b_cols = b.ncols();
301
302    // Handle zero-dimension early return
303    if a_rows == 0 || a_cols == 0 || b_rows == 0 || b_cols == 0 {
304        return Ok(Array2::zeros((a_rows * b_rows, a_cols * b_cols)));
305    }
306
307    let out_rows = a_rows * b_rows;
308    let out_cols = a_cols * b_cols;
309    let mut result = Array2::zeros((out_rows, out_cols));
310
311    for i in 0..a_rows {
312        for j in 0..a_cols {
313            let val_a = a[[i, j]];
314            // Skip near-zero entries for sparse efficiency
315            if val_a.norm() < 1e-14 {
316                continue;
317            }
318            for k in 0..b_rows {
319                for l in 0..b_cols {
320                    result[[i * b_rows + k, j * b_cols + l]] = val_a * b[[k, l]];
321                }
322            }
323        }
324    }
325
326    Ok(result)
327}
328
329/// Compute the partial trace of a matrix
330pub fn partial_trace(
331    matrix: &Array2<Complex64>,
332    keep_qubits: &[usize],
333    total_qubits: usize,
334) -> QuantRS2Result<Array2<Complex64>> {
335    let full_dim = 1 << total_qubits;
336    if matrix.nrows() != full_dim || matrix.ncols() != full_dim {
337        return Err(QuantRS2Error::InvalidInput(format!(
338            "Matrix dimension {} doesn't match {} qubits",
339            matrix.nrows(),
340            total_qubits
341        )));
342    }
343
344    let keep_dim = 1 << keep_qubits.len();
345    let trace_qubits: Vec<usize> = (0..total_qubits)
346        .filter(|q| !keep_qubits.contains(q))
347        .collect();
348    let trace_dim = 1 << trace_qubits.len();
349
350    let mut result = Array2::zeros((keep_dim, keep_dim));
351
352    // Iterate over all basis states
353    for i in 0..keep_dim {
354        for j in 0..keep_dim {
355            let mut sum = Complex64::new(0.0, 0.0);
356
357            // Sum over traced out qubits
358            for t in 0..trace_dim {
359                let row_idx = reconstruct_index(i, t, keep_qubits, &trace_qubits, total_qubits);
360                let col_idx = reconstruct_index(j, t, keep_qubits, &trace_qubits, total_qubits);
361                sum += matrix[[row_idx, col_idx]];
362            }
363
364            result[[i, j]] = sum;
365        }
366    }
367
368    Ok(result)
369}
370
371/// Helper function to reconstruct full index from partial indices
372fn reconstruct_index(
373    keep_idx: usize,
374    trace_idx: usize,
375    keep_qubits: &[usize],
376    trace_qubits: &[usize],
377    _total_qubits: usize,
378) -> usize {
379    let mut index = 0;
380
381    // Set bits for kept qubits
382    for (i, &q) in keep_qubits.iter().enumerate() {
383        if (keep_idx >> i) & 1 == 1 {
384            index |= 1 << q;
385        }
386    }
387
388    // Set bits for traced qubits
389    for (i, &q) in trace_qubits.iter().enumerate() {
390        if (trace_idx >> i) & 1 == 1 {
391            index |= 1 << q;
392        }
393    }
394
395    index
396}
397
398/// Compute the tensor product of multiple matrices efficiently
399pub fn tensor_product_many(matrices: &[&dyn QuantumMatrix]) -> QuantRS2Result<Array2<Complex64>> {
400    if matrices.is_empty() {
401        return Err(QuantRS2Error::InvalidInput(
402            "Cannot compute tensor product of empty list".to_string(),
403        ));
404    }
405
406    if matrices.len() == 1 {
407        return Ok(matrices[0].to_dense());
408    }
409
410    let mut result = matrices[0].to_dense();
411    for matrix in matrices.iter().skip(1) {
412        let dense_result = DenseMatrix::new(result)?;
413        result = dense_result.tensor_product(*matrix)?;
414    }
415
416    Ok(result)
417}
418
419/// Check if two matrices are approximately equal
420pub fn matrices_approx_equal(
421    a: &ArrayView2<Complex64>,
422    b: &ArrayView2<Complex64>,
423    tolerance: f64,
424) -> bool {
425    if a.shape() != b.shape() {
426        return false;
427    }
428
429    for (x, y) in a.iter().zip(b.iter()) {
430        if (x - y).norm() > tolerance {
431            return false;
432        }
433    }
434
435    true
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441    use scirs2_core::Complex64;
442
443    #[test]
444    fn test_dense_matrix_creation() {
445        let data = Array2::from_shape_vec(
446            (2, 2),
447            vec![
448                Complex64::new(1.0, 0.0),
449                Complex64::new(0.0, 0.0),
450                Complex64::new(0.0, 0.0),
451                Complex64::new(1.0, 0.0),
452            ],
453        )
454        .expect("Matrix data creation should succeed");
455
456        let matrix = DenseMatrix::new(data).expect("DenseMatrix creation should succeed");
457        assert_eq!(matrix.dim(), 2);
458    }
459
460    #[test]
461    fn test_unitary_check() {
462        // Hadamard gate
463        let sqrt2 = 1.0 / 2.0_f64.sqrt();
464        let data = Array2::from_shape_vec(
465            (2, 2),
466            vec![
467                Complex64::new(sqrt2, 0.0),
468                Complex64::new(sqrt2, 0.0),
469                Complex64::new(sqrt2, 0.0),
470                Complex64::new(-sqrt2, 0.0),
471            ],
472        )
473        .expect("Hadamard matrix data creation should succeed");
474
475        let matrix = DenseMatrix::new(data).expect("DenseMatrix creation should succeed");
476        assert!(matrix
477            .is_unitary(1e-10)
478            .expect("Unitary check should succeed"));
479    }
480
481    #[test]
482    fn test_tensor_product() {
483        // Identity ⊗ Pauli-X
484        let id = DenseMatrix::new(
485            Array2::from_shape_vec(
486                (2, 2),
487                vec![
488                    Complex64::new(1.0, 0.0),
489                    Complex64::new(0.0, 0.0),
490                    Complex64::new(0.0, 0.0),
491                    Complex64::new(1.0, 0.0),
492                ],
493            )
494            .expect("Identity matrix data creation should succeed"),
495        )
496        .expect("Identity DenseMatrix creation should succeed");
497
498        let x = DenseMatrix::new(
499            Array2::from_shape_vec(
500                (2, 2),
501                vec![
502                    Complex64::new(0.0, 0.0),
503                    Complex64::new(1.0, 0.0),
504                    Complex64::new(1.0, 0.0),
505                    Complex64::new(0.0, 0.0),
506                ],
507            )
508            .expect("Pauli-X matrix data creation should succeed"),
509        )
510        .expect("Pauli-X DenseMatrix creation should succeed");
511
512        let result = id
513            .tensor_product(&x)
514            .expect("Tensor product should succeed");
515        assert_eq!(result.shape(), &[4, 4]);
516
517        // Check specific values
518        assert_eq!(result[[0, 1]], Complex64::new(1.0, 0.0));
519        assert_eq!(result[[2, 3]], Complex64::new(1.0, 0.0));
520    }
521}