quantrs2_core/
matrix_ops.rs

1//! Matrix operations for quantum gates using SciRS2
2//!
3//! This module provides efficient matrix operations for quantum computing,
4//! including sparse/dense conversions, tensor products, and specialized
5//! quantum operations.
6
7use crate::error::{QuantRS2Error, QuantRS2Result};
8use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
9use num_complex::Complex64;
10use scirs2_linalg::{det, inv};
11use scirs2_sparse::csr::CsrMatrix;
12use std::fmt::Debug;
13
14/// Trait for quantum matrix operations
15pub trait QuantumMatrix: Debug + Send + Sync {
16    /// Get the dimension of the matrix (assumed square)
17    fn dim(&self) -> usize;
18
19    /// Convert to dense representation
20    fn to_dense(&self) -> Array2<Complex64>;
21
22    /// Convert to sparse representation
23    fn to_sparse(&self) -> QuantRS2Result<CsrMatrix<Complex64>>;
24
25    /// Check if the matrix is unitary
26    fn is_unitary(&self, tolerance: f64) -> QuantRS2Result<bool>;
27
28    /// Compute the tensor product with another matrix
29    fn tensor_product(&self, other: &dyn QuantumMatrix) -> QuantRS2Result<Array2<Complex64>>;
30
31    /// Apply to a state vector
32    fn apply(&self, state: &ArrayView1<Complex64>) -> QuantRS2Result<Array1<Complex64>>;
33}
34
35/// Dense matrix representation
36#[derive(Debug, Clone)]
37pub struct DenseMatrix {
38    data: Array2<Complex64>,
39}
40
41impl DenseMatrix {
42    /// Create a new dense matrix
43    pub fn new(data: Array2<Complex64>) -> QuantRS2Result<Self> {
44        if data.nrows() != data.ncols() {
45            return Err(QuantRS2Error::InvalidInput(
46                "Matrix must be square".to_string(),
47            ));
48        }
49        Ok(Self { data })
50    }
51
52    /// Create from a flat vector (column-major order)
53    pub fn from_vec(data: Vec<Complex64>, dim: usize) -> QuantRS2Result<Self> {
54        if data.len() != dim * dim {
55            return Err(QuantRS2Error::InvalidInput(format!(
56                "Expected {} elements, got {}",
57                dim * dim,
58                data.len()
59            )));
60        }
61        let matrix = Array2::from_shape_vec((dim, dim), data)
62            .map_err(|e| QuantRS2Error::InvalidInput(e.to_string()))?;
63        Self::new(matrix)
64    }
65
66    /// Get a reference to the underlying array
67    pub fn as_array(&self) -> &Array2<Complex64> {
68        &self.data
69    }
70
71    /// Check if matrix is hermitian
72    pub fn is_hermitian(&self, tolerance: f64) -> bool {
73        let n = self.data.nrows();
74        for i in 0..n {
75            for j in i..n {
76                let diff = (self.data[[i, j]] - self.data[[j, i]].conj()).norm();
77                if diff > tolerance {
78                    return false;
79                }
80            }
81        }
82        true
83    }
84}
85
86impl QuantumMatrix for DenseMatrix {
87    fn dim(&self) -> usize {
88        self.data.nrows()
89    }
90
91    fn to_dense(&self) -> Array2<Complex64> {
92        self.data.clone()
93    }
94
95    fn to_sparse(&self) -> QuantRS2Result<CsrMatrix<Complex64>> {
96        let n = self.dim();
97        let mut rows = Vec::new();
98        let mut cols = Vec::new();
99        let mut data = Vec::new();
100
101        let tolerance = 1e-14;
102        for i in 0..n {
103            for j in 0..n {
104                let val = self.data[[i, j]];
105                if val.norm() > tolerance {
106                    rows.push(i);
107                    cols.push(j);
108                    data.push(val);
109                }
110            }
111        }
112
113        CsrMatrix::new(data, rows, cols, (n, n))
114            .map_err(|e| QuantRS2Error::InvalidInput(e.to_string()))
115    }
116
117    fn is_unitary(&self, tolerance: f64) -> QuantRS2Result<bool> {
118        let n = self.dim();
119        let conj_transpose = self.data.t().mapv(|x| x.conj());
120        let product = self.data.dot(&conj_transpose);
121
122        // Check if product is identity
123        for i in 0..n {
124            for j in 0..n {
125                let expected = if i == j {
126                    Complex64::new(1.0, 0.0)
127                } else {
128                    Complex64::new(0.0, 0.0)
129                };
130                let diff = (product[[i, j]] - expected).norm();
131                if diff > tolerance {
132                    return Ok(false);
133                }
134            }
135        }
136        Ok(true)
137    }
138
139    fn tensor_product(&self, other: &dyn QuantumMatrix) -> QuantRS2Result<Array2<Complex64>> {
140        let other_dense = other.to_dense();
141        let n1 = self.dim();
142        let n2 = other_dense.nrows();
143        let n = n1 * n2;
144
145        let mut result = Array2::zeros((n, n));
146
147        for i1 in 0..n1 {
148            for j1 in 0..n1 {
149                let val1 = self.data[[i1, j1]];
150                for i2 in 0..n2 {
151                    for j2 in 0..n2 {
152                        let val2 = other_dense[[i2, j2]];
153                        result[[i1 * n2 + i2, j1 * n2 + j2]] = val1 * val2;
154                    }
155                }
156            }
157        }
158
159        Ok(result)
160    }
161
162    fn apply(&self, state: &ArrayView1<Complex64>) -> QuantRS2Result<Array1<Complex64>> {
163        if state.len() != self.dim() {
164            return Err(QuantRS2Error::InvalidInput(format!(
165                "State dimension {} doesn't match matrix dimension {}",
166                state.len(),
167                self.dim()
168            )));
169        }
170        Ok(self.data.dot(state))
171    }
172}
173
174/// Sparse matrix representation for quantum gates
175#[derive(Clone)]
176pub struct SparseMatrix {
177    csr: CsrMatrix<Complex64>,
178    dim: usize,
179}
180
181impl Debug for SparseMatrix {
182    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183        f.debug_struct("SparseMatrix")
184            .field("dim", &self.dim)
185            .field("nnz", &self.csr.nnz())
186            .finish()
187    }
188}
189
190impl SparseMatrix {
191    /// Create a new sparse matrix
192    pub fn new(csr: CsrMatrix<Complex64>) -> QuantRS2Result<Self> {
193        let (rows, cols) = csr.shape();
194        if rows != cols {
195            return Err(QuantRS2Error::InvalidInput(
196                "Matrix must be square".to_string(),
197            ));
198        }
199        Ok(Self { csr, dim: rows })
200    }
201
202    /// Create from triplets
203    pub fn from_triplets(
204        rows: Vec<usize>,
205        cols: Vec<usize>,
206        data: Vec<Complex64>,
207        dim: usize,
208    ) -> QuantRS2Result<Self> {
209        let csr = CsrMatrix::new(data, rows, cols, (dim, dim))
210            .map_err(|e| QuantRS2Error::InvalidInput(e.to_string()))?;
211        Self::new(csr)
212    }
213}
214
215impl QuantumMatrix for SparseMatrix {
216    fn dim(&self) -> usize {
217        self.dim
218    }
219
220    fn to_dense(&self) -> Array2<Complex64> {
221        let dense_vec = self.csr.to_dense();
222        let rows = dense_vec.len();
223        let cols = if rows > 0 { dense_vec[0].len() } else { 0 };
224
225        let mut flat = Vec::with_capacity(rows * cols);
226        for row in dense_vec {
227            flat.extend(row);
228        }
229
230        Array2::from_shape_vec((rows, cols), flat).unwrap()
231    }
232
233    fn to_sparse(&self) -> QuantRS2Result<CsrMatrix<Complex64>> {
234        Ok(self.csr.clone())
235    }
236
237    fn is_unitary(&self, tolerance: f64) -> QuantRS2Result<bool> {
238        // Convert to dense for unitary check
239        let dense = DenseMatrix::new(self.to_dense())?;
240        dense.is_unitary(tolerance)
241    }
242
243    fn tensor_product(&self, other: &dyn QuantumMatrix) -> QuantRS2Result<Array2<Complex64>> {
244        // For now, convert to dense and compute
245        // TODO: Implement sparse tensor product
246        let dense = DenseMatrix::new(self.to_dense())?;
247        dense.tensor_product(other)
248    }
249
250    fn apply(&self, state: &ArrayView1<Complex64>) -> QuantRS2Result<Array1<Complex64>> {
251        if state.len() != self.dim() {
252            return Err(QuantRS2Error::InvalidInput(format!(
253                "State dimension {} doesn't match matrix dimension {}",
254                state.len(),
255                self.dim()
256            )));
257        }
258        // Convert to dense and apply
259        let dense = self.to_dense();
260        Ok(dense.dot(state))
261    }
262}
263
264/// Compute the partial trace of a matrix
265pub fn partial_trace(
266    matrix: &Array2<Complex64>,
267    keep_qubits: &[usize],
268    total_qubits: usize,
269) -> QuantRS2Result<Array2<Complex64>> {
270    let full_dim = 1 << total_qubits;
271    if matrix.nrows() != full_dim || matrix.ncols() != full_dim {
272        return Err(QuantRS2Error::InvalidInput(format!(
273            "Matrix dimension {} doesn't match {} qubits",
274            matrix.nrows(),
275            total_qubits
276        )));
277    }
278
279    let keep_dim = 1 << keep_qubits.len();
280    let trace_qubits: Vec<usize> = (0..total_qubits)
281        .filter(|q| !keep_qubits.contains(q))
282        .collect();
283    let trace_dim = 1 << trace_qubits.len();
284
285    let mut result = Array2::zeros((keep_dim, keep_dim));
286
287    // Iterate over all basis states
288    for i in 0..keep_dim {
289        for j in 0..keep_dim {
290            let mut sum = Complex64::new(0.0, 0.0);
291
292            // Sum over traced out qubits
293            for t in 0..trace_dim {
294                let row_idx = reconstruct_index(i, t, keep_qubits, &trace_qubits, total_qubits);
295                let col_idx = reconstruct_index(j, t, keep_qubits, &trace_qubits, total_qubits);
296                sum += matrix[[row_idx, col_idx]];
297            }
298
299            result[[i, j]] = sum;
300        }
301    }
302
303    Ok(result)
304}
305
306/// Helper function to reconstruct full index from partial indices
307fn reconstruct_index(
308    keep_idx: usize,
309    trace_idx: usize,
310    keep_qubits: &[usize],
311    trace_qubits: &[usize],
312    total_qubits: usize,
313) -> usize {
314    let mut index = 0;
315
316    // Set bits for kept qubits
317    for (i, &q) in keep_qubits.iter().enumerate() {
318        if (keep_idx >> i) & 1 == 1 {
319            index |= 1 << q;
320        }
321    }
322
323    // Set bits for traced qubits
324    for (i, &q) in trace_qubits.iter().enumerate() {
325        if (trace_idx >> i) & 1 == 1 {
326            index |= 1 << q;
327        }
328    }
329
330    index
331}
332
333/// Compute the tensor product of multiple matrices efficiently
334pub fn tensor_product_many(matrices: &[&dyn QuantumMatrix]) -> QuantRS2Result<Array2<Complex64>> {
335    if matrices.is_empty() {
336        return Err(QuantRS2Error::InvalidInput(
337            "Cannot compute tensor product of empty list".to_string(),
338        ));
339    }
340
341    if matrices.len() == 1 {
342        return Ok(matrices[0].to_dense());
343    }
344
345    let mut result = matrices[0].to_dense();
346    for matrix in matrices.iter().skip(1) {
347        let dense_result = DenseMatrix::new(result)?;
348        result = dense_result.tensor_product(*matrix)?;
349    }
350
351    Ok(result)
352}
353
354/// Check if two matrices are approximately equal
355pub fn matrices_approx_equal(
356    a: &ArrayView2<Complex64>,
357    b: &ArrayView2<Complex64>,
358    tolerance: f64,
359) -> bool {
360    if a.shape() != b.shape() {
361        return false;
362    }
363
364    for (x, y) in a.iter().zip(b.iter()) {
365        if (x - y).norm() > tolerance {
366            return false;
367        }
368    }
369
370    true
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376    use num_complex::Complex64;
377
378    #[test]
379    fn test_dense_matrix_creation() {
380        let data = Array2::from_shape_vec(
381            (2, 2),
382            vec![
383                Complex64::new(1.0, 0.0),
384                Complex64::new(0.0, 0.0),
385                Complex64::new(0.0, 0.0),
386                Complex64::new(1.0, 0.0),
387            ],
388        )
389        .unwrap();
390
391        let matrix = DenseMatrix::new(data).unwrap();
392        assert_eq!(matrix.dim(), 2);
393    }
394
395    #[test]
396    fn test_unitary_check() {
397        // Hadamard gate
398        let sqrt2 = 1.0 / 2.0_f64.sqrt();
399        let data = Array2::from_shape_vec(
400            (2, 2),
401            vec![
402                Complex64::new(sqrt2, 0.0),
403                Complex64::new(sqrt2, 0.0),
404                Complex64::new(sqrt2, 0.0),
405                Complex64::new(-sqrt2, 0.0),
406            ],
407        )
408        .unwrap();
409
410        let matrix = DenseMatrix::new(data).unwrap();
411        assert!(matrix.is_unitary(1e-10).unwrap());
412    }
413
414    #[test]
415    fn test_tensor_product() {
416        // Identity ⊗ Pauli-X
417        let id = DenseMatrix::new(
418            Array2::from_shape_vec(
419                (2, 2),
420                vec![
421                    Complex64::new(1.0, 0.0),
422                    Complex64::new(0.0, 0.0),
423                    Complex64::new(0.0, 0.0),
424                    Complex64::new(1.0, 0.0),
425                ],
426            )
427            .unwrap(),
428        )
429        .unwrap();
430
431        let x = DenseMatrix::new(
432            Array2::from_shape_vec(
433                (2, 2),
434                vec![
435                    Complex64::new(0.0, 0.0),
436                    Complex64::new(1.0, 0.0),
437                    Complex64::new(1.0, 0.0),
438                    Complex64::new(0.0, 0.0),
439                ],
440            )
441            .unwrap(),
442        )
443        .unwrap();
444
445        let result = id.tensor_product(&x).unwrap();
446        assert_eq!(result.shape(), &[4, 4]);
447
448        // Check specific values
449        assert_eq!(result[[0, 1]], Complex64::new(1.0, 0.0));
450        assert_eq!(result[[2, 3]], Complex64::new(1.0, 0.0));
451    }
452}