quantrs2_core/
matrix_ops.rs

1//! Matrix operations for quantum gates using SciRS2
2//!
3//! This module provides efficient matrix operations for quantum computing,
4//! including sparse/dense conversions, tensor products, and specialized
5//! quantum operations.
6
7use crate::error::{QuantRS2Error, QuantRS2Result};
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
9use scirs2_core::Complex64;
10// use scirs2_sparse::csr::CsrMatrix;
11use crate::linalg_stubs::CsrMatrix;
12use std::fmt::Debug;
13
14/// Trait for quantum matrix operations
15pub trait QuantumMatrix: Debug + Send + Sync {
16    /// Get the dimension of the matrix (assumed square)
17    fn dim(&self) -> usize;
18
19    /// Convert to dense representation
20    fn to_dense(&self) -> Array2<Complex64>;
21
22    /// Convert to sparse representation
23    fn to_sparse(&self) -> QuantRS2Result<CsrMatrix<Complex64>>;
24
25    /// Check if the matrix is unitary
26    fn is_unitary(&self, tolerance: f64) -> QuantRS2Result<bool>;
27
28    /// Compute the tensor product with another matrix
29    fn tensor_product(&self, other: &dyn QuantumMatrix) -> QuantRS2Result<Array2<Complex64>>;
30
31    /// Apply to a state vector
32    fn apply(&self, state: &ArrayView1<Complex64>) -> QuantRS2Result<Array1<Complex64>>;
33}
34
35/// Dense matrix representation
36#[derive(Debug, Clone)]
37pub struct DenseMatrix {
38    data: Array2<Complex64>,
39}
40
41impl DenseMatrix {
42    /// Create a new dense matrix
43    pub fn new(data: Array2<Complex64>) -> QuantRS2Result<Self> {
44        if data.nrows() != data.ncols() {
45            return Err(QuantRS2Error::InvalidInput(
46                "Matrix must be square".to_string(),
47            ));
48        }
49        Ok(Self { data })
50    }
51
52    /// Create from a flat vector (column-major order)
53    pub fn from_vec(data: Vec<Complex64>, dim: usize) -> QuantRS2Result<Self> {
54        if data.len() != dim * dim {
55            return Err(QuantRS2Error::InvalidInput(format!(
56                "Expected {} elements, got {}",
57                dim * dim,
58                data.len()
59            )));
60        }
61        let matrix = Array2::from_shape_vec((dim, dim), data)
62            .map_err(|e| QuantRS2Error::InvalidInput(e.to_string()))?;
63        Self::new(matrix)
64    }
65
66    /// Get a reference to the underlying array
67    pub const fn as_array(&self) -> &Array2<Complex64> {
68        &self.data
69    }
70
71    /// Check if matrix is hermitian
72    pub fn is_hermitian(&self, tolerance: f64) -> bool {
73        let n = self.data.nrows();
74        for i in 0..n {
75            for j in i..n {
76                let diff = (self.data[[i, j]] - self.data[[j, i]].conj()).norm();
77                if diff > tolerance {
78                    return false;
79                }
80            }
81        }
82        true
83    }
84}
85
86impl QuantumMatrix for DenseMatrix {
87    fn dim(&self) -> usize {
88        self.data.nrows()
89    }
90
91    fn to_dense(&self) -> Array2<Complex64> {
92        self.data.clone()
93    }
94
95    fn to_sparse(&self) -> QuantRS2Result<CsrMatrix<Complex64>> {
96        let n = self.dim();
97        let mut rows = Vec::new();
98        let mut cols = Vec::new();
99        let mut data = Vec::new();
100
101        let tolerance = 1e-14;
102        for i in 0..n {
103            for j in 0..n {
104                let val = self.data[[i, j]];
105                if val.norm() > tolerance {
106                    rows.push(i);
107                    cols.push(j);
108                    data.push(val);
109                }
110            }
111        }
112
113        CsrMatrix::new(data, rows, cols, (n, n)).map_err(|e| QuantRS2Error::InvalidInput(e))
114    }
115
116    fn is_unitary(&self, tolerance: f64) -> QuantRS2Result<bool> {
117        let n = self.dim();
118        let conj_transpose = self.data.t().mapv(|x| x.conj());
119        let product = self.data.dot(&conj_transpose);
120
121        // Check if product is identity
122        for i in 0..n {
123            for j in 0..n {
124                let expected = if i == j {
125                    Complex64::new(1.0, 0.0)
126                } else {
127                    Complex64::new(0.0, 0.0)
128                };
129                let diff = (product[[i, j]] - expected).norm();
130                if diff > tolerance {
131                    return Ok(false);
132                }
133            }
134        }
135        Ok(true)
136    }
137
138    fn tensor_product(&self, other: &dyn QuantumMatrix) -> QuantRS2Result<Array2<Complex64>> {
139        let other_dense = other.to_dense();
140        let n1 = self.dim();
141        let n2 = other_dense.nrows();
142        let n = n1 * n2;
143
144        let mut result = Array2::zeros((n, n));
145
146        for i1 in 0..n1 {
147            for j1 in 0..n1 {
148                let val1 = self.data[[i1, j1]];
149                for i2 in 0..n2 {
150                    for j2 in 0..n2 {
151                        let val2 = other_dense[[i2, j2]];
152                        result[[i1 * n2 + i2, j1 * n2 + j2]] = val1 * val2;
153                    }
154                }
155            }
156        }
157
158        Ok(result)
159    }
160
161    fn apply(&self, state: &ArrayView1<Complex64>) -> QuantRS2Result<Array1<Complex64>> {
162        if state.len() != self.dim() {
163            return Err(QuantRS2Error::InvalidInput(format!(
164                "State dimension {} doesn't match matrix dimension {}",
165                state.len(),
166                self.dim()
167            )));
168        }
169        Ok(self.data.dot(state))
170    }
171}
172
173/// Sparse matrix representation for quantum gates
174#[derive(Clone)]
175pub struct SparseMatrix {
176    csr: CsrMatrix<Complex64>,
177    dim: usize,
178}
179
180impl Debug for SparseMatrix {
181    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182        f.debug_struct("SparseMatrix")
183            .field("dim", &self.dim)
184            .field("nnz", &self.csr.nnz())
185            .finish()
186    }
187}
188
189impl SparseMatrix {
190    /// Create a new sparse matrix
191    pub fn new(csr: CsrMatrix<Complex64>) -> QuantRS2Result<Self> {
192        let (rows, cols) = csr.shape();
193        if rows != cols {
194            return Err(QuantRS2Error::InvalidInput(
195                "Matrix must be square".to_string(),
196            ));
197        }
198        Ok(Self { csr, dim: rows })
199    }
200
201    /// Create from triplets
202    pub fn from_triplets(
203        rows: Vec<usize>,
204        cols: Vec<usize>,
205        data: Vec<Complex64>,
206        dim: usize,
207    ) -> QuantRS2Result<Self> {
208        let csr = CsrMatrix::new(data, rows, cols, (dim, dim))
209            .map_err(|e| QuantRS2Error::InvalidInput(e))?;
210        Self::new(csr)
211    }
212}
213
214impl QuantumMatrix for SparseMatrix {
215    fn dim(&self) -> usize {
216        self.dim
217    }
218
219    fn to_dense(&self) -> Array2<Complex64> {
220        self.csr.to_dense()
221    }
222
223    fn to_sparse(&self) -> QuantRS2Result<CsrMatrix<Complex64>> {
224        Ok(self.csr.clone())
225    }
226
227    fn is_unitary(&self, tolerance: f64) -> QuantRS2Result<bool> {
228        // Convert to dense for unitary check
229        let dense = DenseMatrix::new(self.to_dense())?;
230        dense.is_unitary(tolerance)
231    }
232
233    fn tensor_product(&self, other: &dyn QuantumMatrix) -> QuantRS2Result<Array2<Complex64>> {
234        // For now, convert to dense and compute
235        // TODO: Implement sparse tensor product
236        let dense = DenseMatrix::new(self.to_dense())?;
237        dense.tensor_product(other)
238    }
239
240    fn apply(&self, state: &ArrayView1<Complex64>) -> QuantRS2Result<Array1<Complex64>> {
241        if state.len() != self.dim() {
242            return Err(QuantRS2Error::InvalidInput(format!(
243                "State dimension {} doesn't match matrix dimension {}",
244                state.len(),
245                self.dim()
246            )));
247        }
248        // Convert to dense and apply
249        let dense = self.to_dense();
250        Ok(dense.dot(state))
251    }
252}
253
254/// Compute the partial trace of a matrix
255pub fn partial_trace(
256    matrix: &Array2<Complex64>,
257    keep_qubits: &[usize],
258    total_qubits: usize,
259) -> QuantRS2Result<Array2<Complex64>> {
260    let full_dim = 1 << total_qubits;
261    if matrix.nrows() != full_dim || matrix.ncols() != full_dim {
262        return Err(QuantRS2Error::InvalidInput(format!(
263            "Matrix dimension {} doesn't match {} qubits",
264            matrix.nrows(),
265            total_qubits
266        )));
267    }
268
269    let keep_dim = 1 << keep_qubits.len();
270    let trace_qubits: Vec<usize> = (0..total_qubits)
271        .filter(|q| !keep_qubits.contains(q))
272        .collect();
273    let trace_dim = 1 << trace_qubits.len();
274
275    let mut result = Array2::zeros((keep_dim, keep_dim));
276
277    // Iterate over all basis states
278    for i in 0..keep_dim {
279        for j in 0..keep_dim {
280            let mut sum = Complex64::new(0.0, 0.0);
281
282            // Sum over traced out qubits
283            for t in 0..trace_dim {
284                let row_idx = reconstruct_index(i, t, keep_qubits, &trace_qubits, total_qubits);
285                let col_idx = reconstruct_index(j, t, keep_qubits, &trace_qubits, total_qubits);
286                sum += matrix[[row_idx, col_idx]];
287            }
288
289            result[[i, j]] = sum;
290        }
291    }
292
293    Ok(result)
294}
295
296/// Helper function to reconstruct full index from partial indices
297fn reconstruct_index(
298    keep_idx: usize,
299    trace_idx: usize,
300    keep_qubits: &[usize],
301    trace_qubits: &[usize],
302    _total_qubits: usize,
303) -> usize {
304    let mut index = 0;
305
306    // Set bits for kept qubits
307    for (i, &q) in keep_qubits.iter().enumerate() {
308        if (keep_idx >> i) & 1 == 1 {
309            index |= 1 << q;
310        }
311    }
312
313    // Set bits for traced qubits
314    for (i, &q) in trace_qubits.iter().enumerate() {
315        if (trace_idx >> i) & 1 == 1 {
316            index |= 1 << q;
317        }
318    }
319
320    index
321}
322
323/// Compute the tensor product of multiple matrices efficiently
324pub fn tensor_product_many(matrices: &[&dyn QuantumMatrix]) -> QuantRS2Result<Array2<Complex64>> {
325    if matrices.is_empty() {
326        return Err(QuantRS2Error::InvalidInput(
327            "Cannot compute tensor product of empty list".to_string(),
328        ));
329    }
330
331    if matrices.len() == 1 {
332        return Ok(matrices[0].to_dense());
333    }
334
335    let mut result = matrices[0].to_dense();
336    for matrix in matrices.iter().skip(1) {
337        let dense_result = DenseMatrix::new(result)?;
338        result = dense_result.tensor_product(*matrix)?;
339    }
340
341    Ok(result)
342}
343
344/// Check if two matrices are approximately equal
345pub fn matrices_approx_equal(
346    a: &ArrayView2<Complex64>,
347    b: &ArrayView2<Complex64>,
348    tolerance: f64,
349) -> bool {
350    if a.shape() != b.shape() {
351        return false;
352    }
353
354    for (x, y) in a.iter().zip(b.iter()) {
355        if (x - y).norm() > tolerance {
356            return false;
357        }
358    }
359
360    true
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366    use scirs2_core::Complex64;
367
368    #[test]
369    fn test_dense_matrix_creation() {
370        let data = Array2::from_shape_vec(
371            (2, 2),
372            vec![
373                Complex64::new(1.0, 0.0),
374                Complex64::new(0.0, 0.0),
375                Complex64::new(0.0, 0.0),
376                Complex64::new(1.0, 0.0),
377            ],
378        )
379        .expect("Matrix data creation should succeed");
380
381        let matrix = DenseMatrix::new(data).expect("DenseMatrix creation should succeed");
382        assert_eq!(matrix.dim(), 2);
383    }
384
385    #[test]
386    fn test_unitary_check() {
387        // Hadamard gate
388        let sqrt2 = 1.0 / 2.0_f64.sqrt();
389        let data = Array2::from_shape_vec(
390            (2, 2),
391            vec![
392                Complex64::new(sqrt2, 0.0),
393                Complex64::new(sqrt2, 0.0),
394                Complex64::new(sqrt2, 0.0),
395                Complex64::new(-sqrt2, 0.0),
396            ],
397        )
398        .expect("Hadamard matrix data creation should succeed");
399
400        let matrix = DenseMatrix::new(data).expect("DenseMatrix creation should succeed");
401        assert!(matrix
402            .is_unitary(1e-10)
403            .expect("Unitary check should succeed"));
404    }
405
406    #[test]
407    fn test_tensor_product() {
408        // Identity ⊗ Pauli-X
409        let id = DenseMatrix::new(
410            Array2::from_shape_vec(
411                (2, 2),
412                vec![
413                    Complex64::new(1.0, 0.0),
414                    Complex64::new(0.0, 0.0),
415                    Complex64::new(0.0, 0.0),
416                    Complex64::new(1.0, 0.0),
417                ],
418            )
419            .expect("Identity matrix data creation should succeed"),
420        )
421        .expect("Identity DenseMatrix creation should succeed");
422
423        let x = DenseMatrix::new(
424            Array2::from_shape_vec(
425                (2, 2),
426                vec![
427                    Complex64::new(0.0, 0.0),
428                    Complex64::new(1.0, 0.0),
429                    Complex64::new(1.0, 0.0),
430                    Complex64::new(0.0, 0.0),
431                ],
432            )
433            .expect("Pauli-X matrix data creation should succeed"),
434        )
435        .expect("Pauli-X DenseMatrix creation should succeed");
436
437        let result = id
438            .tensor_product(&x)
439            .expect("Tensor product should succeed");
440        assert_eq!(result.shape(), &[4, 4]);
441
442        // Check specific values
443        assert_eq!(result[[0, 1]], Complex64::new(1.0, 0.0));
444        assert_eq!(result[[2, 3]], Complex64::new(1.0, 0.0));
445    }
446}