quantrs2_sim/
sparse.rs

1//! Sparse matrix operations for efficient quantum circuit simulation.
2//!
3//! This module provides sparse matrix representations and operations
4//! optimized for quantum gates, especially for circuits with limited connectivity.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::Complex64;
8use std::collections::HashMap;
9
10use crate::error::{Result, SimulatorError};
11
12/// Compressed Sparse Row (CSR) matrix format
13#[derive(Debug, Clone)]
14pub struct CSRMatrix {
15    /// Non-zero values
16    pub values: Vec<Complex64>,
17    /// Column indices for each value
18    pub col_indices: Vec<usize>,
19    /// Row pointer array
20    pub row_ptr: Vec<usize>,
21    /// Number of rows
22    pub num_rows: usize,
23    /// Number of columns
24    pub num_cols: usize,
25}
26
27impl CSRMatrix {
28    /// Create a new CSR matrix
29    #[must_use]
30    pub fn new(
31        values: Vec<Complex64>,
32        col_indices: Vec<usize>,
33        row_ptr: Vec<usize>,
34        num_rows: usize,
35        num_cols: usize,
36    ) -> Self {
37        assert_eq!(values.len(), col_indices.len());
38        assert_eq!(row_ptr.len(), num_rows + 1);
39
40        Self {
41            values,
42            col_indices,
43            row_ptr,
44            num_rows,
45            num_cols,
46        }
47    }
48
49    /// Create from a dense matrix
50    #[must_use]
51    pub fn from_dense(matrix: &Array2<Complex64>) -> Self {
52        let num_rows = matrix.nrows();
53        let num_cols = matrix.ncols();
54        let mut values = Vec::new();
55        let mut col_indices = Vec::new();
56        let mut row_ptr = vec![0];
57
58        for i in 0..num_rows {
59            for j in 0..num_cols {
60                let val = matrix[[i, j]];
61                if val.norm() > 1e-15 {
62                    values.push(val);
63                    col_indices.push(j);
64                }
65            }
66            row_ptr.push(values.len());
67        }
68
69        Self::new(values, col_indices, row_ptr, num_rows, num_cols)
70    }
71
72    /// Convert to dense matrix
73    #[must_use]
74    pub fn to_dense(&self) -> Array2<Complex64> {
75        let mut dense = Array2::zeros((self.num_rows, self.num_cols));
76
77        for i in 0..self.num_rows {
78            let start = self.row_ptr[i];
79            let end = self.row_ptr[i + 1];
80
81            for idx in start..end {
82                dense[[i, self.col_indices[idx]]] = self.values[idx];
83            }
84        }
85
86        dense
87    }
88
89    /// Get number of non-zero elements
90    #[must_use]
91    pub fn nnz(&self) -> usize {
92        self.values.len()
93    }
94
95    /// Matrix-vector multiplication
96    pub fn matvec(&self, vec: &Array1<Complex64>) -> Result<Array1<Complex64>> {
97        if vec.len() != self.num_cols {
98            return Err(SimulatorError::DimensionMismatch(format!(
99                "Vector length {} doesn't match matrix columns {}",
100                vec.len(),
101                self.num_cols
102            )));
103        }
104
105        let mut result = Array1::zeros(self.num_rows);
106
107        for i in 0..self.num_rows {
108            let start = self.row_ptr[i];
109            let end = self.row_ptr[i + 1];
110
111            let mut sum = Complex64::new(0.0, 0.0);
112            for idx in start..end {
113                sum += self.values[idx] * vec[self.col_indices[idx]];
114            }
115            result[i] = sum;
116        }
117
118        Ok(result)
119    }
120
121    /// Sparse matrix multiplication
122    pub fn matmul(&self, other: &Self) -> Result<Self> {
123        if self.num_cols != other.num_rows {
124            return Err(SimulatorError::DimensionMismatch(format!(
125                "Matrix dimensions incompatible: {}x{} * {}x{}",
126                self.num_rows, self.num_cols, other.num_rows, other.num_cols
127            )));
128        }
129
130        let mut values = Vec::new();
131        let mut col_indices = Vec::new();
132        let mut row_ptr = vec![0];
133
134        // Convert other to CSC for efficient column access
135        let other_csc = other.to_csc();
136
137        for i in 0..self.num_rows {
138            let mut row_values: HashMap<usize, Complex64> = HashMap::new();
139
140            let a_start = self.row_ptr[i];
141            let a_end = self.row_ptr[i + 1];
142
143            for a_idx in a_start..a_end {
144                let k = self.col_indices[a_idx];
145                let a_val = self.values[a_idx];
146
147                // Multiply row i of A with column k of B
148                let b_start = other_csc.col_ptr[k];
149                let b_end = other_csc.col_ptr[k + 1];
150
151                for b_idx in b_start..b_end {
152                    let j = other_csc.row_indices[b_idx];
153                    let b_val = other_csc.values[b_idx];
154
155                    *row_values.entry(j).or_insert(Complex64::new(0.0, 0.0)) += a_val * b_val;
156                }
157            }
158
159            // Sort by column index and add to result
160            let mut sorted_cols: Vec<_> = row_values.into_iter().collect();
161            sorted_cols.sort_by_key(|(col, _)| *col);
162
163            for (col, val) in sorted_cols {
164                if val.norm() > 1e-15 {
165                    values.push(val);
166                    col_indices.push(col);
167                }
168            }
169
170            row_ptr.push(values.len());
171        }
172
173        Ok(Self::new(
174            values,
175            col_indices,
176            row_ptr,
177            self.num_rows,
178            other.num_cols,
179        ))
180    }
181
182    /// Convert to Compressed Sparse Column (CSC) format
183    fn to_csc(&self) -> CSCMatrix {
184        let mut values = Vec::new();
185        let mut row_indices = Vec::new();
186        let mut col_ptr = vec![0; self.num_cols + 1];
187
188        // Count elements per column
189        for &col in &self.col_indices {
190            col_ptr[col + 1] += 1;
191        }
192
193        // Cumulative sum to get column pointers
194        for i in 1..=self.num_cols {
195            col_ptr[i] += col_ptr[i - 1];
196        }
197
198        // Temporary array to track current position in each column
199        let mut current_pos = col_ptr[0..self.num_cols].to_vec();
200        values.resize(self.nnz(), Complex64::new(0.0, 0.0));
201        row_indices.resize(self.nnz(), 0);
202
203        // Fill CSC arrays
204        for i in 0..self.num_rows {
205            let start = self.row_ptr[i];
206            let end = self.row_ptr[i + 1];
207
208            for idx in start..end {
209                let col = self.col_indices[idx];
210                let pos = current_pos[col];
211
212                values[pos] = self.values[idx];
213                row_indices[pos] = i;
214                current_pos[col] += 1;
215            }
216        }
217
218        CSCMatrix {
219            values,
220            row_indices,
221            col_ptr,
222            num_rows: self.num_rows,
223            num_cols: self.num_cols,
224        }
225    }
226}
227
228/// Compressed Sparse Column (CSC) matrix format
229#[derive(Debug, Clone)]
230struct CSCMatrix {
231    values: Vec<Complex64>,
232    row_indices: Vec<usize>,
233    col_ptr: Vec<usize>,
234    num_rows: usize,
235    num_cols: usize,
236}
237
238/// Sparse matrix builder for incremental construction
239#[derive(Debug)]
240pub struct SparseMatrixBuilder {
241    triplets: Vec<(usize, usize, Complex64)>,
242    num_rows: usize,
243    num_cols: usize,
244}
245
246impl SparseMatrixBuilder {
247    /// Create a new builder
248    #[must_use]
249    pub const fn new(num_rows: usize, num_cols: usize) -> Self {
250        Self {
251            triplets: Vec::new(),
252            num_rows,
253            num_cols,
254        }
255    }
256
257    /// Add an element to the matrix
258    pub fn add(&mut self, row: usize, col: usize, value: Complex64) {
259        if row < self.num_rows && col < self.num_cols && value.norm() > 1e-15 {
260            self.triplets.push((row, col, value));
261        }
262    }
263
264    /// Set value at specific position (alias for add)
265    pub fn set_value(&mut self, row: usize, col: usize, value: Complex64) {
266        self.add(row, col, value);
267    }
268
269    /// Build the CSR matrix
270    #[must_use]
271    pub fn build(mut self) -> CSRMatrix {
272        // Sort by row, then column
273        self.triplets.sort_by_key(|(r, c, _)| (*r, *c));
274
275        // Combine duplicates
276        let mut combined_triplets = Vec::new();
277        let mut last_pos: Option<(usize, usize)> = None;
278
279        for (r, c, v) in self.triplets {
280            if Some((r, c)) == last_pos {
281                if let Some(last) = combined_triplets.last_mut() {
282                    let (_, _, ref mut last_val) = last;
283                    *last_val += v;
284                }
285            } else {
286                combined_triplets.push((r, c, v));
287                last_pos = Some((r, c));
288            }
289        }
290
291        // Build CSR arrays
292        let mut values = Vec::new();
293        let mut col_indices = Vec::new();
294        let mut row_ptr = vec![0];
295        let mut current_row = 0;
296
297        for (r, c, v) in combined_triplets {
298            while current_row < r {
299                row_ptr.push(values.len());
300                current_row += 1;
301            }
302
303            if v.norm() > 1e-15 {
304                values.push(v);
305                col_indices.push(c);
306            }
307        }
308
309        while row_ptr.len() <= self.num_rows {
310            row_ptr.push(values.len());
311        }
312
313        CSRMatrix::new(values, col_indices, row_ptr, self.num_rows, self.num_cols)
314    }
315}
316
317/// Sparse quantum gate representations
318pub struct SparseGates;
319
320impl SparseGates {
321    /// Create sparse Pauli X gate
322    #[must_use]
323    pub fn x() -> CSRMatrix {
324        let mut builder = SparseMatrixBuilder::new(2, 2);
325        builder.add(0, 1, Complex64::new(1.0, 0.0));
326        builder.add(1, 0, Complex64::new(1.0, 0.0));
327        builder.build()
328    }
329
330    /// Create sparse Pauli Y gate
331    #[must_use]
332    pub fn y() -> CSRMatrix {
333        let mut builder = SparseMatrixBuilder::new(2, 2);
334        builder.add(0, 1, Complex64::new(0.0, -1.0));
335        builder.add(1, 0, Complex64::new(0.0, 1.0));
336        builder.build()
337    }
338
339    /// Create sparse Pauli Z gate
340    #[must_use]
341    pub fn z() -> CSRMatrix {
342        let mut builder = SparseMatrixBuilder::new(2, 2);
343        builder.add(0, 0, Complex64::new(1.0, 0.0));
344        builder.add(1, 1, Complex64::new(-1.0, 0.0));
345        builder.build()
346    }
347
348    /// Create sparse CNOT gate
349    #[must_use]
350    pub fn cnot() -> CSRMatrix {
351        let mut builder = SparseMatrixBuilder::new(4, 4);
352        builder.add(0, 0, Complex64::new(1.0, 0.0));
353        builder.add(1, 1, Complex64::new(1.0, 0.0));
354        builder.add(2, 3, Complex64::new(1.0, 0.0));
355        builder.add(3, 2, Complex64::new(1.0, 0.0));
356        builder.build()
357    }
358
359    /// Create sparse CZ gate
360    #[must_use]
361    pub fn cz() -> CSRMatrix {
362        let mut builder = SparseMatrixBuilder::new(4, 4);
363        builder.add(0, 0, Complex64::new(1.0, 0.0));
364        builder.add(1, 1, Complex64::new(1.0, 0.0));
365        builder.add(2, 2, Complex64::new(1.0, 0.0));
366        builder.add(3, 3, Complex64::new(-1.0, 0.0));
367        builder.build()
368    }
369
370    /// Create sparse rotation gate
371    pub fn rotation(axis: &str, angle: f64) -> Result<CSRMatrix> {
372        let (c, s) = (angle.cos(), angle.sin());
373        let half_angle = angle / 2.0;
374        let (ch, sh) = (half_angle.cos(), half_angle.sin());
375
376        let mut builder = SparseMatrixBuilder::new(2, 2);
377
378        match axis {
379            "x" | "X" => {
380                builder.add(0, 0, Complex64::new(ch, 0.0));
381                builder.add(0, 1, Complex64::new(0.0, -sh));
382                builder.add(1, 0, Complex64::new(0.0, -sh));
383                builder.add(1, 1, Complex64::new(ch, 0.0));
384            }
385            "y" | "Y" => {
386                builder.add(0, 0, Complex64::new(ch, 0.0));
387                builder.add(0, 1, Complex64::new(-sh, 0.0));
388                builder.add(1, 0, Complex64::new(sh, 0.0));
389                builder.add(1, 1, Complex64::new(ch, 0.0));
390            }
391            "z" | "Z" => {
392                builder.add(0, 0, Complex64::new(ch, -sh));
393                builder.add(1, 1, Complex64::new(ch, sh));
394            }
395            _ => {
396                return Err(SimulatorError::InvalidConfiguration(format!(
397                    "Unknown rotation axis: {axis}"
398                )))
399            }
400        }
401
402        Ok(builder.build())
403    }
404
405    /// Create sparse controlled rotation gate
406    pub fn controlled_rotation(axis: &str, angle: f64) -> Result<CSRMatrix> {
407        let single_qubit = Self::rotation(axis, angle)?;
408
409        let mut builder = SparseMatrixBuilder::new(4, 4);
410
411        // |00⟩ and |01⟩ states unchanged
412        builder.add(0, 0, Complex64::new(1.0, 0.0));
413        builder.add(1, 1, Complex64::new(1.0, 0.0));
414
415        // Apply rotation to |10⟩ and |11⟩ states
416        builder.add(2, 2, single_qubit.values[0]);
417        if single_qubit.values.len() > 1 {
418            builder.add(2, 3, single_qubit.values[1]);
419        }
420        if single_qubit.values.len() > 2 {
421            builder.add(3, 2, single_qubit.values[2]);
422        }
423        if single_qubit.values.len() > 3 {
424            builder.add(3, 3, single_qubit.values[3]);
425        }
426
427        Ok(builder.build())
428    }
429}
430
431/// Apply sparse gate to state vector at specific qubits
432pub fn apply_sparse_gate(
433    state: &mut Array1<Complex64>,
434    gate: &CSRMatrix,
435    qubits: &[usize],
436    num_qubits: usize,
437) -> Result<()> {
438    let gate_qubits = qubits.len();
439    let gate_dim = 1 << gate_qubits;
440
441    if gate.num_rows != gate_dim || gate.num_cols != gate_dim {
442        return Err(SimulatorError::DimensionMismatch(format!(
443            "Gate dimension {} doesn't match qubit count {}",
444            gate.num_rows, gate_qubits
445        )));
446    }
447
448    // Create bit masks for the target qubits
449    let mut masks = vec![0usize; gate_qubits];
450    for (i, &qubit) in qubits.iter().enumerate() {
451        masks[i] = 1 << qubit;
452    }
453
454    // Apply gate to all basis states
455    let state_dim = 1 << num_qubits;
456    let mut new_state = Array1::zeros(state_dim);
457
458    for i in 0..state_dim {
459        // Extract indices for gate qubits
460        let mut gate_idx = 0;
461        for (j, &mask) in masks.iter().enumerate() {
462            if i & mask != 0 {
463                gate_idx |= 1 << j;
464            }
465        }
466
467        // Apply sparse gate row
468        let row_start = gate.row_ptr[gate_idx];
469        let row_end = gate.row_ptr[gate_idx + 1];
470
471        for idx in row_start..row_end {
472            let gate_col = gate.col_indices[idx];
473            let gate_val = gate.values[idx];
474
475            // Reconstruct global index
476            let mut j = i;
477            for (k, &mask) in masks.iter().enumerate() {
478                if gate_col & (1 << k) != 0 {
479                    j |= mask;
480                } else {
481                    j &= !mask;
482                }
483            }
484
485            new_state[i] += gate_val * state[j];
486        }
487    }
488
489    state.assign(&new_state);
490    Ok(())
491}
492
493/// Optimize gate sequence using sparsity
494pub fn optimize_sparse_gates(gates: Vec<CSRMatrix>) -> Result<CSRMatrix> {
495    if gates.is_empty() {
496        return Err(SimulatorError::InvalidInput(
497            "Empty gate sequence".to_string(),
498        ));
499    }
500
501    let mut result = gates[0].clone();
502    for gate in gates.into_iter().skip(1) {
503        result = result.matmul(&gate)?;
504
505        // Threshold small values
506        result.values.retain(|&v| v.norm() > 1e-15);
507    }
508
509    Ok(result)
510}
511
512#[cfg(test)]
513mod tests {
514    use super::*;
515
516    #[test]
517    fn test_sparse_matrix_construction() {
518        let mut builder = SparseMatrixBuilder::new(3, 3);
519        builder.add(0, 0, Complex64::new(1.0, 0.0));
520        builder.add(1, 1, Complex64::new(2.0, 0.0));
521        builder.add(2, 2, Complex64::new(3.0, 0.0));
522        builder.add(0, 2, Complex64::new(4.0, 0.0));
523
524        let sparse = builder.build();
525        assert_eq!(sparse.nnz(), 4);
526        assert_eq!(sparse.num_rows, 3);
527        assert_eq!(sparse.num_cols, 3);
528    }
529
530    #[test]
531    fn test_sparse_gates() {
532        let x = SparseGates::x();
533        assert_eq!(x.nnz(), 2);
534
535        let cnot = SparseGates::cnot();
536        assert_eq!(cnot.nnz(), 4);
537
538        let rz = SparseGates::rotation("z", 0.5).expect("Failed to create rotation gate");
539        assert_eq!(rz.nnz(), 2);
540    }
541
542    #[test]
543    fn test_sparse_matvec() {
544        let x = SparseGates::x();
545        let vec = Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]);
546
547        let result = x
548            .matvec(&vec)
549            .expect("Failed to perform matrix-vector multiplication");
550        assert!((result[0] - Complex64::new(0.0, 0.0)).norm() < 1e-10);
551        assert!((result[1] - Complex64::new(1.0, 0.0)).norm() < 1e-10);
552    }
553
554    #[test]
555    fn test_sparse_matmul() {
556        let x = SparseGates::x();
557        let z = SparseGates::z();
558
559        let xz = x
560            .matmul(&z)
561            .expect("Failed to perform matrix multiplication");
562        let y_expected = SparseGates::y();
563
564        // X * Z = -iY
565        assert_eq!(xz.nnz(), y_expected.nnz());
566    }
567
568    #[test]
569    fn test_csr_to_dense() {
570        let cnot = SparseGates::cnot();
571        let dense = cnot.to_dense();
572
573        assert_eq!(dense.shape(), &[4, 4]);
574        assert!((dense[[0, 0]] - Complex64::new(1.0, 0.0)).norm() < 1e-10);
575        assert!((dense[[3, 2]] - Complex64::new(1.0, 0.0)).norm() < 1e-10);
576    }
577}