quantrs2_core/
matrix_ops.rs

1//! Matrix operations for quantum gates using SciRS2
2//!
3//! This module provides efficient matrix operations for quantum computing,
4//! including sparse/dense conversions, tensor products, and specialized
5//! quantum operations.
6
7use crate::error::{QuantRS2Error, QuantRS2Result};
8use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
9use num_complex::Complex64;
10use scirs2_sparse::csr::CsrMatrix;
11use std::fmt::Debug;
12
13/// Trait for quantum matrix operations
14pub trait QuantumMatrix: Debug + Send + Sync {
15    /// Get the dimension of the matrix (assumed square)
16    fn dim(&self) -> usize;
17
18    /// Convert to dense representation
19    fn to_dense(&self) -> Array2<Complex64>;
20
21    /// Convert to sparse representation
22    fn to_sparse(&self) -> QuantRS2Result<CsrMatrix<Complex64>>;
23
24    /// Check if the matrix is unitary
25    fn is_unitary(&self, tolerance: f64) -> QuantRS2Result<bool>;
26
27    /// Compute the tensor product with another matrix
28    fn tensor_product(&self, other: &dyn QuantumMatrix) -> QuantRS2Result<Array2<Complex64>>;
29
30    /// Apply to a state vector
31    fn apply(&self, state: &ArrayView1<Complex64>) -> QuantRS2Result<Array1<Complex64>>;
32}
33
34/// Dense matrix representation
35#[derive(Debug, Clone)]
36pub struct DenseMatrix {
37    data: Array2<Complex64>,
38}
39
40impl DenseMatrix {
41    /// Create a new dense matrix
42    pub fn new(data: Array2<Complex64>) -> QuantRS2Result<Self> {
43        if data.nrows() != data.ncols() {
44            return Err(QuantRS2Error::InvalidInput(
45                "Matrix must be square".to_string(),
46            ));
47        }
48        Ok(Self { data })
49    }
50
51    /// Create from a flat vector (column-major order)
52    pub fn from_vec(data: Vec<Complex64>, dim: usize) -> QuantRS2Result<Self> {
53        if data.len() != dim * dim {
54            return Err(QuantRS2Error::InvalidInput(format!(
55                "Expected {} elements, got {}",
56                dim * dim,
57                data.len()
58            )));
59        }
60        let matrix = Array2::from_shape_vec((dim, dim), data)
61            .map_err(|e| QuantRS2Error::InvalidInput(e.to_string()))?;
62        Self::new(matrix)
63    }
64
65    /// Get a reference to the underlying array
66    pub fn as_array(&self) -> &Array2<Complex64> {
67        &self.data
68    }
69
70    /// Check if matrix is hermitian
71    pub fn is_hermitian(&self, tolerance: f64) -> bool {
72        let n = self.data.nrows();
73        for i in 0..n {
74            for j in i..n {
75                let diff = (self.data[[i, j]] - self.data[[j, i]].conj()).norm();
76                if diff > tolerance {
77                    return false;
78                }
79            }
80        }
81        true
82    }
83}
84
85impl QuantumMatrix for DenseMatrix {
86    fn dim(&self) -> usize {
87        self.data.nrows()
88    }
89
90    fn to_dense(&self) -> Array2<Complex64> {
91        self.data.clone()
92    }
93
94    fn to_sparse(&self) -> QuantRS2Result<CsrMatrix<Complex64>> {
95        let n = self.dim();
96        let mut rows = Vec::new();
97        let mut cols = Vec::new();
98        let mut data = Vec::new();
99
100        let tolerance = 1e-14;
101        for i in 0..n {
102            for j in 0..n {
103                let val = self.data[[i, j]];
104                if val.norm() > tolerance {
105                    rows.push(i);
106                    cols.push(j);
107                    data.push(val);
108                }
109            }
110        }
111
112        CsrMatrix::new(data, rows, cols, (n, n))
113            .map_err(|e| QuantRS2Error::InvalidInput(e.to_string()))
114    }
115
116    fn is_unitary(&self, tolerance: f64) -> QuantRS2Result<bool> {
117        let n = self.dim();
118        let conj_transpose = self.data.t().mapv(|x| x.conj());
119        let product = self.data.dot(&conj_transpose);
120
121        // Check if product is identity
122        for i in 0..n {
123            for j in 0..n {
124                let expected = if i == j {
125                    Complex64::new(1.0, 0.0)
126                } else {
127                    Complex64::new(0.0, 0.0)
128                };
129                let diff = (product[[i, j]] - expected).norm();
130                if diff > tolerance {
131                    return Ok(false);
132                }
133            }
134        }
135        Ok(true)
136    }
137
138    fn tensor_product(&self, other: &dyn QuantumMatrix) -> QuantRS2Result<Array2<Complex64>> {
139        let other_dense = other.to_dense();
140        let n1 = self.dim();
141        let n2 = other_dense.nrows();
142        let n = n1 * n2;
143
144        let mut result = Array2::zeros((n, n));
145
146        for i1 in 0..n1 {
147            for j1 in 0..n1 {
148                let val1 = self.data[[i1, j1]];
149                for i2 in 0..n2 {
150                    for j2 in 0..n2 {
151                        let val2 = other_dense[[i2, j2]];
152                        result[[i1 * n2 + i2, j1 * n2 + j2]] = val1 * val2;
153                    }
154                }
155            }
156        }
157
158        Ok(result)
159    }
160
161    fn apply(&self, state: &ArrayView1<Complex64>) -> QuantRS2Result<Array1<Complex64>> {
162        if state.len() != self.dim() {
163            return Err(QuantRS2Error::InvalidInput(format!(
164                "State dimension {} doesn't match matrix dimension {}",
165                state.len(),
166                self.dim()
167            )));
168        }
169        Ok(self.data.dot(state))
170    }
171}
172
173/// Sparse matrix representation for quantum gates
174#[derive(Clone)]
175pub struct SparseMatrix {
176    csr: CsrMatrix<Complex64>,
177    dim: usize,
178}
179
180impl Debug for SparseMatrix {
181    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182        f.debug_struct("SparseMatrix")
183            .field("dim", &self.dim)
184            .field("nnz", &self.csr.nnz())
185            .finish()
186    }
187}
188
189impl SparseMatrix {
190    /// Create a new sparse matrix
191    pub fn new(csr: CsrMatrix<Complex64>) -> QuantRS2Result<Self> {
192        let (rows, cols) = csr.shape();
193        if rows != cols {
194            return Err(QuantRS2Error::InvalidInput(
195                "Matrix must be square".to_string(),
196            ));
197        }
198        Ok(Self { csr, dim: rows })
199    }
200
201    /// Create from triplets
202    pub fn from_triplets(
203        rows: Vec<usize>,
204        cols: Vec<usize>,
205        data: Vec<Complex64>,
206        dim: usize,
207    ) -> QuantRS2Result<Self> {
208        let csr = CsrMatrix::new(data, rows, cols, (dim, dim))
209            .map_err(|e| QuantRS2Error::InvalidInput(e.to_string()))?;
210        Self::new(csr)
211    }
212}
213
214impl QuantumMatrix for SparseMatrix {
215    fn dim(&self) -> usize {
216        self.dim
217    }
218
219    fn to_dense(&self) -> Array2<Complex64> {
220        let dense_vec = self.csr.to_dense();
221        let rows = dense_vec.len();
222        let cols = if rows > 0 { dense_vec[0].len() } else { 0 };
223
224        let mut flat = Vec::with_capacity(rows * cols);
225        for row in dense_vec {
226            flat.extend(row);
227        }
228
229        Array2::from_shape_vec((rows, cols), flat).unwrap()
230    }
231
232    fn to_sparse(&self) -> QuantRS2Result<CsrMatrix<Complex64>> {
233        Ok(self.csr.clone())
234    }
235
236    fn is_unitary(&self, tolerance: f64) -> QuantRS2Result<bool> {
237        // Convert to dense for unitary check
238        let dense = DenseMatrix::new(self.to_dense())?;
239        dense.is_unitary(tolerance)
240    }
241
242    fn tensor_product(&self, other: &dyn QuantumMatrix) -> QuantRS2Result<Array2<Complex64>> {
243        // For now, convert to dense and compute
244        // TODO: Implement sparse tensor product
245        let dense = DenseMatrix::new(self.to_dense())?;
246        dense.tensor_product(other)
247    }
248
249    fn apply(&self, state: &ArrayView1<Complex64>) -> QuantRS2Result<Array1<Complex64>> {
250        if state.len() != self.dim() {
251            return Err(QuantRS2Error::InvalidInput(format!(
252                "State dimension {} doesn't match matrix dimension {}",
253                state.len(),
254                self.dim()
255            )));
256        }
257        // Convert to dense and apply
258        let dense = self.to_dense();
259        Ok(dense.dot(state))
260    }
261}
262
263/// Compute the partial trace of a matrix
264pub fn partial_trace(
265    matrix: &Array2<Complex64>,
266    keep_qubits: &[usize],
267    total_qubits: usize,
268) -> QuantRS2Result<Array2<Complex64>> {
269    let full_dim = 1 << total_qubits;
270    if matrix.nrows() != full_dim || matrix.ncols() != full_dim {
271        return Err(QuantRS2Error::InvalidInput(format!(
272            "Matrix dimension {} doesn't match {} qubits",
273            matrix.nrows(),
274            total_qubits
275        )));
276    }
277
278    let keep_dim = 1 << keep_qubits.len();
279    let trace_qubits: Vec<usize> = (0..total_qubits)
280        .filter(|q| !keep_qubits.contains(q))
281        .collect();
282    let trace_dim = 1 << trace_qubits.len();
283
284    let mut result = Array2::zeros((keep_dim, keep_dim));
285
286    // Iterate over all basis states
287    for i in 0..keep_dim {
288        for j in 0..keep_dim {
289            let mut sum = Complex64::new(0.0, 0.0);
290
291            // Sum over traced out qubits
292            for t in 0..trace_dim {
293                let row_idx = reconstruct_index(i, t, keep_qubits, &trace_qubits, total_qubits);
294                let col_idx = reconstruct_index(j, t, keep_qubits, &trace_qubits, total_qubits);
295                sum += matrix[[row_idx, col_idx]];
296            }
297
298            result[[i, j]] = sum;
299        }
300    }
301
302    Ok(result)
303}
304
305/// Helper function to reconstruct full index from partial indices
306fn reconstruct_index(
307    keep_idx: usize,
308    trace_idx: usize,
309    keep_qubits: &[usize],
310    trace_qubits: &[usize],
311    _total_qubits: usize,
312) -> usize {
313    let mut index = 0;
314
315    // Set bits for kept qubits
316    for (i, &q) in keep_qubits.iter().enumerate() {
317        if (keep_idx >> i) & 1 == 1 {
318            index |= 1 << q;
319        }
320    }
321
322    // Set bits for traced qubits
323    for (i, &q) in trace_qubits.iter().enumerate() {
324        if (trace_idx >> i) & 1 == 1 {
325            index |= 1 << q;
326        }
327    }
328
329    index
330}
331
332/// Compute the tensor product of multiple matrices efficiently
333pub fn tensor_product_many(matrices: &[&dyn QuantumMatrix]) -> QuantRS2Result<Array2<Complex64>> {
334    if matrices.is_empty() {
335        return Err(QuantRS2Error::InvalidInput(
336            "Cannot compute tensor product of empty list".to_string(),
337        ));
338    }
339
340    if matrices.len() == 1 {
341        return Ok(matrices[0].to_dense());
342    }
343
344    let mut result = matrices[0].to_dense();
345    for matrix in matrices.iter().skip(1) {
346        let dense_result = DenseMatrix::new(result)?;
347        result = dense_result.tensor_product(*matrix)?;
348    }
349
350    Ok(result)
351}
352
353/// Check if two matrices are approximately equal
354pub fn matrices_approx_equal(
355    a: &ArrayView2<Complex64>,
356    b: &ArrayView2<Complex64>,
357    tolerance: f64,
358) -> bool {
359    if a.shape() != b.shape() {
360        return false;
361    }
362
363    for (x, y) in a.iter().zip(b.iter()) {
364        if (x - y).norm() > tolerance {
365            return false;
366        }
367    }
368
369    true
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375    use num_complex::Complex64;
376
377    #[test]
378    fn test_dense_matrix_creation() {
379        let data = Array2::from_shape_vec(
380            (2, 2),
381            vec![
382                Complex64::new(1.0, 0.0),
383                Complex64::new(0.0, 0.0),
384                Complex64::new(0.0, 0.0),
385                Complex64::new(1.0, 0.0),
386            ],
387        )
388        .unwrap();
389
390        let matrix = DenseMatrix::new(data).unwrap();
391        assert_eq!(matrix.dim(), 2);
392    }
393
394    #[test]
395    fn test_unitary_check() {
396        // Hadamard gate
397        let sqrt2 = 1.0 / 2.0_f64.sqrt();
398        let data = Array2::from_shape_vec(
399            (2, 2),
400            vec![
401                Complex64::new(sqrt2, 0.0),
402                Complex64::new(sqrt2, 0.0),
403                Complex64::new(sqrt2, 0.0),
404                Complex64::new(-sqrt2, 0.0),
405            ],
406        )
407        .unwrap();
408
409        let matrix = DenseMatrix::new(data).unwrap();
410        assert!(matrix.is_unitary(1e-10).unwrap());
411    }
412
413    #[test]
414    fn test_tensor_product() {
415        // Identity ⊗ Pauli-X
416        let id = DenseMatrix::new(
417            Array2::from_shape_vec(
418                (2, 2),
419                vec![
420                    Complex64::new(1.0, 0.0),
421                    Complex64::new(0.0, 0.0),
422                    Complex64::new(0.0, 0.0),
423                    Complex64::new(1.0, 0.0),
424                ],
425            )
426            .unwrap(),
427        )
428        .unwrap();
429
430        let x = DenseMatrix::new(
431            Array2::from_shape_vec(
432                (2, 2),
433                vec![
434                    Complex64::new(0.0, 0.0),
435                    Complex64::new(1.0, 0.0),
436                    Complex64::new(1.0, 0.0),
437                    Complex64::new(0.0, 0.0),
438                ],
439            )
440            .unwrap(),
441        )
442        .unwrap();
443
444        let result = id.tensor_product(&x).unwrap();
445        assert_eq!(result.shape(), &[4, 4]);
446
447        // Check specific values
448        assert_eq!(result[[0, 1]], Complex64::new(1.0, 0.0));
449        assert_eq!(result[[2, 3]], Complex64::new(1.0, 0.0));
450    }
451}