1use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::Complex64;
8use std::collections::HashMap;
9
10use crate::error::{Result, SimulatorError};
11
12#[derive(Debug, Clone)]
14pub struct CSRMatrix {
15 pub values: Vec<Complex64>,
17 pub col_indices: Vec<usize>,
19 pub row_ptr: Vec<usize>,
21 pub num_rows: usize,
23 pub num_cols: usize,
25}
26
27impl CSRMatrix {
28 pub fn new(
30 values: Vec<Complex64>,
31 col_indices: Vec<usize>,
32 row_ptr: Vec<usize>,
33 num_rows: usize,
34 num_cols: usize,
35 ) -> Self {
36 assert_eq!(values.len(), col_indices.len());
37 assert_eq!(row_ptr.len(), num_rows + 1);
38
39 Self {
40 values,
41 col_indices,
42 row_ptr,
43 num_rows,
44 num_cols,
45 }
46 }
47
48 pub fn from_dense(matrix: &Array2<Complex64>) -> Self {
50 let num_rows = matrix.nrows();
51 let num_cols = matrix.ncols();
52 let mut values = Vec::new();
53 let mut col_indices = Vec::new();
54 let mut row_ptr = vec![0];
55
56 for i in 0..num_rows {
57 for j in 0..num_cols {
58 let val = matrix[[i, j]];
59 if val.norm() > 1e-15 {
60 values.push(val);
61 col_indices.push(j);
62 }
63 }
64 row_ptr.push(values.len());
65 }
66
67 Self::new(values, col_indices, row_ptr, num_rows, num_cols)
68 }
69
70 pub fn to_dense(&self) -> Array2<Complex64> {
72 let mut dense = Array2::zeros((self.num_rows, self.num_cols));
73
74 for i in 0..self.num_rows {
75 let start = self.row_ptr[i];
76 let end = self.row_ptr[i + 1];
77
78 for idx in start..end {
79 dense[[i, self.col_indices[idx]]] = self.values[idx];
80 }
81 }
82
83 dense
84 }
85
86 pub fn nnz(&self) -> usize {
88 self.values.len()
89 }
90
91 pub fn matvec(&self, vec: &Array1<Complex64>) -> Result<Array1<Complex64>> {
93 if vec.len() != self.num_cols {
94 return Err(SimulatorError::DimensionMismatch(format!(
95 "Vector length {} doesn't match matrix columns {}",
96 vec.len(),
97 self.num_cols
98 )));
99 }
100
101 let mut result = Array1::zeros(self.num_rows);
102
103 for i in 0..self.num_rows {
104 let start = self.row_ptr[i];
105 let end = self.row_ptr[i + 1];
106
107 let mut sum = Complex64::new(0.0, 0.0);
108 for idx in start..end {
109 sum += self.values[idx] * vec[self.col_indices[idx]];
110 }
111 result[i] = sum;
112 }
113
114 Ok(result)
115 }
116
117 pub fn matmul(&self, other: &CSRMatrix) -> Result<CSRMatrix> {
119 if self.num_cols != other.num_rows {
120 return Err(SimulatorError::DimensionMismatch(format!(
121 "Matrix dimensions incompatible: {}x{} * {}x{}",
122 self.num_rows, self.num_cols, other.num_rows, other.num_cols
123 )));
124 }
125
126 let mut values = Vec::new();
127 let mut col_indices = Vec::new();
128 let mut row_ptr = vec![0];
129
130 let other_csc = other.to_csc();
132
133 for i in 0..self.num_rows {
134 let mut row_values: HashMap<usize, Complex64> = HashMap::new();
135
136 let a_start = self.row_ptr[i];
137 let a_end = self.row_ptr[i + 1];
138
139 for a_idx in a_start..a_end {
140 let k = self.col_indices[a_idx];
141 let a_val = self.values[a_idx];
142
143 let b_start = other_csc.col_ptr[k];
145 let b_end = other_csc.col_ptr[k + 1];
146
147 for b_idx in b_start..b_end {
148 let j = other_csc.row_indices[b_idx];
149 let b_val = other_csc.values[b_idx];
150
151 *row_values.entry(j).or_insert(Complex64::new(0.0, 0.0)) += a_val * b_val;
152 }
153 }
154
155 let mut sorted_cols: Vec<_> = row_values.into_iter().collect();
157 sorted_cols.sort_by_key(|(col, _)| *col);
158
159 for (col, val) in sorted_cols {
160 if val.norm() > 1e-15 {
161 values.push(val);
162 col_indices.push(col);
163 }
164 }
165
166 row_ptr.push(values.len());
167 }
168
169 Ok(CSRMatrix::new(
170 values,
171 col_indices,
172 row_ptr,
173 self.num_rows,
174 other.num_cols,
175 ))
176 }
177
178 fn to_csc(&self) -> CSCMatrix {
180 let mut values = Vec::new();
181 let mut row_indices = Vec::new();
182 let mut col_ptr = vec![0; self.num_cols + 1];
183
184 for &col in &self.col_indices {
186 col_ptr[col + 1] += 1;
187 }
188
189 for i in 1..=self.num_cols {
191 col_ptr[i] += col_ptr[i - 1];
192 }
193
194 let mut current_pos = col_ptr[0..self.num_cols].to_vec();
196 values.resize(self.nnz(), Complex64::new(0.0, 0.0));
197 row_indices.resize(self.nnz(), 0);
198
199 for i in 0..self.num_rows {
201 let start = self.row_ptr[i];
202 let end = self.row_ptr[i + 1];
203
204 for idx in start..end {
205 let col = self.col_indices[idx];
206 let pos = current_pos[col];
207
208 values[pos] = self.values[idx];
209 row_indices[pos] = i;
210 current_pos[col] += 1;
211 }
212 }
213
214 CSCMatrix {
215 values,
216 row_indices,
217 col_ptr,
218 num_rows: self.num_rows,
219 num_cols: self.num_cols,
220 }
221 }
222}
223
224#[derive(Debug, Clone)]
226struct CSCMatrix {
227 values: Vec<Complex64>,
228 row_indices: Vec<usize>,
229 col_ptr: Vec<usize>,
230 num_rows: usize,
231 num_cols: usize,
232}
233
234#[derive(Debug)]
236pub struct SparseMatrixBuilder {
237 triplets: Vec<(usize, usize, Complex64)>,
238 num_rows: usize,
239 num_cols: usize,
240}
241
242impl SparseMatrixBuilder {
243 pub fn new(num_rows: usize, num_cols: usize) -> Self {
245 Self {
246 triplets: Vec::new(),
247 num_rows,
248 num_cols,
249 }
250 }
251
252 pub fn add(&mut self, row: usize, col: usize, value: Complex64) {
254 if row < self.num_rows && col < self.num_cols && value.norm() > 1e-15 {
255 self.triplets.push((row, col, value));
256 }
257 }
258
259 pub fn set_value(&mut self, row: usize, col: usize, value: Complex64) {
261 self.add(row, col, value);
262 }
263
264 pub fn build(mut self) -> CSRMatrix {
266 self.triplets.sort_by_key(|(r, c, _)| (*r, *c));
268
269 let mut combined_triplets = Vec::new();
271 let mut last_pos: Option<(usize, usize)> = None;
272
273 for (r, c, v) in self.triplets {
274 if Some((r, c)) == last_pos {
275 if let Some(last) = combined_triplets.last_mut() {
276 let (_, _, ref mut last_val) = last;
277 *last_val += v;
278 }
279 } else {
280 combined_triplets.push((r, c, v));
281 last_pos = Some((r, c));
282 }
283 }
284
285 let mut values = Vec::new();
287 let mut col_indices = Vec::new();
288 let mut row_ptr = vec![0];
289 let mut current_row = 0;
290
291 for (r, c, v) in combined_triplets {
292 while current_row < r {
293 row_ptr.push(values.len());
294 current_row += 1;
295 }
296
297 if v.norm() > 1e-15 {
298 values.push(v);
299 col_indices.push(c);
300 }
301 }
302
303 while row_ptr.len() <= self.num_rows {
304 row_ptr.push(values.len());
305 }
306
307 CSRMatrix::new(values, col_indices, row_ptr, self.num_rows, self.num_cols)
308 }
309}
310
311pub struct SparseGates;
313
314impl SparseGates {
315 pub fn x() -> CSRMatrix {
317 let mut builder = SparseMatrixBuilder::new(2, 2);
318 builder.add(0, 1, Complex64::new(1.0, 0.0));
319 builder.add(1, 0, Complex64::new(1.0, 0.0));
320 builder.build()
321 }
322
323 pub fn y() -> CSRMatrix {
325 let mut builder = SparseMatrixBuilder::new(2, 2);
326 builder.add(0, 1, Complex64::new(0.0, -1.0));
327 builder.add(1, 0, Complex64::new(0.0, 1.0));
328 builder.build()
329 }
330
331 pub fn z() -> CSRMatrix {
333 let mut builder = SparseMatrixBuilder::new(2, 2);
334 builder.add(0, 0, Complex64::new(1.0, 0.0));
335 builder.add(1, 1, Complex64::new(-1.0, 0.0));
336 builder.build()
337 }
338
339 pub fn cnot() -> CSRMatrix {
341 let mut builder = SparseMatrixBuilder::new(4, 4);
342 builder.add(0, 0, Complex64::new(1.0, 0.0));
343 builder.add(1, 1, Complex64::new(1.0, 0.0));
344 builder.add(2, 3, Complex64::new(1.0, 0.0));
345 builder.add(3, 2, Complex64::new(1.0, 0.0));
346 builder.build()
347 }
348
349 pub fn cz() -> CSRMatrix {
351 let mut builder = SparseMatrixBuilder::new(4, 4);
352 builder.add(0, 0, Complex64::new(1.0, 0.0));
353 builder.add(1, 1, Complex64::new(1.0, 0.0));
354 builder.add(2, 2, Complex64::new(1.0, 0.0));
355 builder.add(3, 3, Complex64::new(-1.0, 0.0));
356 builder.build()
357 }
358
359 pub fn rotation(axis: &str, angle: f64) -> Result<CSRMatrix> {
361 let (c, s) = (angle.cos(), angle.sin());
362 let half_angle = angle / 2.0;
363 let (ch, sh) = (half_angle.cos(), half_angle.sin());
364
365 let mut builder = SparseMatrixBuilder::new(2, 2);
366
367 match axis {
368 "x" | "X" => {
369 builder.add(0, 0, Complex64::new(ch, 0.0));
370 builder.add(0, 1, Complex64::new(0.0, -sh));
371 builder.add(1, 0, Complex64::new(0.0, -sh));
372 builder.add(1, 1, Complex64::new(ch, 0.0));
373 }
374 "y" | "Y" => {
375 builder.add(0, 0, Complex64::new(ch, 0.0));
376 builder.add(0, 1, Complex64::new(-sh, 0.0));
377 builder.add(1, 0, Complex64::new(sh, 0.0));
378 builder.add(1, 1, Complex64::new(ch, 0.0));
379 }
380 "z" | "Z" => {
381 builder.add(0, 0, Complex64::new(ch, -sh));
382 builder.add(1, 1, Complex64::new(ch, sh));
383 }
384 _ => {
385 return Err(SimulatorError::InvalidConfiguration(format!(
386 "Unknown rotation axis: {}",
387 axis
388 )))
389 }
390 }
391
392 Ok(builder.build())
393 }
394
395 pub fn controlled_rotation(axis: &str, angle: f64) -> Result<CSRMatrix> {
397 let single_qubit = Self::rotation(axis, angle)?;
398
399 let mut builder = SparseMatrixBuilder::new(4, 4);
400
401 builder.add(0, 0, Complex64::new(1.0, 0.0));
403 builder.add(1, 1, Complex64::new(1.0, 0.0));
404
405 builder.add(2, 2, single_qubit.values[0]);
407 if single_qubit.values.len() > 1 {
408 builder.add(2, 3, single_qubit.values[1]);
409 }
410 if single_qubit.values.len() > 2 {
411 builder.add(3, 2, single_qubit.values[2]);
412 }
413 if single_qubit.values.len() > 3 {
414 builder.add(3, 3, single_qubit.values[3]);
415 }
416
417 Ok(builder.build())
418 }
419}
420
421pub fn apply_sparse_gate(
423 state: &mut Array1<Complex64>,
424 gate: &CSRMatrix,
425 qubits: &[usize],
426 num_qubits: usize,
427) -> Result<()> {
428 let gate_qubits = qubits.len();
429 let gate_dim = 1 << gate_qubits;
430
431 if gate.num_rows != gate_dim || gate.num_cols != gate_dim {
432 return Err(SimulatorError::DimensionMismatch(format!(
433 "Gate dimension {} doesn't match qubit count {}",
434 gate.num_rows, gate_qubits
435 )));
436 }
437
438 let mut masks = vec![0usize; gate_qubits];
440 for (i, &qubit) in qubits.iter().enumerate() {
441 masks[i] = 1 << qubit;
442 }
443
444 let state_dim = 1 << num_qubits;
446 let mut new_state = Array1::zeros(state_dim);
447
448 for i in 0..state_dim {
449 let mut gate_idx = 0;
451 for (j, &mask) in masks.iter().enumerate() {
452 if i & mask != 0 {
453 gate_idx |= 1 << j;
454 }
455 }
456
457 let row_start = gate.row_ptr[gate_idx];
459 let row_end = gate.row_ptr[gate_idx + 1];
460
461 for idx in row_start..row_end {
462 let gate_col = gate.col_indices[idx];
463 let gate_val = gate.values[idx];
464
465 let mut j = i;
467 for (k, &mask) in masks.iter().enumerate() {
468 if gate_col & (1 << k) != 0 {
469 j |= mask;
470 } else {
471 j &= !mask;
472 }
473 }
474
475 new_state[i] += gate_val * state[j];
476 }
477 }
478
479 state.assign(&new_state);
480 Ok(())
481}
482
483pub fn optimize_sparse_gates(gates: Vec<CSRMatrix>) -> Result<CSRMatrix> {
485 if gates.is_empty() {
486 return Err(SimulatorError::InvalidInput(
487 "Empty gate sequence".to_string(),
488 ));
489 }
490
491 let mut result = gates[0].clone();
492 for gate in gates.into_iter().skip(1) {
493 result = result.matmul(&gate)?;
494
495 result.values.retain(|&v| v.norm() > 1e-15);
497 }
498
499 Ok(result)
500}
501
502#[cfg(test)]
503mod tests {
504 use super::*;
505
506 #[test]
507 fn test_sparse_matrix_construction() {
508 let mut builder = SparseMatrixBuilder::new(3, 3);
509 builder.add(0, 0, Complex64::new(1.0, 0.0));
510 builder.add(1, 1, Complex64::new(2.0, 0.0));
511 builder.add(2, 2, Complex64::new(3.0, 0.0));
512 builder.add(0, 2, Complex64::new(4.0, 0.0));
513
514 let sparse = builder.build();
515 assert_eq!(sparse.nnz(), 4);
516 assert_eq!(sparse.num_rows, 3);
517 assert_eq!(sparse.num_cols, 3);
518 }
519
520 #[test]
521 fn test_sparse_gates() {
522 let x = SparseGates::x();
523 assert_eq!(x.nnz(), 2);
524
525 let cnot = SparseGates::cnot();
526 assert_eq!(cnot.nnz(), 4);
527
528 let rz = SparseGates::rotation("z", 0.5).unwrap();
529 assert_eq!(rz.nnz(), 2);
530 }
531
532 #[test]
533 fn test_sparse_matvec() {
534 let x = SparseGates::x();
535 let vec = Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]);
536
537 let result = x.matvec(&vec).unwrap();
538 assert!((result[0] - Complex64::new(0.0, 0.0)).norm() < 1e-10);
539 assert!((result[1] - Complex64::new(1.0, 0.0)).norm() < 1e-10);
540 }
541
542 #[test]
543 fn test_sparse_matmul() {
544 let x = SparseGates::x();
545 let z = SparseGates::z();
546
547 let xz = x.matmul(&z).unwrap();
548 let y_expected = SparseGates::y();
549
550 assert_eq!(xz.nnz(), y_expected.nnz());
552 }
553
554 #[test]
555 fn test_csr_to_dense() {
556 let cnot = SparseGates::cnot();
557 let dense = cnot.to_dense();
558
559 assert_eq!(dense.shape(), &[4, 4]);
560 assert!((dense[[0, 0]] - Complex64::new(1.0, 0.0)).norm() < 1e-10);
561 assert!((dense[[3, 2]] - Complex64::new(1.0, 0.0)).norm() < 1e-10);
562 }
563}