1use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::Complex64;
8use std::collections::HashMap;
9
10use crate::error::{Result, SimulatorError};
11
12#[derive(Debug, Clone)]
14pub struct CSRMatrix {
15 pub values: Vec<Complex64>,
17 pub col_indices: Vec<usize>,
19 pub row_ptr: Vec<usize>,
21 pub num_rows: usize,
23 pub num_cols: usize,
25}
26
27impl CSRMatrix {
28 #[must_use]
30 pub fn new(
31 values: Vec<Complex64>,
32 col_indices: Vec<usize>,
33 row_ptr: Vec<usize>,
34 num_rows: usize,
35 num_cols: usize,
36 ) -> Self {
37 assert_eq!(values.len(), col_indices.len());
38 assert_eq!(row_ptr.len(), num_rows + 1);
39
40 Self {
41 values,
42 col_indices,
43 row_ptr,
44 num_rows,
45 num_cols,
46 }
47 }
48
49 #[must_use]
51 pub fn from_dense(matrix: &Array2<Complex64>) -> Self {
52 let num_rows = matrix.nrows();
53 let num_cols = matrix.ncols();
54 let mut values = Vec::new();
55 let mut col_indices = Vec::new();
56 let mut row_ptr = vec![0];
57
58 for i in 0..num_rows {
59 for j in 0..num_cols {
60 let val = matrix[[i, j]];
61 if val.norm() > 1e-15 {
62 values.push(val);
63 col_indices.push(j);
64 }
65 }
66 row_ptr.push(values.len());
67 }
68
69 Self::new(values, col_indices, row_ptr, num_rows, num_cols)
70 }
71
72 #[must_use]
74 pub fn to_dense(&self) -> Array2<Complex64> {
75 let mut dense = Array2::zeros((self.num_rows, self.num_cols));
76
77 for i in 0..self.num_rows {
78 let start = self.row_ptr[i];
79 let end = self.row_ptr[i + 1];
80
81 for idx in start..end {
82 dense[[i, self.col_indices[idx]]] = self.values[idx];
83 }
84 }
85
86 dense
87 }
88
89 #[must_use]
91 pub fn nnz(&self) -> usize {
92 self.values.len()
93 }
94
95 pub fn matvec(&self, vec: &Array1<Complex64>) -> Result<Array1<Complex64>> {
97 if vec.len() != self.num_cols {
98 return Err(SimulatorError::DimensionMismatch(format!(
99 "Vector length {} doesn't match matrix columns {}",
100 vec.len(),
101 self.num_cols
102 )));
103 }
104
105 let mut result = Array1::zeros(self.num_rows);
106
107 for i in 0..self.num_rows {
108 let start = self.row_ptr[i];
109 let end = self.row_ptr[i + 1];
110
111 let mut sum = Complex64::new(0.0, 0.0);
112 for idx in start..end {
113 sum += self.values[idx] * vec[self.col_indices[idx]];
114 }
115 result[i] = sum;
116 }
117
118 Ok(result)
119 }
120
121 pub fn matmul(&self, other: &Self) -> Result<Self> {
123 if self.num_cols != other.num_rows {
124 return Err(SimulatorError::DimensionMismatch(format!(
125 "Matrix dimensions incompatible: {}x{} * {}x{}",
126 self.num_rows, self.num_cols, other.num_rows, other.num_cols
127 )));
128 }
129
130 let mut values = Vec::new();
131 let mut col_indices = Vec::new();
132 let mut row_ptr = vec![0];
133
134 let other_csc = other.to_csc();
136
137 for i in 0..self.num_rows {
138 let mut row_values: HashMap<usize, Complex64> = HashMap::new();
139
140 let a_start = self.row_ptr[i];
141 let a_end = self.row_ptr[i + 1];
142
143 for a_idx in a_start..a_end {
144 let k = self.col_indices[a_idx];
145 let a_val = self.values[a_idx];
146
147 let b_start = other_csc.col_ptr[k];
149 let b_end = other_csc.col_ptr[k + 1];
150
151 for b_idx in b_start..b_end {
152 let j = other_csc.row_indices[b_idx];
153 let b_val = other_csc.values[b_idx];
154
155 *row_values.entry(j).or_insert(Complex64::new(0.0, 0.0)) += a_val * b_val;
156 }
157 }
158
159 let mut sorted_cols: Vec<_> = row_values.into_iter().collect();
161 sorted_cols.sort_by_key(|(col, _)| *col);
162
163 for (col, val) in sorted_cols {
164 if val.norm() > 1e-15 {
165 values.push(val);
166 col_indices.push(col);
167 }
168 }
169
170 row_ptr.push(values.len());
171 }
172
173 Ok(Self::new(
174 values,
175 col_indices,
176 row_ptr,
177 self.num_rows,
178 other.num_cols,
179 ))
180 }
181
182 fn to_csc(&self) -> CSCMatrix {
184 let mut values = Vec::new();
185 let mut row_indices = Vec::new();
186 let mut col_ptr = vec![0; self.num_cols + 1];
187
188 for &col in &self.col_indices {
190 col_ptr[col + 1] += 1;
191 }
192
193 for i in 1..=self.num_cols {
195 col_ptr[i] += col_ptr[i - 1];
196 }
197
198 let mut current_pos = col_ptr[0..self.num_cols].to_vec();
200 values.resize(self.nnz(), Complex64::new(0.0, 0.0));
201 row_indices.resize(self.nnz(), 0);
202
203 for i in 0..self.num_rows {
205 let start = self.row_ptr[i];
206 let end = self.row_ptr[i + 1];
207
208 for idx in start..end {
209 let col = self.col_indices[idx];
210 let pos = current_pos[col];
211
212 values[pos] = self.values[idx];
213 row_indices[pos] = i;
214 current_pos[col] += 1;
215 }
216 }
217
218 CSCMatrix {
219 values,
220 row_indices,
221 col_ptr,
222 num_rows: self.num_rows,
223 num_cols: self.num_cols,
224 }
225 }
226}
227
228#[derive(Debug, Clone)]
230struct CSCMatrix {
231 values: Vec<Complex64>,
232 row_indices: Vec<usize>,
233 col_ptr: Vec<usize>,
234 num_rows: usize,
235 num_cols: usize,
236}
237
238#[derive(Debug)]
240pub struct SparseMatrixBuilder {
241 triplets: Vec<(usize, usize, Complex64)>,
242 num_rows: usize,
243 num_cols: usize,
244}
245
246impl SparseMatrixBuilder {
247 #[must_use]
249 pub const fn new(num_rows: usize, num_cols: usize) -> Self {
250 Self {
251 triplets: Vec::new(),
252 num_rows,
253 num_cols,
254 }
255 }
256
257 pub fn add(&mut self, row: usize, col: usize, value: Complex64) {
259 if row < self.num_rows && col < self.num_cols && value.norm() > 1e-15 {
260 self.triplets.push((row, col, value));
261 }
262 }
263
264 pub fn set_value(&mut self, row: usize, col: usize, value: Complex64) {
266 self.add(row, col, value);
267 }
268
269 #[must_use]
271 pub fn build(mut self) -> CSRMatrix {
272 self.triplets.sort_by_key(|(r, c, _)| (*r, *c));
274
275 let mut combined_triplets = Vec::new();
277 let mut last_pos: Option<(usize, usize)> = None;
278
279 for (r, c, v) in self.triplets {
280 if Some((r, c)) == last_pos {
281 if let Some(last) = combined_triplets.last_mut() {
282 let (_, _, ref mut last_val) = last;
283 *last_val += v;
284 }
285 } else {
286 combined_triplets.push((r, c, v));
287 last_pos = Some((r, c));
288 }
289 }
290
291 let mut values = Vec::new();
293 let mut col_indices = Vec::new();
294 let mut row_ptr = vec![0];
295 let mut current_row = 0;
296
297 for (r, c, v) in combined_triplets {
298 while current_row < r {
299 row_ptr.push(values.len());
300 current_row += 1;
301 }
302
303 if v.norm() > 1e-15 {
304 values.push(v);
305 col_indices.push(c);
306 }
307 }
308
309 while row_ptr.len() <= self.num_rows {
310 row_ptr.push(values.len());
311 }
312
313 CSRMatrix::new(values, col_indices, row_ptr, self.num_rows, self.num_cols)
314 }
315}
316
317pub struct SparseGates;
319
320impl SparseGates {
321 #[must_use]
323 pub fn x() -> CSRMatrix {
324 let mut builder = SparseMatrixBuilder::new(2, 2);
325 builder.add(0, 1, Complex64::new(1.0, 0.0));
326 builder.add(1, 0, Complex64::new(1.0, 0.0));
327 builder.build()
328 }
329
330 #[must_use]
332 pub fn y() -> CSRMatrix {
333 let mut builder = SparseMatrixBuilder::new(2, 2);
334 builder.add(0, 1, Complex64::new(0.0, -1.0));
335 builder.add(1, 0, Complex64::new(0.0, 1.0));
336 builder.build()
337 }
338
339 #[must_use]
341 pub fn z() -> CSRMatrix {
342 let mut builder = SparseMatrixBuilder::new(2, 2);
343 builder.add(0, 0, Complex64::new(1.0, 0.0));
344 builder.add(1, 1, Complex64::new(-1.0, 0.0));
345 builder.build()
346 }
347
348 #[must_use]
350 pub fn cnot() -> CSRMatrix {
351 let mut builder = SparseMatrixBuilder::new(4, 4);
352 builder.add(0, 0, Complex64::new(1.0, 0.0));
353 builder.add(1, 1, Complex64::new(1.0, 0.0));
354 builder.add(2, 3, Complex64::new(1.0, 0.0));
355 builder.add(3, 2, Complex64::new(1.0, 0.0));
356 builder.build()
357 }
358
359 #[must_use]
361 pub fn cz() -> CSRMatrix {
362 let mut builder = SparseMatrixBuilder::new(4, 4);
363 builder.add(0, 0, Complex64::new(1.0, 0.0));
364 builder.add(1, 1, Complex64::new(1.0, 0.0));
365 builder.add(2, 2, Complex64::new(1.0, 0.0));
366 builder.add(3, 3, Complex64::new(-1.0, 0.0));
367 builder.build()
368 }
369
370 pub fn rotation(axis: &str, angle: f64) -> Result<CSRMatrix> {
372 let (c, s) = (angle.cos(), angle.sin());
373 let half_angle = angle / 2.0;
374 let (ch, sh) = (half_angle.cos(), half_angle.sin());
375
376 let mut builder = SparseMatrixBuilder::new(2, 2);
377
378 match axis {
379 "x" | "X" => {
380 builder.add(0, 0, Complex64::new(ch, 0.0));
381 builder.add(0, 1, Complex64::new(0.0, -sh));
382 builder.add(1, 0, Complex64::new(0.0, -sh));
383 builder.add(1, 1, Complex64::new(ch, 0.0));
384 }
385 "y" | "Y" => {
386 builder.add(0, 0, Complex64::new(ch, 0.0));
387 builder.add(0, 1, Complex64::new(-sh, 0.0));
388 builder.add(1, 0, Complex64::new(sh, 0.0));
389 builder.add(1, 1, Complex64::new(ch, 0.0));
390 }
391 "z" | "Z" => {
392 builder.add(0, 0, Complex64::new(ch, -sh));
393 builder.add(1, 1, Complex64::new(ch, sh));
394 }
395 _ => {
396 return Err(SimulatorError::InvalidConfiguration(format!(
397 "Unknown rotation axis: {axis}"
398 )))
399 }
400 }
401
402 Ok(builder.build())
403 }
404
405 pub fn controlled_rotation(axis: &str, angle: f64) -> Result<CSRMatrix> {
407 let single_qubit = Self::rotation(axis, angle)?;
408
409 let mut builder = SparseMatrixBuilder::new(4, 4);
410
411 builder.add(0, 0, Complex64::new(1.0, 0.0));
413 builder.add(1, 1, Complex64::new(1.0, 0.0));
414
415 builder.add(2, 2, single_qubit.values[0]);
417 if single_qubit.values.len() > 1 {
418 builder.add(2, 3, single_qubit.values[1]);
419 }
420 if single_qubit.values.len() > 2 {
421 builder.add(3, 2, single_qubit.values[2]);
422 }
423 if single_qubit.values.len() > 3 {
424 builder.add(3, 3, single_qubit.values[3]);
425 }
426
427 Ok(builder.build())
428 }
429}
430
431pub fn apply_sparse_gate(
433 state: &mut Array1<Complex64>,
434 gate: &CSRMatrix,
435 qubits: &[usize],
436 num_qubits: usize,
437) -> Result<()> {
438 let gate_qubits = qubits.len();
439 let gate_dim = 1 << gate_qubits;
440
441 if gate.num_rows != gate_dim || gate.num_cols != gate_dim {
442 return Err(SimulatorError::DimensionMismatch(format!(
443 "Gate dimension {} doesn't match qubit count {}",
444 gate.num_rows, gate_qubits
445 )));
446 }
447
448 let mut masks = vec![0usize; gate_qubits];
450 for (i, &qubit) in qubits.iter().enumerate() {
451 masks[i] = 1 << qubit;
452 }
453
454 let state_dim = 1 << num_qubits;
456 let mut new_state = Array1::zeros(state_dim);
457
458 for i in 0..state_dim {
459 let mut gate_idx = 0;
461 for (j, &mask) in masks.iter().enumerate() {
462 if i & mask != 0 {
463 gate_idx |= 1 << j;
464 }
465 }
466
467 let row_start = gate.row_ptr[gate_idx];
469 let row_end = gate.row_ptr[gate_idx + 1];
470
471 for idx in row_start..row_end {
472 let gate_col = gate.col_indices[idx];
473 let gate_val = gate.values[idx];
474
475 let mut j = i;
477 for (k, &mask) in masks.iter().enumerate() {
478 if gate_col & (1 << k) != 0 {
479 j |= mask;
480 } else {
481 j &= !mask;
482 }
483 }
484
485 new_state[i] += gate_val * state[j];
486 }
487 }
488
489 state.assign(&new_state);
490 Ok(())
491}
492
493pub fn optimize_sparse_gates(gates: Vec<CSRMatrix>) -> Result<CSRMatrix> {
495 if gates.is_empty() {
496 return Err(SimulatorError::InvalidInput(
497 "Empty gate sequence".to_string(),
498 ));
499 }
500
501 let mut result = gates[0].clone();
502 for gate in gates.into_iter().skip(1) {
503 result = result.matmul(&gate)?;
504
505 result.values.retain(|&v| v.norm() > 1e-15);
507 }
508
509 Ok(result)
510}
511
512#[cfg(test)]
513mod tests {
514 use super::*;
515
516 #[test]
517 fn test_sparse_matrix_construction() {
518 let mut builder = SparseMatrixBuilder::new(3, 3);
519 builder.add(0, 0, Complex64::new(1.0, 0.0));
520 builder.add(1, 1, Complex64::new(2.0, 0.0));
521 builder.add(2, 2, Complex64::new(3.0, 0.0));
522 builder.add(0, 2, Complex64::new(4.0, 0.0));
523
524 let sparse = builder.build();
525 assert_eq!(sparse.nnz(), 4);
526 assert_eq!(sparse.num_rows, 3);
527 assert_eq!(sparse.num_cols, 3);
528 }
529
530 #[test]
531 fn test_sparse_gates() {
532 let x = SparseGates::x();
533 assert_eq!(x.nnz(), 2);
534
535 let cnot = SparseGates::cnot();
536 assert_eq!(cnot.nnz(), 4);
537
538 let rz = SparseGates::rotation("z", 0.5).expect("Failed to create rotation gate");
539 assert_eq!(rz.nnz(), 2);
540 }
541
542 #[test]
543 fn test_sparse_matvec() {
544 let x = SparseGates::x();
545 let vec = Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]);
546
547 let result = x
548 .matvec(&vec)
549 .expect("Failed to perform matrix-vector multiplication");
550 assert!((result[0] - Complex64::new(0.0, 0.0)).norm() < 1e-10);
551 assert!((result[1] - Complex64::new(1.0, 0.0)).norm() < 1e-10);
552 }
553
554 #[test]
555 fn test_sparse_matmul() {
556 let x = SparseGates::x();
557 let z = SparseGates::z();
558
559 let xz = x
560 .matmul(&z)
561 .expect("Failed to perform matrix multiplication");
562 let y_expected = SparseGates::y();
563
564 assert_eq!(xz.nnz(), y_expected.nnz());
566 }
567
568 #[test]
569 fn test_csr_to_dense() {
570 let cnot = SparseGates::cnot();
571 let dense = cnot.to_dense();
572
573 assert_eq!(dense.shape(), &[4, 4]);
574 assert!((dense[[0, 0]] - Complex64::new(1.0, 0.0)).norm() < 1e-10);
575 assert!((dense[[3, 2]] - Complex64::new(1.0, 0.0)).norm() < 1e-10);
576 }
577}