1use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::Complex64;
8use std::collections::HashMap;
9
10use crate::error::{Result, SimulatorError};
11
12#[derive(Debug, Clone)]
14pub struct CSRMatrix {
15 pub values: Vec<Complex64>,
17 pub col_indices: Vec<usize>,
19 pub row_ptr: Vec<usize>,
21 pub num_rows: usize,
23 pub num_cols: usize,
25}
26
27impl CSRMatrix {
28 pub fn new(
30 values: Vec<Complex64>,
31 col_indices: Vec<usize>,
32 row_ptr: Vec<usize>,
33 num_rows: usize,
34 num_cols: usize,
35 ) -> Self {
36 assert_eq!(values.len(), col_indices.len());
37 assert_eq!(row_ptr.len(), num_rows + 1);
38
39 Self {
40 values,
41 col_indices,
42 row_ptr,
43 num_rows,
44 num_cols,
45 }
46 }
47
48 pub fn from_dense(matrix: &Array2<Complex64>) -> Self {
50 let num_rows = matrix.nrows();
51 let num_cols = matrix.ncols();
52 let mut values = Vec::new();
53 let mut col_indices = Vec::new();
54 let mut row_ptr = vec![0];
55
56 for i in 0..num_rows {
57 for j in 0..num_cols {
58 let val = matrix[[i, j]];
59 if val.norm() > 1e-15 {
60 values.push(val);
61 col_indices.push(j);
62 }
63 }
64 row_ptr.push(values.len());
65 }
66
67 Self::new(values, col_indices, row_ptr, num_rows, num_cols)
68 }
69
70 pub fn to_dense(&self) -> Array2<Complex64> {
72 let mut dense = Array2::zeros((self.num_rows, self.num_cols));
73
74 for i in 0..self.num_rows {
75 let start = self.row_ptr[i];
76 let end = self.row_ptr[i + 1];
77
78 for idx in start..end {
79 dense[[i, self.col_indices[idx]]] = self.values[idx];
80 }
81 }
82
83 dense
84 }
85
86 pub fn nnz(&self) -> usize {
88 self.values.len()
89 }
90
91 pub fn matvec(&self, vec: &Array1<Complex64>) -> Result<Array1<Complex64>> {
93 if vec.len() != self.num_cols {
94 return Err(SimulatorError::DimensionMismatch(format!(
95 "Vector length {} doesn't match matrix columns {}",
96 vec.len(),
97 self.num_cols
98 )));
99 }
100
101 let mut result = Array1::zeros(self.num_rows);
102
103 for i in 0..self.num_rows {
104 let start = self.row_ptr[i];
105 let end = self.row_ptr[i + 1];
106
107 let mut sum = Complex64::new(0.0, 0.0);
108 for idx in start..end {
109 sum += self.values[idx] * vec[self.col_indices[idx]];
110 }
111 result[i] = sum;
112 }
113
114 Ok(result)
115 }
116
117 pub fn matmul(&self, other: &Self) -> Result<Self> {
119 if self.num_cols != other.num_rows {
120 return Err(SimulatorError::DimensionMismatch(format!(
121 "Matrix dimensions incompatible: {}x{} * {}x{}",
122 self.num_rows, self.num_cols, other.num_rows, other.num_cols
123 )));
124 }
125
126 let mut values = Vec::new();
127 let mut col_indices = Vec::new();
128 let mut row_ptr = vec![0];
129
130 let other_csc = other.to_csc();
132
133 for i in 0..self.num_rows {
134 let mut row_values: HashMap<usize, Complex64> = HashMap::new();
135
136 let a_start = self.row_ptr[i];
137 let a_end = self.row_ptr[i + 1];
138
139 for a_idx in a_start..a_end {
140 let k = self.col_indices[a_idx];
141 let a_val = self.values[a_idx];
142
143 let b_start = other_csc.col_ptr[k];
145 let b_end = other_csc.col_ptr[k + 1];
146
147 for b_idx in b_start..b_end {
148 let j = other_csc.row_indices[b_idx];
149 let b_val = other_csc.values[b_idx];
150
151 *row_values.entry(j).or_insert(Complex64::new(0.0, 0.0)) += a_val * b_val;
152 }
153 }
154
155 let mut sorted_cols: Vec<_> = row_values.into_iter().collect();
157 sorted_cols.sort_by_key(|(col, _)| *col);
158
159 for (col, val) in sorted_cols {
160 if val.norm() > 1e-15 {
161 values.push(val);
162 col_indices.push(col);
163 }
164 }
165
166 row_ptr.push(values.len());
167 }
168
169 Ok(Self::new(
170 values,
171 col_indices,
172 row_ptr,
173 self.num_rows,
174 other.num_cols,
175 ))
176 }
177
178 fn to_csc(&self) -> CSCMatrix {
180 let mut values = Vec::new();
181 let mut row_indices = Vec::new();
182 let mut col_ptr = vec![0; self.num_cols + 1];
183
184 for &col in &self.col_indices {
186 col_ptr[col + 1] += 1;
187 }
188
189 for i in 1..=self.num_cols {
191 col_ptr[i] += col_ptr[i - 1];
192 }
193
194 let mut current_pos = col_ptr[0..self.num_cols].to_vec();
196 values.resize(self.nnz(), Complex64::new(0.0, 0.0));
197 row_indices.resize(self.nnz(), 0);
198
199 for i in 0..self.num_rows {
201 let start = self.row_ptr[i];
202 let end = self.row_ptr[i + 1];
203
204 for idx in start..end {
205 let col = self.col_indices[idx];
206 let pos = current_pos[col];
207
208 values[pos] = self.values[idx];
209 row_indices[pos] = i;
210 current_pos[col] += 1;
211 }
212 }
213
214 CSCMatrix {
215 values,
216 row_indices,
217 col_ptr,
218 num_rows: self.num_rows,
219 num_cols: self.num_cols,
220 }
221 }
222}
223
224#[derive(Debug, Clone)]
226struct CSCMatrix {
227 values: Vec<Complex64>,
228 row_indices: Vec<usize>,
229 col_ptr: Vec<usize>,
230 num_rows: usize,
231 num_cols: usize,
232}
233
234#[derive(Debug)]
236pub struct SparseMatrixBuilder {
237 triplets: Vec<(usize, usize, Complex64)>,
238 num_rows: usize,
239 num_cols: usize,
240}
241
242impl SparseMatrixBuilder {
243 pub const fn new(num_rows: usize, num_cols: usize) -> Self {
245 Self {
246 triplets: Vec::new(),
247 num_rows,
248 num_cols,
249 }
250 }
251
252 pub fn add(&mut self, row: usize, col: usize, value: Complex64) {
254 if row < self.num_rows && col < self.num_cols && value.norm() > 1e-15 {
255 self.triplets.push((row, col, value));
256 }
257 }
258
259 pub fn set_value(&mut self, row: usize, col: usize, value: Complex64) {
261 self.add(row, col, value);
262 }
263
264 pub fn build(mut self) -> CSRMatrix {
266 self.triplets.sort_by_key(|(r, c, _)| (*r, *c));
268
269 let mut combined_triplets = Vec::new();
271 let mut last_pos: Option<(usize, usize)> = None;
272
273 for (r, c, v) in self.triplets {
274 if Some((r, c)) == last_pos {
275 if let Some(last) = combined_triplets.last_mut() {
276 let (_, _, ref mut last_val) = last;
277 *last_val += v;
278 }
279 } else {
280 combined_triplets.push((r, c, v));
281 last_pos = Some((r, c));
282 }
283 }
284
285 let mut values = Vec::new();
287 let mut col_indices = Vec::new();
288 let mut row_ptr = vec![0];
289 let mut current_row = 0;
290
291 for (r, c, v) in combined_triplets {
292 while current_row < r {
293 row_ptr.push(values.len());
294 current_row += 1;
295 }
296
297 if v.norm() > 1e-15 {
298 values.push(v);
299 col_indices.push(c);
300 }
301 }
302
303 while row_ptr.len() <= self.num_rows {
304 row_ptr.push(values.len());
305 }
306
307 CSRMatrix::new(values, col_indices, row_ptr, self.num_rows, self.num_cols)
308 }
309}
310
311pub struct SparseGates;
313
314impl SparseGates {
315 pub fn x() -> CSRMatrix {
317 let mut builder = SparseMatrixBuilder::new(2, 2);
318 builder.add(0, 1, Complex64::new(1.0, 0.0));
319 builder.add(1, 0, Complex64::new(1.0, 0.0));
320 builder.build()
321 }
322
323 pub fn y() -> CSRMatrix {
325 let mut builder = SparseMatrixBuilder::new(2, 2);
326 builder.add(0, 1, Complex64::new(0.0, -1.0));
327 builder.add(1, 0, Complex64::new(0.0, 1.0));
328 builder.build()
329 }
330
331 pub fn z() -> CSRMatrix {
333 let mut builder = SparseMatrixBuilder::new(2, 2);
334 builder.add(0, 0, Complex64::new(1.0, 0.0));
335 builder.add(1, 1, Complex64::new(-1.0, 0.0));
336 builder.build()
337 }
338
339 pub fn cnot() -> CSRMatrix {
341 let mut builder = SparseMatrixBuilder::new(4, 4);
342 builder.add(0, 0, Complex64::new(1.0, 0.0));
343 builder.add(1, 1, Complex64::new(1.0, 0.0));
344 builder.add(2, 3, Complex64::new(1.0, 0.0));
345 builder.add(3, 2, Complex64::new(1.0, 0.0));
346 builder.build()
347 }
348
349 pub fn cz() -> CSRMatrix {
351 let mut builder = SparseMatrixBuilder::new(4, 4);
352 builder.add(0, 0, Complex64::new(1.0, 0.0));
353 builder.add(1, 1, Complex64::new(1.0, 0.0));
354 builder.add(2, 2, Complex64::new(1.0, 0.0));
355 builder.add(3, 3, Complex64::new(-1.0, 0.0));
356 builder.build()
357 }
358
359 pub fn rotation(axis: &str, angle: f64) -> Result<CSRMatrix> {
361 let (c, s) = (angle.cos(), angle.sin());
362 let half_angle = angle / 2.0;
363 let (ch, sh) = (half_angle.cos(), half_angle.sin());
364
365 let mut builder = SparseMatrixBuilder::new(2, 2);
366
367 match axis {
368 "x" | "X" => {
369 builder.add(0, 0, Complex64::new(ch, 0.0));
370 builder.add(0, 1, Complex64::new(0.0, -sh));
371 builder.add(1, 0, Complex64::new(0.0, -sh));
372 builder.add(1, 1, Complex64::new(ch, 0.0));
373 }
374 "y" | "Y" => {
375 builder.add(0, 0, Complex64::new(ch, 0.0));
376 builder.add(0, 1, Complex64::new(-sh, 0.0));
377 builder.add(1, 0, Complex64::new(sh, 0.0));
378 builder.add(1, 1, Complex64::new(ch, 0.0));
379 }
380 "z" | "Z" => {
381 builder.add(0, 0, Complex64::new(ch, -sh));
382 builder.add(1, 1, Complex64::new(ch, sh));
383 }
384 _ => {
385 return Err(SimulatorError::InvalidConfiguration(format!(
386 "Unknown rotation axis: {axis}"
387 )))
388 }
389 }
390
391 Ok(builder.build())
392 }
393
394 pub fn controlled_rotation(axis: &str, angle: f64) -> Result<CSRMatrix> {
396 let single_qubit = Self::rotation(axis, angle)?;
397
398 let mut builder = SparseMatrixBuilder::new(4, 4);
399
400 builder.add(0, 0, Complex64::new(1.0, 0.0));
402 builder.add(1, 1, Complex64::new(1.0, 0.0));
403
404 builder.add(2, 2, single_qubit.values[0]);
406 if single_qubit.values.len() > 1 {
407 builder.add(2, 3, single_qubit.values[1]);
408 }
409 if single_qubit.values.len() > 2 {
410 builder.add(3, 2, single_qubit.values[2]);
411 }
412 if single_qubit.values.len() > 3 {
413 builder.add(3, 3, single_qubit.values[3]);
414 }
415
416 Ok(builder.build())
417 }
418}
419
420pub fn apply_sparse_gate(
422 state: &mut Array1<Complex64>,
423 gate: &CSRMatrix,
424 qubits: &[usize],
425 num_qubits: usize,
426) -> Result<()> {
427 let gate_qubits = qubits.len();
428 let gate_dim = 1 << gate_qubits;
429
430 if gate.num_rows != gate_dim || gate.num_cols != gate_dim {
431 return Err(SimulatorError::DimensionMismatch(format!(
432 "Gate dimension {} doesn't match qubit count {}",
433 gate.num_rows, gate_qubits
434 )));
435 }
436
437 let mut masks = vec![0usize; gate_qubits];
439 for (i, &qubit) in qubits.iter().enumerate() {
440 masks[i] = 1 << qubit;
441 }
442
443 let state_dim = 1 << num_qubits;
445 let mut new_state = Array1::zeros(state_dim);
446
447 for i in 0..state_dim {
448 let mut gate_idx = 0;
450 for (j, &mask) in masks.iter().enumerate() {
451 if i & mask != 0 {
452 gate_idx |= 1 << j;
453 }
454 }
455
456 let row_start = gate.row_ptr[gate_idx];
458 let row_end = gate.row_ptr[gate_idx + 1];
459
460 for idx in row_start..row_end {
461 let gate_col = gate.col_indices[idx];
462 let gate_val = gate.values[idx];
463
464 let mut j = i;
466 for (k, &mask) in masks.iter().enumerate() {
467 if gate_col & (1 << k) != 0 {
468 j |= mask;
469 } else {
470 j &= !mask;
471 }
472 }
473
474 new_state[i] += gate_val * state[j];
475 }
476 }
477
478 state.assign(&new_state);
479 Ok(())
480}
481
482pub fn optimize_sparse_gates(gates: Vec<CSRMatrix>) -> Result<CSRMatrix> {
484 if gates.is_empty() {
485 return Err(SimulatorError::InvalidInput(
486 "Empty gate sequence".to_string(),
487 ));
488 }
489
490 let mut result = gates[0].clone();
491 for gate in gates.into_iter().skip(1) {
492 result = result.matmul(&gate)?;
493
494 result.values.retain(|&v| v.norm() > 1e-15);
496 }
497
498 Ok(result)
499}
500
501#[cfg(test)]
502mod tests {
503 use super::*;
504
505 #[test]
506 fn test_sparse_matrix_construction() {
507 let mut builder = SparseMatrixBuilder::new(3, 3);
508 builder.add(0, 0, Complex64::new(1.0, 0.0));
509 builder.add(1, 1, Complex64::new(2.0, 0.0));
510 builder.add(2, 2, Complex64::new(3.0, 0.0));
511 builder.add(0, 2, Complex64::new(4.0, 0.0));
512
513 let sparse = builder.build();
514 assert_eq!(sparse.nnz(), 4);
515 assert_eq!(sparse.num_rows, 3);
516 assert_eq!(sparse.num_cols, 3);
517 }
518
519 #[test]
520 fn test_sparse_gates() {
521 let x = SparseGates::x();
522 assert_eq!(x.nnz(), 2);
523
524 let cnot = SparseGates::cnot();
525 assert_eq!(cnot.nnz(), 4);
526
527 let rz = SparseGates::rotation("z", 0.5).unwrap();
528 assert_eq!(rz.nnz(), 2);
529 }
530
531 #[test]
532 fn test_sparse_matvec() {
533 let x = SparseGates::x();
534 let vec = Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]);
535
536 let result = x.matvec(&vec).unwrap();
537 assert!((result[0] - Complex64::new(0.0, 0.0)).norm() < 1e-10);
538 assert!((result[1] - Complex64::new(1.0, 0.0)).norm() < 1e-10);
539 }
540
541 #[test]
542 fn test_sparse_matmul() {
543 let x = SparseGates::x();
544 let z = SparseGates::z();
545
546 let xz = x.matmul(&z).unwrap();
547 let y_expected = SparseGates::y();
548
549 assert_eq!(xz.nnz(), y_expected.nnz());
551 }
552
553 #[test]
554 fn test_csr_to_dense() {
555 let cnot = SparseGates::cnot();
556 let dense = cnot.to_dense();
557
558 assert_eq!(dense.shape(), &[4, 4]);
559 assert!((dense[[0, 0]] - Complex64::new(1.0, 0.0)).norm() < 1e-10);
560 assert!((dense[[3, 2]] - Complex64::new(1.0, 0.0)).norm() < 1e-10);
561 }
562}