1use crate::traits::{ComplexField, LinearOperator};
9use ndarray::{Array1, Array2};
10use num_traits::{FromPrimitive, Zero};
11use std::ops::Range;
12
13#[cfg(feature = "rayon")]
14use rayon::prelude::*;
15
16#[derive(Debug, Clone)]
21pub struct CsrMatrix<T: ComplexField> {
22 pub num_rows: usize,
24 pub num_cols: usize,
26 pub values: Vec<T>,
28 pub col_indices: Vec<usize>,
30 pub row_ptrs: Vec<usize>,
33}
34
35impl<T: ComplexField> CsrMatrix<T> {
36 pub fn new(num_rows: usize, num_cols: usize) -> Self {
38 Self {
39 num_rows,
40 num_cols,
41 values: Vec::new(),
42 col_indices: Vec::new(),
43 row_ptrs: vec![0; num_rows + 1],
44 }
45 }
46
47 pub fn with_capacity(num_rows: usize, num_cols: usize, nnz_estimate: usize) -> Self {
49 Self {
50 num_rows,
51 num_cols,
52 values: Vec::with_capacity(nnz_estimate),
53 col_indices: Vec::with_capacity(nnz_estimate),
54 row_ptrs: vec![0; num_rows + 1],
55 }
56 }
57
58 pub fn from_raw_parts(
70 num_rows: usize,
71 num_cols: usize,
72 row_ptrs: Vec<usize>,
73 col_indices: Vec<usize>,
74 values: Vec<T>,
75 ) -> Self {
76 assert_eq!(
77 row_ptrs.len(),
78 num_rows + 1,
79 "row_ptrs must have num_rows + 1 elements"
80 );
81 assert_eq!(
82 col_indices.len(),
83 values.len(),
84 "col_indices and values must have the same length"
85 );
86 assert_eq!(
87 row_ptrs[num_rows],
88 values.len(),
89 "row_ptrs[num_rows] must equal nnz"
90 );
91
92 Self {
93 num_rows,
94 num_cols,
95 row_ptrs,
96 col_indices,
97 values,
98 }
99 }
100
101 pub fn from_dense(dense: &Array2<T>, threshold: T::Real) -> Self {
105 let num_rows = dense.nrows();
106 let num_cols = dense.ncols();
107
108 let mut values = Vec::new();
109 let mut col_indices = Vec::new();
110 let mut row_ptrs = vec![0usize; num_rows + 1];
111
112 for i in 0..num_rows {
113 for j in 0..num_cols {
114 let val = dense[[i, j]];
115 if val.norm() > threshold {
116 values.push(val);
117 col_indices.push(j);
118 }
119 }
120 row_ptrs[i + 1] = values.len();
121 }
122
123 Self {
124 num_rows,
125 num_cols,
126 values,
127 col_indices,
128 row_ptrs,
129 }
130 }
131
132 pub fn from_triplets(
136 num_rows: usize,
137 num_cols: usize,
138 mut triplets: Vec<(usize, usize, T)>,
139 ) -> Self {
140 if triplets.is_empty() {
141 return Self::new(num_rows, num_cols);
142 }
143
144 triplets.sort_by(|a, b| {
146 if a.0 != b.0 {
147 a.0.cmp(&b.0)
148 } else {
149 a.1.cmp(&b.1)
150 }
151 });
152
153 let mut values = Vec::with_capacity(triplets.len());
154 let mut col_indices = Vec::with_capacity(triplets.len());
155 let mut row_ptrs = vec![0usize; num_rows + 1];
156
157 let mut prev_row = usize::MAX;
158 let mut prev_col = usize::MAX;
159
160 for (row, col, val) in triplets {
161 if row == prev_row && col == prev_col {
162 if let Some(last) = values.last_mut() {
164 *last += val;
165 }
166 } else {
167 if row != prev_row {
169 let start = if prev_row == usize::MAX {
170 0
171 } else {
172 prev_row + 1
173 };
174 for item in row_ptrs.iter_mut().take(row + 1).skip(start) {
175 *item = values.len();
176 }
177 }
178
179 values.push(val);
181 col_indices.push(col);
182
183 prev_row = row;
184 prev_col = col;
185 }
186 }
187
188 let last_row = if prev_row == usize::MAX {
190 0
191 } else {
192 prev_row + 1
193 };
194 for item in row_ptrs.iter_mut().take(num_rows + 1).skip(last_row) {
195 *item = values.len();
196 }
197
198 Self {
199 num_rows,
200 num_cols,
201 values,
202 col_indices,
203 row_ptrs,
204 }
205 }
206
207 pub fn nnz(&self) -> usize {
209 self.values.len()
210 }
211
212 pub fn sparsity(&self) -> f64 {
214 let total = self.num_rows * self.num_cols;
215 if total == 0 {
216 0.0
217 } else {
218 self.nnz() as f64 / total as f64
219 }
220 }
221
222 pub fn row_range(&self, row: usize) -> Range<usize> {
224 self.row_ptrs[row]..self.row_ptrs[row + 1]
225 }
226
227 pub fn row_entries(&self, row: usize) -> impl Iterator<Item = (usize, T)> + '_ {
229 let range = self.row_range(row);
230 self.col_indices[range.clone()]
231 .iter()
232 .copied()
233 .zip(self.values[range].iter().copied())
234 }
235
236 pub fn matvec(&self, x: &Array1<T>) -> Array1<T> {
241 assert_eq!(x.len(), self.num_cols, "Input vector size mismatch");
242
243 #[cfg(feature = "rayon")]
245 {
246 if self.num_rows >= 246 {
248 return self.matvec_parallel(x);
249 }
250 }
251
252 self.matvec_sequential(x)
253 }
254
255 fn matvec_sequential(&self, x: &Array1<T>) -> Array1<T> {
257 let mut y = Array1::from_elem(self.num_rows, T::zero());
258
259 for i in 0..self.num_rows {
260 let mut sum = T::zero();
261 for idx in self.row_range(i) {
262 let j = self.col_indices[idx];
263 sum += self.values[idx] * x[j];
264 }
265 y[i] = sum;
266 }
267
268 y
269 }
270
271 #[cfg(feature = "rayon")]
273 fn matvec_parallel(&self, x: &Array1<T>) -> Array1<T>
274 where
275 T: Send + Sync,
276 {
277 let x_slice = x.as_slice().expect("Array should be contiguous");
278
279 let results: Vec<T> = (0..self.num_rows)
280 .into_par_iter()
281 .map(|i| {
282 let mut sum = T::zero();
283 for idx in self.row_range(i) {
284 let j = self.col_indices[idx];
285 sum += self.values[idx] * x_slice[j];
286 }
287 sum
288 })
289 .collect();
290
291 Array1::from_vec(results)
292 }
293
294 pub fn matvec_add(&self, x: &Array1<T>, y: &mut Array1<T>) {
296 assert_eq!(x.len(), self.num_cols, "Input vector size mismatch");
297 assert_eq!(y.len(), self.num_rows, "Output vector size mismatch");
298
299 for i in 0..self.num_rows {
300 for idx in self.row_range(i) {
301 let j = self.col_indices[idx];
302 y[i] += self.values[idx] * x[j];
303 }
304 }
305 }
306
307 pub fn matvec_transpose(&self, x: &Array1<T>) -> Array1<T> {
309 assert_eq!(x.len(), self.num_rows, "Input vector size mismatch");
310
311 let mut y = Array1::from_elem(self.num_cols, T::zero());
312
313 for i in 0..self.num_rows {
314 for idx in self.row_range(i) {
315 let j = self.col_indices[idx];
316 y[j] += self.values[idx] * x[i];
317 }
318 }
319
320 y
321 }
322
323 pub fn matvec_hermitian(&self, x: &Array1<T>) -> Array1<T> {
325 assert_eq!(x.len(), self.num_rows, "Input vector size mismatch");
326
327 let mut y = Array1::from_elem(self.num_cols, T::zero());
328
329 for i in 0..self.num_rows {
330 for idx in self.row_range(i) {
331 let j = self.col_indices[idx];
332 y[j] += self.values[idx].conj() * x[i];
333 }
334 }
335
336 y
337 }
338
339 pub fn get(&self, i: usize, j: usize) -> T {
341 for idx in self.row_range(i) {
342 if self.col_indices[idx] == j {
343 return self.values[idx];
344 }
345 }
346 T::zero()
347 }
348
349 pub fn diagonal(&self) -> Array1<T> {
351 let n = self.num_rows.min(self.num_cols);
352 let mut diag = Array1::from_elem(n, T::zero());
353
354 for i in 0..n {
355 for idx in self.row_range(i) {
356 if self.col_indices[idx] == i {
357 diag[i] = self.values[idx];
358 break;
359 }
360 }
361 }
362
363 diag
364 }
365
366 pub fn scale(&mut self, scalar: T) {
368 for val in &mut self.values {
369 *val *= scalar;
370 }
371 }
372
373 pub fn add_diagonal(&mut self, scalar: T) {
375 let n = self.num_rows.min(self.num_cols);
376
377 for i in 0..n {
378 for idx in self.row_range(i) {
379 if self.col_indices[idx] == i {
380 self.values[idx] += scalar;
381 break;
382 }
383 }
384 }
385 }
386
387 pub fn identity(n: usize) -> Self {
389 Self {
390 num_rows: n,
391 num_cols: n,
392 values: vec![T::one(); n],
393 col_indices: (0..n).collect(),
394 row_ptrs: (0..=n).collect(),
395 }
396 }
397
398 pub fn from_diagonal(diag: &Array1<T>) -> Self {
400 let n = diag.len();
401 Self {
402 num_rows: n,
403 num_cols: n,
404 values: diag.to_vec(),
405 col_indices: (0..n).collect(),
406 row_ptrs: (0..=n).collect(),
407 }
408 }
409
410 pub fn to_dense(&self) -> Array2<T> {
412 let mut dense = Array2::from_elem((self.num_rows, self.num_cols), T::zero());
413
414 for i in 0..self.num_rows {
415 for idx in self.row_range(i) {
416 let j = self.col_indices[idx];
417 dense[[i, j]] = self.values[idx];
418 }
419 }
420
421 dense
422 }
423}
424
425impl<T: ComplexField> LinearOperator<T> for CsrMatrix<T> {
426 fn num_rows(&self) -> usize {
427 self.num_rows
428 }
429
430 fn num_cols(&self) -> usize {
431 self.num_cols
432 }
433
434 fn apply(&self, x: &Array1<T>) -> Array1<T> {
435 self.matvec(x)
436 }
437
438 fn apply_transpose(&self, x: &Array1<T>) -> Array1<T> {
439 self.matvec_transpose(x)
440 }
441
442 fn apply_hermitian(&self, x: &Array1<T>) -> Array1<T> {
443 self.matvec_hermitian(x)
444 }
445}
446
447pub struct CsrBuilder<T: ComplexField> {
449 num_rows: usize,
450 num_cols: usize,
451 values: Vec<T>,
452 col_indices: Vec<usize>,
453 row_ptrs: Vec<usize>,
454 current_row: usize,
455}
456
457impl<T: ComplexField> CsrBuilder<T> {
458 pub fn new(num_rows: usize, num_cols: usize) -> Self {
460 Self {
461 num_rows,
462 num_cols,
463 values: Vec::new(),
464 col_indices: Vec::new(),
465 row_ptrs: vec![0],
466 current_row: 0,
467 }
468 }
469
470 pub fn with_capacity(num_rows: usize, num_cols: usize, nnz_estimate: usize) -> Self {
472 Self {
473 num_rows,
474 num_cols,
475 values: Vec::with_capacity(nnz_estimate),
476 col_indices: Vec::with_capacity(nnz_estimate),
477 row_ptrs: Vec::with_capacity(num_rows + 1),
478 current_row: 0,
479 }
480 }
481
482 pub fn add_row_entries(&mut self, entries: impl Iterator<Item = (usize, T)>) {
484 for (col, val) in entries {
485 if val.norm() > T::Real::zero() {
486 self.values.push(val);
487 self.col_indices.push(col);
488 }
489 }
490 self.row_ptrs.push(self.values.len());
491 self.current_row += 1;
492 }
493
494 pub fn finish(mut self) -> CsrMatrix<T> {
496 while self.current_row < self.num_rows {
498 self.row_ptrs.push(self.values.len());
499 self.current_row += 1;
500 }
501
502 CsrMatrix {
503 num_rows: self.num_rows,
504 num_cols: self.num_cols,
505 values: self.values,
506 col_indices: self.col_indices,
507 row_ptrs: self.row_ptrs,
508 }
509 }
510}
511
512#[derive(Debug, Clone)]
517pub struct BlockedCsr<T: ComplexField> {
518 pub num_rows: usize,
520 pub num_cols: usize,
522 pub block_size: usize,
524 pub num_block_rows: usize,
526 pub num_block_cols: usize,
528 pub blocks: Vec<Array2<T>>,
530 pub block_col_indices: Vec<usize>,
532 pub block_row_ptrs: Vec<usize>,
534}
535
536impl<T: ComplexField> BlockedCsr<T> {
537 pub fn new(num_rows: usize, num_cols: usize, block_size: usize) -> Self {
539 let num_block_rows = num_rows.div_ceil(block_size);
540 let num_block_cols = num_cols.div_ceil(block_size);
541
542 Self {
543 num_rows,
544 num_cols,
545 block_size,
546 num_block_rows,
547 num_block_cols,
548 blocks: Vec::new(),
549 block_col_indices: Vec::new(),
550 block_row_ptrs: vec![0; num_block_rows + 1],
551 }
552 }
553
554 pub fn matvec(&self, x: &Array1<T>) -> Array1<T> {
556 assert_eq!(x.len(), self.num_cols, "Input vector size mismatch");
557
558 let mut y = Array1::from_elem(self.num_rows, T::zero());
559
560 for block_i in 0..self.num_block_rows {
561 let row_start = block_i * self.block_size;
562 let row_end = (row_start + self.block_size).min(self.num_rows);
563 let local_rows = row_end - row_start;
564
565 for idx in self.block_row_ptrs[block_i]..self.block_row_ptrs[block_i + 1] {
566 let block_j = self.block_col_indices[idx];
567 let block = &self.blocks[idx];
568
569 let col_start = block_j * self.block_size;
570 let col_end = (col_start + self.block_size).min(self.num_cols);
571 let local_cols = col_end - col_start;
572
573 let x_local: Array1<T> = Array1::from_iter((col_start..col_end).map(|j| x[j]));
575
576 for i in 0..local_rows {
578 let mut sum = T::zero();
579 for j in 0..local_cols {
580 sum += block[[i, j]] * x_local[j];
581 }
582 y[row_start + i] += sum;
583 }
584 }
585 }
586
587 y
588 }
589}
590
591impl<T: ComplexField> CsrMatrix<T> {
598 pub fn matmul(&self, other: &CsrMatrix<T>) -> CsrMatrix<T> {
600 assert_eq!(
601 self.num_cols, other.num_rows,
602 "Matrix dimension mismatch: A.cols ({}) != B.rows ({})",
603 self.num_cols, other.num_rows
604 );
605
606 let m = self.num_rows;
607 let n = other.num_cols;
608
609 if m == 0 || n == 0 || self.nnz() == 0 || other.nnz() == 0 {
610 return CsrMatrix::new(m, n);
611 }
612
613 let tol = T::Real::from_f64(1e-15).unwrap();
614
615 let mut triplets: Vec<(usize, usize, T)> = Vec::with_capacity(self.nnz() * 2);
616
617 let mut sparse_accumulator = vec![T::zero(); n];
619 let mut active_indices = Vec::with_capacity(n);
620 let mut occupied = vec![false; n];
621
622 for i in 0..m {
623 for (k, a_ik) in self.row_entries(i) {
624 for (j, b_kj) in other.row_entries(k) {
625 if !occupied[j] {
626 occupied[j] = true;
627 active_indices.push(j);
628 }
629 sparse_accumulator[j] += a_ik * b_kj;
630 }
631 }
632
633 if active_indices.is_empty() {
634 continue;
635 }
636
637 active_indices.sort_unstable();
639
640 for &j in &active_indices {
641 let val = sparse_accumulator[j];
642 if val.norm() > tol {
643 triplets.push((i, j, val));
644 }
645 sparse_accumulator[j] = T::zero();
647 occupied[j] = false;
648 }
649 active_indices.clear();
650 }
651
652 CsrMatrix::from_triplets(m, n, triplets)
653 }
654}
655
656#[cfg(test)]
657mod tests {
658 use super::*;
659 use approx::assert_relative_eq;
660 use ndarray::array;
661 use num_complex::Complex64;
662
663 #[test]
664 fn test_csr_from_dense() {
665 let dense = array![
666 [
667 Complex64::new(1.0, 0.0),
668 Complex64::new(0.0, 0.0),
669 Complex64::new(2.0, 0.0)
670 ],
671 [
672 Complex64::new(0.0, 0.0),
673 Complex64::new(3.0, 0.0),
674 Complex64::new(0.0, 0.0)
675 ],
676 [
677 Complex64::new(4.0, 0.0),
678 Complex64::new(0.0, 0.0),
679 Complex64::new(5.0, 0.0)
680 ],
681 ];
682
683 let csr = CsrMatrix::from_dense(&dense, 1e-15);
684
685 assert_eq!(csr.num_rows, 3);
686 assert_eq!(csr.num_cols, 3);
687 assert_eq!(csr.nnz(), 5);
688
689 assert_relative_eq!(csr.get(0, 0).re, 1.0);
690 assert_relative_eq!(csr.get(0, 2).re, 2.0);
691 assert_relative_eq!(csr.get(1, 1).re, 3.0);
692 assert_relative_eq!(csr.get(2, 0).re, 4.0);
693 assert_relative_eq!(csr.get(2, 2).re, 5.0);
694 }
695
696 #[test]
697 fn test_csr_matvec() {
698 let dense = array![
699 [Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)],
700 [Complex64::new(3.0, 0.0), Complex64::new(4.0, 0.0)],
701 ];
702
703 let csr = CsrMatrix::from_dense(&dense, 1e-15);
704 let x = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
705
706 let y = csr.matvec(&x);
707
708 assert_relative_eq!(y[0].re, 5.0, epsilon = 1e-10);
711 assert_relative_eq!(y[1].re, 11.0, epsilon = 1e-10);
712 }
713
714 #[test]
715 fn test_csr_from_triplets() {
716 let triplets = vec![
717 (0, 0, Complex64::new(1.0, 0.0)),
718 (0, 2, Complex64::new(2.0, 0.0)),
719 (1, 1, Complex64::new(3.0, 0.0)),
720 (2, 0, Complex64::new(4.0, 0.0)),
721 (2, 2, Complex64::new(5.0, 0.0)),
722 ];
723
724 let csr = CsrMatrix::from_triplets(3, 3, triplets);
725
726 assert_eq!(csr.nnz(), 5);
727 assert_relative_eq!(csr.get(0, 0).re, 1.0);
728 assert_relative_eq!(csr.get(1, 1).re, 3.0);
729 }
730
731 #[test]
732 fn test_csr_triplets_duplicate() {
733 let triplets = vec![
734 (0, 0, Complex64::new(1.0, 0.0)),
735 (0, 0, Complex64::new(2.0, 0.0)), (1, 1, Complex64::new(3.0, 0.0)),
737 ];
738
739 let csr = CsrMatrix::from_triplets(2, 2, triplets);
740
741 assert_relative_eq!(csr.get(0, 0).re, 3.0); }
743
744 #[test]
745 fn test_csr_identity() {
746 let id: CsrMatrix<Complex64> = CsrMatrix::identity(3);
747
748 assert_eq!(id.nnz(), 3);
749 assert_relative_eq!(id.get(0, 0).re, 1.0);
750 assert_relative_eq!(id.get(1, 1).re, 1.0);
751 assert_relative_eq!(id.get(2, 2).re, 1.0);
752 assert_relative_eq!(id.get(0, 1).norm(), 0.0);
753 }
754
755 #[test]
756 fn test_csr_builder() {
757 let mut builder: CsrBuilder<Complex64> = CsrBuilder::new(3, 3);
758
759 builder.add_row_entries(
760 [(0, Complex64::new(1.0, 0.0)), (2, Complex64::new(2.0, 0.0))].into_iter(),
761 );
762 builder.add_row_entries([(1, Complex64::new(3.0, 0.0))].into_iter());
763 builder.add_row_entries(
764 [(0, Complex64::new(4.0, 0.0)), (2, Complex64::new(5.0, 0.0))].into_iter(),
765 );
766
767 let csr = builder.finish();
768
769 assert_eq!(csr.nnz(), 5);
770 assert_relative_eq!(csr.get(0, 0).re, 1.0);
771 assert_relative_eq!(csr.get(1, 1).re, 3.0);
772 }
773
774 #[test]
775 fn test_csr_to_dense_roundtrip() {
776 let original = array![
777 [Complex64::new(1.0, 0.5), Complex64::new(0.0, 0.0)],
778 [Complex64::new(2.0, -1.0), Complex64::new(3.0, 0.0)],
779 ];
780
781 let csr = CsrMatrix::from_dense(&original, 1e-15);
782 let recovered = csr.to_dense();
783
784 for i in 0..2 {
785 for j in 0..2 {
786 assert_relative_eq!(
787 (original[[i, j]] - recovered[[i, j]]).norm(),
788 0.0,
789 epsilon = 1e-10
790 );
791 }
792 }
793 }
794
795 #[test]
796 fn test_linear_operator_impl() {
797 let dense = array![
798 [Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)],
799 [Complex64::new(3.0, 0.0), Complex64::new(4.0, 0.0)],
800 ];
801
802 let csr = CsrMatrix::from_dense(&dense, 1e-15);
803 let x = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
804
805 let y = csr.apply(&x);
807 assert_relative_eq!(y[0].re, 5.0, epsilon = 1e-10);
808 assert_relative_eq!(y[1].re, 11.0, epsilon = 1e-10);
809
810 assert!(csr.is_square());
811 assert_eq!(csr.num_rows(), 2);
812 assert_eq!(csr.num_cols(), 2);
813 }
814
815 #[test]
816 fn test_f64_csr() {
817 let dense = array![[1.0_f64, 2.0], [3.0, 4.0],];
818
819 let csr = CsrMatrix::from_dense(&dense, 1e-15);
820 let x = array![1.0_f64, 2.0];
821
822 let y = csr.matvec(&x);
823 assert_relative_eq!(y[0], 5.0, epsilon = 1e-10);
824 assert_relative_eq!(y[1], 11.0, epsilon = 1e-10);
825 }
826}