quantrs2_sim/
sparse.rs

1//! Sparse matrix operations for efficient quantum circuit simulation.
2//!
3//! This module provides sparse matrix representations and operations
4//! optimized for quantum gates, especially for circuits with limited connectivity.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::Complex64;
8use std::collections::HashMap;
9
10use crate::error::{Result, SimulatorError};
11
12/// Compressed Sparse Row (CSR) matrix format
13#[derive(Debug, Clone)]
14pub struct CSRMatrix {
15    /// Non-zero values
16    pub values: Vec<Complex64>,
17    /// Column indices for each value
18    pub col_indices: Vec<usize>,
19    /// Row pointer array
20    pub row_ptr: Vec<usize>,
21    /// Number of rows
22    pub num_rows: usize,
23    /// Number of columns
24    pub num_cols: usize,
25}
26
27impl CSRMatrix {
28    /// Create a new CSR matrix
29    pub fn new(
30        values: Vec<Complex64>,
31        col_indices: Vec<usize>,
32        row_ptr: Vec<usize>,
33        num_rows: usize,
34        num_cols: usize,
35    ) -> Self {
36        assert_eq!(values.len(), col_indices.len());
37        assert_eq!(row_ptr.len(), num_rows + 1);
38
39        Self {
40            values,
41            col_indices,
42            row_ptr,
43            num_rows,
44            num_cols,
45        }
46    }
47
48    /// Create from a dense matrix
49    pub fn from_dense(matrix: &Array2<Complex64>) -> Self {
50        let num_rows = matrix.nrows();
51        let num_cols = matrix.ncols();
52        let mut values = Vec::new();
53        let mut col_indices = Vec::new();
54        let mut row_ptr = vec![0];
55
56        for i in 0..num_rows {
57            for j in 0..num_cols {
58                let val = matrix[[i, j]];
59                if val.norm() > 1e-15 {
60                    values.push(val);
61                    col_indices.push(j);
62                }
63            }
64            row_ptr.push(values.len());
65        }
66
67        Self::new(values, col_indices, row_ptr, num_rows, num_cols)
68    }
69
70    /// Convert to dense matrix
71    pub fn to_dense(&self) -> Array2<Complex64> {
72        let mut dense = Array2::zeros((self.num_rows, self.num_cols));
73
74        for i in 0..self.num_rows {
75            let start = self.row_ptr[i];
76            let end = self.row_ptr[i + 1];
77
78            for idx in start..end {
79                dense[[i, self.col_indices[idx]]] = self.values[idx];
80            }
81        }
82
83        dense
84    }
85
86    /// Get number of non-zero elements
87    pub fn nnz(&self) -> usize {
88        self.values.len()
89    }
90
91    /// Matrix-vector multiplication
92    pub fn matvec(&self, vec: &Array1<Complex64>) -> Result<Array1<Complex64>> {
93        if vec.len() != self.num_cols {
94            return Err(SimulatorError::DimensionMismatch(format!(
95                "Vector length {} doesn't match matrix columns {}",
96                vec.len(),
97                self.num_cols
98            )));
99        }
100
101        let mut result = Array1::zeros(self.num_rows);
102
103        for i in 0..self.num_rows {
104            let start = self.row_ptr[i];
105            let end = self.row_ptr[i + 1];
106
107            let mut sum = Complex64::new(0.0, 0.0);
108            for idx in start..end {
109                sum += self.values[idx] * vec[self.col_indices[idx]];
110            }
111            result[i] = sum;
112        }
113
114        Ok(result)
115    }
116
117    /// Sparse matrix multiplication
118    pub fn matmul(&self, other: &CSRMatrix) -> Result<CSRMatrix> {
119        if self.num_cols != other.num_rows {
120            return Err(SimulatorError::DimensionMismatch(format!(
121                "Matrix dimensions incompatible: {}x{} * {}x{}",
122                self.num_rows, self.num_cols, other.num_rows, other.num_cols
123            )));
124        }
125
126        let mut values = Vec::new();
127        let mut col_indices = Vec::new();
128        let mut row_ptr = vec![0];
129
130        // Convert other to CSC for efficient column access
131        let other_csc = other.to_csc();
132
133        for i in 0..self.num_rows {
134            let mut row_values: HashMap<usize, Complex64> = HashMap::new();
135
136            let a_start = self.row_ptr[i];
137            let a_end = self.row_ptr[i + 1];
138
139            for a_idx in a_start..a_end {
140                let k = self.col_indices[a_idx];
141                let a_val = self.values[a_idx];
142
143                // Multiply row i of A with column k of B
144                let b_start = other_csc.col_ptr[k];
145                let b_end = other_csc.col_ptr[k + 1];
146
147                for b_idx in b_start..b_end {
148                    let j = other_csc.row_indices[b_idx];
149                    let b_val = other_csc.values[b_idx];
150
151                    *row_values.entry(j).or_insert(Complex64::new(0.0, 0.0)) += a_val * b_val;
152                }
153            }
154
155            // Sort by column index and add to result
156            let mut sorted_cols: Vec<_> = row_values.into_iter().collect();
157            sorted_cols.sort_by_key(|(col, _)| *col);
158
159            for (col, val) in sorted_cols {
160                if val.norm() > 1e-15 {
161                    values.push(val);
162                    col_indices.push(col);
163                }
164            }
165
166            row_ptr.push(values.len());
167        }
168
169        Ok(CSRMatrix::new(
170            values,
171            col_indices,
172            row_ptr,
173            self.num_rows,
174            other.num_cols,
175        ))
176    }
177
178    /// Convert to Compressed Sparse Column (CSC) format
179    fn to_csc(&self) -> CSCMatrix {
180        let mut values = Vec::new();
181        let mut row_indices = Vec::new();
182        let mut col_ptr = vec![0; self.num_cols + 1];
183
184        // Count elements per column
185        for &col in &self.col_indices {
186            col_ptr[col + 1] += 1;
187        }
188
189        // Cumulative sum to get column pointers
190        for i in 1..=self.num_cols {
191            col_ptr[i] += col_ptr[i - 1];
192        }
193
194        // Temporary array to track current position in each column
195        let mut current_pos = col_ptr[0..self.num_cols].to_vec();
196        values.resize(self.nnz(), Complex64::new(0.0, 0.0));
197        row_indices.resize(self.nnz(), 0);
198
199        // Fill CSC arrays
200        for i in 0..self.num_rows {
201            let start = self.row_ptr[i];
202            let end = self.row_ptr[i + 1];
203
204            for idx in start..end {
205                let col = self.col_indices[idx];
206                let pos = current_pos[col];
207
208                values[pos] = self.values[idx];
209                row_indices[pos] = i;
210                current_pos[col] += 1;
211            }
212        }
213
214        CSCMatrix {
215            values,
216            row_indices,
217            col_ptr,
218            num_rows: self.num_rows,
219            num_cols: self.num_cols,
220        }
221    }
222}
223
224/// Compressed Sparse Column (CSC) matrix format
225#[derive(Debug, Clone)]
226struct CSCMatrix {
227    values: Vec<Complex64>,
228    row_indices: Vec<usize>,
229    col_ptr: Vec<usize>,
230    num_rows: usize,
231    num_cols: usize,
232}
233
234/// Sparse matrix builder for incremental construction
235#[derive(Debug)]
236pub struct SparseMatrixBuilder {
237    triplets: Vec<(usize, usize, Complex64)>,
238    num_rows: usize,
239    num_cols: usize,
240}
241
242impl SparseMatrixBuilder {
243    /// Create a new builder
244    pub fn new(num_rows: usize, num_cols: usize) -> Self {
245        Self {
246            triplets: Vec::new(),
247            num_rows,
248            num_cols,
249        }
250    }
251
252    /// Add an element to the matrix
253    pub fn add(&mut self, row: usize, col: usize, value: Complex64) {
254        if row < self.num_rows && col < self.num_cols && value.norm() > 1e-15 {
255            self.triplets.push((row, col, value));
256        }
257    }
258
259    /// Set value at specific position (alias for add)
260    pub fn set_value(&mut self, row: usize, col: usize, value: Complex64) {
261        self.add(row, col, value);
262    }
263
264    /// Build the CSR matrix
265    pub fn build(mut self) -> CSRMatrix {
266        // Sort by row, then column
267        self.triplets.sort_by_key(|(r, c, _)| (*r, *c));
268
269        // Combine duplicates
270        let mut combined_triplets = Vec::new();
271        let mut last_pos: Option<(usize, usize)> = None;
272
273        for (r, c, v) in self.triplets {
274            if Some((r, c)) == last_pos {
275                if let Some(last) = combined_triplets.last_mut() {
276                    let (_, _, ref mut last_val) = last;
277                    *last_val += v;
278                }
279            } else {
280                combined_triplets.push((r, c, v));
281                last_pos = Some((r, c));
282            }
283        }
284
285        // Build CSR arrays
286        let mut values = Vec::new();
287        let mut col_indices = Vec::new();
288        let mut row_ptr = vec![0];
289        let mut current_row = 0;
290
291        for (r, c, v) in combined_triplets {
292            while current_row < r {
293                row_ptr.push(values.len());
294                current_row += 1;
295            }
296
297            if v.norm() > 1e-15 {
298                values.push(v);
299                col_indices.push(c);
300            }
301        }
302
303        while row_ptr.len() <= self.num_rows {
304            row_ptr.push(values.len());
305        }
306
307        CSRMatrix::new(values, col_indices, row_ptr, self.num_rows, self.num_cols)
308    }
309}
310
311/// Sparse quantum gate representations
312pub struct SparseGates;
313
314impl SparseGates {
315    /// Create sparse Pauli X gate
316    pub fn x() -> CSRMatrix {
317        let mut builder = SparseMatrixBuilder::new(2, 2);
318        builder.add(0, 1, Complex64::new(1.0, 0.0));
319        builder.add(1, 0, Complex64::new(1.0, 0.0));
320        builder.build()
321    }
322
323    /// Create sparse Pauli Y gate
324    pub fn y() -> CSRMatrix {
325        let mut builder = SparseMatrixBuilder::new(2, 2);
326        builder.add(0, 1, Complex64::new(0.0, -1.0));
327        builder.add(1, 0, Complex64::new(0.0, 1.0));
328        builder.build()
329    }
330
331    /// Create sparse Pauli Z gate
332    pub fn z() -> CSRMatrix {
333        let mut builder = SparseMatrixBuilder::new(2, 2);
334        builder.add(0, 0, Complex64::new(1.0, 0.0));
335        builder.add(1, 1, Complex64::new(-1.0, 0.0));
336        builder.build()
337    }
338
339    /// Create sparse CNOT gate
340    pub fn cnot() -> CSRMatrix {
341        let mut builder = SparseMatrixBuilder::new(4, 4);
342        builder.add(0, 0, Complex64::new(1.0, 0.0));
343        builder.add(1, 1, Complex64::new(1.0, 0.0));
344        builder.add(2, 3, Complex64::new(1.0, 0.0));
345        builder.add(3, 2, Complex64::new(1.0, 0.0));
346        builder.build()
347    }
348
349    /// Create sparse CZ gate
350    pub fn cz() -> CSRMatrix {
351        let mut builder = SparseMatrixBuilder::new(4, 4);
352        builder.add(0, 0, Complex64::new(1.0, 0.0));
353        builder.add(1, 1, Complex64::new(1.0, 0.0));
354        builder.add(2, 2, Complex64::new(1.0, 0.0));
355        builder.add(3, 3, Complex64::new(-1.0, 0.0));
356        builder.build()
357    }
358
359    /// Create sparse rotation gate
360    pub fn rotation(axis: &str, angle: f64) -> Result<CSRMatrix> {
361        let (c, s) = (angle.cos(), angle.sin());
362        let half_angle = angle / 2.0;
363        let (ch, sh) = (half_angle.cos(), half_angle.sin());
364
365        let mut builder = SparseMatrixBuilder::new(2, 2);
366
367        match axis {
368            "x" | "X" => {
369                builder.add(0, 0, Complex64::new(ch, 0.0));
370                builder.add(0, 1, Complex64::new(0.0, -sh));
371                builder.add(1, 0, Complex64::new(0.0, -sh));
372                builder.add(1, 1, Complex64::new(ch, 0.0));
373            }
374            "y" | "Y" => {
375                builder.add(0, 0, Complex64::new(ch, 0.0));
376                builder.add(0, 1, Complex64::new(-sh, 0.0));
377                builder.add(1, 0, Complex64::new(sh, 0.0));
378                builder.add(1, 1, Complex64::new(ch, 0.0));
379            }
380            "z" | "Z" => {
381                builder.add(0, 0, Complex64::new(ch, -sh));
382                builder.add(1, 1, Complex64::new(ch, sh));
383            }
384            _ => {
385                return Err(SimulatorError::InvalidConfiguration(format!(
386                    "Unknown rotation axis: {}",
387                    axis
388                )))
389            }
390        }
391
392        Ok(builder.build())
393    }
394
395    /// Create sparse controlled rotation gate
396    pub fn controlled_rotation(axis: &str, angle: f64) -> Result<CSRMatrix> {
397        let single_qubit = Self::rotation(axis, angle)?;
398
399        let mut builder = SparseMatrixBuilder::new(4, 4);
400
401        // |00⟩ and |01⟩ states unchanged
402        builder.add(0, 0, Complex64::new(1.0, 0.0));
403        builder.add(1, 1, Complex64::new(1.0, 0.0));
404
405        // Apply rotation to |10⟩ and |11⟩ states
406        builder.add(2, 2, single_qubit.values[0]);
407        if single_qubit.values.len() > 1 {
408            builder.add(2, 3, single_qubit.values[1]);
409        }
410        if single_qubit.values.len() > 2 {
411            builder.add(3, 2, single_qubit.values[2]);
412        }
413        if single_qubit.values.len() > 3 {
414            builder.add(3, 3, single_qubit.values[3]);
415        }
416
417        Ok(builder.build())
418    }
419}
420
421/// Apply sparse gate to state vector at specific qubits
422pub fn apply_sparse_gate(
423    state: &mut Array1<Complex64>,
424    gate: &CSRMatrix,
425    qubits: &[usize],
426    num_qubits: usize,
427) -> Result<()> {
428    let gate_qubits = qubits.len();
429    let gate_dim = 1 << gate_qubits;
430
431    if gate.num_rows != gate_dim || gate.num_cols != gate_dim {
432        return Err(SimulatorError::DimensionMismatch(format!(
433            "Gate dimension {} doesn't match qubit count {}",
434            gate.num_rows, gate_qubits
435        )));
436    }
437
438    // Create bit masks for the target qubits
439    let mut masks = vec![0usize; gate_qubits];
440    for (i, &qubit) in qubits.iter().enumerate() {
441        masks[i] = 1 << qubit;
442    }
443
444    // Apply gate to all basis states
445    let state_dim = 1 << num_qubits;
446    let mut new_state = Array1::zeros(state_dim);
447
448    for i in 0..state_dim {
449        // Extract indices for gate qubits
450        let mut gate_idx = 0;
451        for (j, &mask) in masks.iter().enumerate() {
452            if i & mask != 0 {
453                gate_idx |= 1 << j;
454            }
455        }
456
457        // Apply sparse gate row
458        let row_start = gate.row_ptr[gate_idx];
459        let row_end = gate.row_ptr[gate_idx + 1];
460
461        for idx in row_start..row_end {
462            let gate_col = gate.col_indices[idx];
463            let gate_val = gate.values[idx];
464
465            // Reconstruct global index
466            let mut j = i;
467            for (k, &mask) in masks.iter().enumerate() {
468                if gate_col & (1 << k) != 0 {
469                    j |= mask;
470                } else {
471                    j &= !mask;
472                }
473            }
474
475            new_state[i] += gate_val * state[j];
476        }
477    }
478
479    state.assign(&new_state);
480    Ok(())
481}
482
483/// Optimize gate sequence using sparsity
484pub fn optimize_sparse_gates(gates: Vec<CSRMatrix>) -> Result<CSRMatrix> {
485    if gates.is_empty() {
486        return Err(SimulatorError::InvalidInput(
487            "Empty gate sequence".to_string(),
488        ));
489    }
490
491    let mut result = gates[0].clone();
492    for gate in gates.into_iter().skip(1) {
493        result = result.matmul(&gate)?;
494
495        // Threshold small values
496        result.values.retain(|&v| v.norm() > 1e-15);
497    }
498
499    Ok(result)
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505
506    #[test]
507    fn test_sparse_matrix_construction() {
508        let mut builder = SparseMatrixBuilder::new(3, 3);
509        builder.add(0, 0, Complex64::new(1.0, 0.0));
510        builder.add(1, 1, Complex64::new(2.0, 0.0));
511        builder.add(2, 2, Complex64::new(3.0, 0.0));
512        builder.add(0, 2, Complex64::new(4.0, 0.0));
513
514        let sparse = builder.build();
515        assert_eq!(sparse.nnz(), 4);
516        assert_eq!(sparse.num_rows, 3);
517        assert_eq!(sparse.num_cols, 3);
518    }
519
520    #[test]
521    fn test_sparse_gates() {
522        let x = SparseGates::x();
523        assert_eq!(x.nnz(), 2);
524
525        let cnot = SparseGates::cnot();
526        assert_eq!(cnot.nnz(), 4);
527
528        let rz = SparseGates::rotation("z", 0.5).unwrap();
529        assert_eq!(rz.nnz(), 2);
530    }
531
532    #[test]
533    fn test_sparse_matvec() {
534        let x = SparseGates::x();
535        let vec = Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]);
536
537        let result = x.matvec(&vec).unwrap();
538        assert!((result[0] - Complex64::new(0.0, 0.0)).norm() < 1e-10);
539        assert!((result[1] - Complex64::new(1.0, 0.0)).norm() < 1e-10);
540    }
541
542    #[test]
543    fn test_sparse_matmul() {
544        let x = SparseGates::x();
545        let z = SparseGates::z();
546
547        let xz = x.matmul(&z).unwrap();
548        let y_expected = SparseGates::y();
549
550        // X * Z = -iY
551        assert_eq!(xz.nnz(), y_expected.nnz());
552    }
553
554    #[test]
555    fn test_csr_to_dense() {
556        let cnot = SparseGates::cnot();
557        let dense = cnot.to_dense();
558
559        assert_eq!(dense.shape(), &[4, 4]);
560        assert!((dense[[0, 0]] - Complex64::new(1.0, 0.0)).norm() < 1e-10);
561        assert!((dense[[3, 2]] - Complex64::new(1.0, 0.0)).norm() < 1e-10);
562    }
563}