quantrs2_core/
matrix_ops.rs

1//! Matrix operations for quantum gates using SciRS2
2//!
3//! This module provides efficient matrix operations for quantum computing,
4//! including sparse/dense conversions, tensor products, and specialized
5//! quantum operations.
6
7use crate::error::{QuantRS2Error, QuantRS2Result};
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
9use scirs2_core::Complex64;
10// use scirs2_sparse::csr::CsrMatrix;
11use crate::linalg_stubs::CsrMatrix;
12use std::fmt::Debug;
13
14/// Trait for quantum matrix operations
15pub trait QuantumMatrix: Debug + Send + Sync {
16    /// Get the dimension of the matrix (assumed square)
17    fn dim(&self) -> usize;
18
19    /// Convert to dense representation
20    fn to_dense(&self) -> Array2<Complex64>;
21
22    /// Convert to sparse representation
23    fn to_sparse(&self) -> QuantRS2Result<CsrMatrix<Complex64>>;
24
25    /// Check if the matrix is unitary
26    fn is_unitary(&self, tolerance: f64) -> QuantRS2Result<bool>;
27
28    /// Compute the tensor product with another matrix
29    fn tensor_product(&self, other: &dyn QuantumMatrix) -> QuantRS2Result<Array2<Complex64>>;
30
31    /// Apply to a state vector
32    fn apply(&self, state: &ArrayView1<Complex64>) -> QuantRS2Result<Array1<Complex64>>;
33}
34
35/// Dense matrix representation
36#[derive(Debug, Clone)]
37pub struct DenseMatrix {
38    data: Array2<Complex64>,
39}
40
41impl DenseMatrix {
42    /// Create a new dense matrix
43    pub fn new(data: Array2<Complex64>) -> QuantRS2Result<Self> {
44        if data.nrows() != data.ncols() {
45            return Err(QuantRS2Error::InvalidInput(
46                "Matrix must be square".to_string(),
47            ));
48        }
49        Ok(Self { data })
50    }
51
52    /// Create from a flat vector (column-major order)
53    pub fn from_vec(data: Vec<Complex64>, dim: usize) -> QuantRS2Result<Self> {
54        if data.len() != dim * dim {
55            return Err(QuantRS2Error::InvalidInput(format!(
56                "Expected {} elements, got {}",
57                dim * dim,
58                data.len()
59            )));
60        }
61        let matrix = Array2::from_shape_vec((dim, dim), data)
62            .map_err(|e| QuantRS2Error::InvalidInput(e.to_string()))?;
63        Self::new(matrix)
64    }
65
66    /// Get a reference to the underlying array
67    pub fn as_array(&self) -> &Array2<Complex64> {
68        &self.data
69    }
70
71    /// Check if matrix is hermitian
72    pub fn is_hermitian(&self, tolerance: f64) -> bool {
73        let n = self.data.nrows();
74        for i in 0..n {
75            for j in i..n {
76                let diff = (self.data[[i, j]] - self.data[[j, i]].conj()).norm();
77                if diff > tolerance {
78                    return false;
79                }
80            }
81        }
82        true
83    }
84}
85
86impl QuantumMatrix for DenseMatrix {
87    fn dim(&self) -> usize {
88        self.data.nrows()
89    }
90
91    fn to_dense(&self) -> Array2<Complex64> {
92        self.data.clone()
93    }
94
95    fn to_sparse(&self) -> QuantRS2Result<CsrMatrix<Complex64>> {
96        let n = self.dim();
97        let mut rows = Vec::new();
98        let mut cols = Vec::new();
99        let mut data = Vec::new();
100
101        let tolerance = 1e-14;
102        for i in 0..n {
103            for j in 0..n {
104                let val = self.data[[i, j]];
105                if val.norm() > tolerance {
106                    rows.push(i);
107                    cols.push(j);
108                    data.push(val);
109                }
110            }
111        }
112
113        CsrMatrix::new(data, rows, cols, (n, n))
114            .map_err(|e| QuantRS2Error::InvalidInput(e.to_string()))
115    }
116
117    fn is_unitary(&self, tolerance: f64) -> QuantRS2Result<bool> {
118        let n = self.dim();
119        let conj_transpose = self.data.t().mapv(|x| x.conj());
120        let product = self.data.dot(&conj_transpose);
121
122        // Check if product is identity
123        for i in 0..n {
124            for j in 0..n {
125                let expected = if i == j {
126                    Complex64::new(1.0, 0.0)
127                } else {
128                    Complex64::new(0.0, 0.0)
129                };
130                let diff = (product[[i, j]] - expected).norm();
131                if diff > tolerance {
132                    return Ok(false);
133                }
134            }
135        }
136        Ok(true)
137    }
138
139    fn tensor_product(&self, other: &dyn QuantumMatrix) -> QuantRS2Result<Array2<Complex64>> {
140        let other_dense = other.to_dense();
141        let n1 = self.dim();
142        let n2 = other_dense.nrows();
143        let n = n1 * n2;
144
145        let mut result = Array2::zeros((n, n));
146
147        for i1 in 0..n1 {
148            for j1 in 0..n1 {
149                let val1 = self.data[[i1, j1]];
150                for i2 in 0..n2 {
151                    for j2 in 0..n2 {
152                        let val2 = other_dense[[i2, j2]];
153                        result[[i1 * n2 + i2, j1 * n2 + j2]] = val1 * val2;
154                    }
155                }
156            }
157        }
158
159        Ok(result)
160    }
161
162    fn apply(&self, state: &ArrayView1<Complex64>) -> QuantRS2Result<Array1<Complex64>> {
163        if state.len() != self.dim() {
164            return Err(QuantRS2Error::InvalidInput(format!(
165                "State dimension {} doesn't match matrix dimension {}",
166                state.len(),
167                self.dim()
168            )));
169        }
170        Ok(self.data.dot(state))
171    }
172}
173
174/// Sparse matrix representation for quantum gates
175#[derive(Clone)]
176pub struct SparseMatrix {
177    csr: CsrMatrix<Complex64>,
178    dim: usize,
179}
180
181impl Debug for SparseMatrix {
182    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183        f.debug_struct("SparseMatrix")
184            .field("dim", &self.dim)
185            .field("nnz", &self.csr.nnz())
186            .finish()
187    }
188}
189
190impl SparseMatrix {
191    /// Create a new sparse matrix
192    pub fn new(csr: CsrMatrix<Complex64>) -> QuantRS2Result<Self> {
193        let (rows, cols) = csr.shape();
194        if rows != cols {
195            return Err(QuantRS2Error::InvalidInput(
196                "Matrix must be square".to_string(),
197            ));
198        }
199        Ok(Self { csr, dim: rows })
200    }
201
202    /// Create from triplets
203    pub fn from_triplets(
204        rows: Vec<usize>,
205        cols: Vec<usize>,
206        data: Vec<Complex64>,
207        dim: usize,
208    ) -> QuantRS2Result<Self> {
209        let csr = CsrMatrix::new(data, rows, cols, (dim, dim))
210            .map_err(|e| QuantRS2Error::InvalidInput(e.to_string()))?;
211        Self::new(csr)
212    }
213}
214
215impl QuantumMatrix for SparseMatrix {
216    fn dim(&self) -> usize {
217        self.dim
218    }
219
220    fn to_dense(&self) -> Array2<Complex64> {
221        self.csr.to_dense()
222    }
223
224    fn to_sparse(&self) -> QuantRS2Result<CsrMatrix<Complex64>> {
225        Ok(self.csr.clone())
226    }
227
228    fn is_unitary(&self, tolerance: f64) -> QuantRS2Result<bool> {
229        // Convert to dense for unitary check
230        let dense = DenseMatrix::new(self.to_dense())?;
231        dense.is_unitary(tolerance)
232    }
233
234    fn tensor_product(&self, other: &dyn QuantumMatrix) -> QuantRS2Result<Array2<Complex64>> {
235        // For now, convert to dense and compute
236        // TODO: Implement sparse tensor product
237        let dense = DenseMatrix::new(self.to_dense())?;
238        dense.tensor_product(other)
239    }
240
241    fn apply(&self, state: &ArrayView1<Complex64>) -> QuantRS2Result<Array1<Complex64>> {
242        if state.len() != self.dim() {
243            return Err(QuantRS2Error::InvalidInput(format!(
244                "State dimension {} doesn't match matrix dimension {}",
245                state.len(),
246                self.dim()
247            )));
248        }
249        // Convert to dense and apply
250        let dense = self.to_dense();
251        Ok(dense.dot(state))
252    }
253}
254
255/// Compute the partial trace of a matrix
256pub fn partial_trace(
257    matrix: &Array2<Complex64>,
258    keep_qubits: &[usize],
259    total_qubits: usize,
260) -> QuantRS2Result<Array2<Complex64>> {
261    let full_dim = 1 << total_qubits;
262    if matrix.nrows() != full_dim || matrix.ncols() != full_dim {
263        return Err(QuantRS2Error::InvalidInput(format!(
264            "Matrix dimension {} doesn't match {} qubits",
265            matrix.nrows(),
266            total_qubits
267        )));
268    }
269
270    let keep_dim = 1 << keep_qubits.len();
271    let trace_qubits: Vec<usize> = (0..total_qubits)
272        .filter(|q| !keep_qubits.contains(q))
273        .collect();
274    let trace_dim = 1 << trace_qubits.len();
275
276    let mut result = Array2::zeros((keep_dim, keep_dim));
277
278    // Iterate over all basis states
279    for i in 0..keep_dim {
280        for j in 0..keep_dim {
281            let mut sum = Complex64::new(0.0, 0.0);
282
283            // Sum over traced out qubits
284            for t in 0..trace_dim {
285                let row_idx = reconstruct_index(i, t, keep_qubits, &trace_qubits, total_qubits);
286                let col_idx = reconstruct_index(j, t, keep_qubits, &trace_qubits, total_qubits);
287                sum += matrix[[row_idx, col_idx]];
288            }
289
290            result[[i, j]] = sum;
291        }
292    }
293
294    Ok(result)
295}
296
297/// Helper function to reconstruct full index from partial indices
298fn reconstruct_index(
299    keep_idx: usize,
300    trace_idx: usize,
301    keep_qubits: &[usize],
302    trace_qubits: &[usize],
303    _total_qubits: usize,
304) -> usize {
305    let mut index = 0;
306
307    // Set bits for kept qubits
308    for (i, &q) in keep_qubits.iter().enumerate() {
309        if (keep_idx >> i) & 1 == 1 {
310            index |= 1 << q;
311        }
312    }
313
314    // Set bits for traced qubits
315    for (i, &q) in trace_qubits.iter().enumerate() {
316        if (trace_idx >> i) & 1 == 1 {
317            index |= 1 << q;
318        }
319    }
320
321    index
322}
323
324/// Compute the tensor product of multiple matrices efficiently
325pub fn tensor_product_many(matrices: &[&dyn QuantumMatrix]) -> QuantRS2Result<Array2<Complex64>> {
326    if matrices.is_empty() {
327        return Err(QuantRS2Error::InvalidInput(
328            "Cannot compute tensor product of empty list".to_string(),
329        ));
330    }
331
332    if matrices.len() == 1 {
333        return Ok(matrices[0].to_dense());
334    }
335
336    let mut result = matrices[0].to_dense();
337    for matrix in matrices.iter().skip(1) {
338        let dense_result = DenseMatrix::new(result)?;
339        result = dense_result.tensor_product(*matrix)?;
340    }
341
342    Ok(result)
343}
344
345/// Check if two matrices are approximately equal
346pub fn matrices_approx_equal(
347    a: &ArrayView2<Complex64>,
348    b: &ArrayView2<Complex64>,
349    tolerance: f64,
350) -> bool {
351    if a.shape() != b.shape() {
352        return false;
353    }
354
355    for (x, y) in a.iter().zip(b.iter()) {
356        if (x - y).norm() > tolerance {
357            return false;
358        }
359    }
360
361    true
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367    use scirs2_core::Complex64;
368
369    #[test]
370    fn test_dense_matrix_creation() {
371        let data = Array2::from_shape_vec(
372            (2, 2),
373            vec![
374                Complex64::new(1.0, 0.0),
375                Complex64::new(0.0, 0.0),
376                Complex64::new(0.0, 0.0),
377                Complex64::new(1.0, 0.0),
378            ],
379        )
380        .unwrap();
381
382        let matrix = DenseMatrix::new(data).unwrap();
383        assert_eq!(matrix.dim(), 2);
384    }
385
386    #[test]
387    fn test_unitary_check() {
388        // Hadamard gate
389        let sqrt2 = 1.0 / 2.0_f64.sqrt();
390        let data = Array2::from_shape_vec(
391            (2, 2),
392            vec![
393                Complex64::new(sqrt2, 0.0),
394                Complex64::new(sqrt2, 0.0),
395                Complex64::new(sqrt2, 0.0),
396                Complex64::new(-sqrt2, 0.0),
397            ],
398        )
399        .unwrap();
400
401        let matrix = DenseMatrix::new(data).unwrap();
402        assert!(matrix.is_unitary(1e-10).unwrap());
403    }
404
405    #[test]
406    fn test_tensor_product() {
407        // Identity ⊗ Pauli-X
408        let id = DenseMatrix::new(
409            Array2::from_shape_vec(
410                (2, 2),
411                vec![
412                    Complex64::new(1.0, 0.0),
413                    Complex64::new(0.0, 0.0),
414                    Complex64::new(0.0, 0.0),
415                    Complex64::new(1.0, 0.0),
416                ],
417            )
418            .unwrap(),
419        )
420        .unwrap();
421
422        let x = DenseMatrix::new(
423            Array2::from_shape_vec(
424                (2, 2),
425                vec![
426                    Complex64::new(0.0, 0.0),
427                    Complex64::new(1.0, 0.0),
428                    Complex64::new(1.0, 0.0),
429                    Complex64::new(0.0, 0.0),
430                ],
431            )
432            .unwrap(),
433        )
434        .unwrap();
435
436        let result = id.tensor_product(&x).unwrap();
437        assert_eq!(result.shape(), &[4, 4]);
438
439        // Check specific values
440        assert_eq!(result[[0, 1]], Complex64::new(1.0, 0.0));
441        assert_eq!(result[[2, 3]], Complex64::new(1.0, 0.0));
442    }
443}