1use crate::traits::{ComplexField, LinearOperator};
9use ndarray::{Array1, Array2};
10use num_traits::{FromPrimitive, Zero};
11use std::ops::Range;
12
13#[cfg(feature = "rayon")]
14use rayon::prelude::*;
15
16#[derive(Debug, Clone)]
21pub struct CsrMatrix<T: ComplexField> {
22 pub num_rows: usize,
24 pub num_cols: usize,
26 pub values: Vec<T>,
28 pub col_indices: Vec<usize>,
30 pub row_ptrs: Vec<usize>,
33}
34
35impl<T: ComplexField> CsrMatrix<T> {
36 pub fn new(num_rows: usize, num_cols: usize) -> Self {
38 Self {
39 num_rows,
40 num_cols,
41 values: Vec::new(),
42 col_indices: Vec::new(),
43 row_ptrs: vec![0; num_rows + 1],
44 }
45 }
46
47 pub fn with_capacity(num_rows: usize, num_cols: usize, nnz_estimate: usize) -> Self {
49 Self {
50 num_rows,
51 num_cols,
52 values: Vec::with_capacity(nnz_estimate),
53 col_indices: Vec::with_capacity(nnz_estimate),
54 row_ptrs: vec![0; num_rows + 1],
55 }
56 }
57
58 pub fn from_raw_parts(
70 num_rows: usize,
71 num_cols: usize,
72 row_ptrs: Vec<usize>,
73 col_indices: Vec<usize>,
74 values: Vec<T>,
75 ) -> Self {
76 assert_eq!(
77 row_ptrs.len(),
78 num_rows + 1,
79 "row_ptrs must have num_rows + 1 elements"
80 );
81 assert_eq!(
82 col_indices.len(),
83 values.len(),
84 "col_indices and values must have the same length"
85 );
86 assert_eq!(
87 row_ptrs[num_rows],
88 values.len(),
89 "row_ptrs[num_rows] must equal nnz"
90 );
91
92 Self {
93 num_rows,
94 num_cols,
95 row_ptrs,
96 col_indices,
97 values,
98 }
99 }
100
101 pub fn from_dense(dense: &Array2<T>, threshold: T::Real) -> Self {
105 let num_rows = dense.nrows();
106 let num_cols = dense.ncols();
107
108 let mut values = Vec::new();
109 let mut col_indices = Vec::new();
110 let mut row_ptrs = vec![0usize; num_rows + 1];
111
112 for i in 0..num_rows {
113 for j in 0..num_cols {
114 let val = dense[[i, j]];
115 if val.norm() > threshold {
116 values.push(val);
117 col_indices.push(j);
118 }
119 }
120 row_ptrs[i + 1] = values.len();
121 }
122
123 Self {
124 num_rows,
125 num_cols,
126 values,
127 col_indices,
128 row_ptrs,
129 }
130 }
131
132 pub fn from_triplets(
136 num_rows: usize,
137 num_cols: usize,
138 mut triplets: Vec<(usize, usize, T)>,
139 ) -> Self {
140 if triplets.is_empty() {
141 return Self::new(num_rows, num_cols);
142 }
143
144 triplets.sort_by(|a, b| {
146 if a.0 != b.0 {
147 a.0.cmp(&b.0)
148 } else {
149 a.1.cmp(&b.1)
150 }
151 });
152
153 let mut values = Vec::with_capacity(triplets.len());
154 let mut col_indices = Vec::with_capacity(triplets.len());
155 let mut row_ptrs = vec![0usize; num_rows + 1];
156
157 let mut prev_row = usize::MAX;
158 let mut prev_col = usize::MAX;
159
160 for (row, col, val) in triplets {
161 if row == prev_row && col == prev_col {
162 if let Some(last) = values.last_mut() {
164 *last += val;
165 }
166 } else {
167 values.push(val);
169 col_indices.push(col);
170
171 if row != prev_row {
173 let start = if prev_row == usize::MAX {
174 0
175 } else {
176 prev_row + 1
177 };
178 for item in row_ptrs.iter_mut().take(row + 1).skip(start) {
179 *item = values.len() - 1;
180 }
181 }
182
183 prev_row = row;
184 prev_col = col;
185 }
186 }
187
188 let last_row = if prev_row == usize::MAX {
190 0
191 } else {
192 prev_row + 1
193 };
194 for item in row_ptrs.iter_mut().take(num_rows + 1).skip(last_row) {
195 *item = values.len();
196 }
197
198 Self {
199 num_rows,
200 num_cols,
201 values,
202 col_indices,
203 row_ptrs,
204 }
205 }
206
207 pub fn nnz(&self) -> usize {
209 self.values.len()
210 }
211
212 pub fn sparsity(&self) -> f64 {
214 let total = self.num_rows * self.num_cols;
215 if total == 0 {
216 0.0
217 } else {
218 self.nnz() as f64 / total as f64
219 }
220 }
221
222 pub fn row_range(&self, row: usize) -> Range<usize> {
224 self.row_ptrs[row]..self.row_ptrs[row + 1]
225 }
226
227 pub fn row_entries(&self, row: usize) -> impl Iterator<Item = (usize, T)> + '_ {
229 let range = self.row_range(row);
230 self.col_indices[range.clone()]
231 .iter()
232 .copied()
233 .zip(self.values[range].iter().copied())
234 }
235
236 pub fn matvec(&self, x: &Array1<T>) -> Array1<T> {
241 assert_eq!(x.len(), self.num_cols, "Input vector size mismatch");
242
243 #[cfg(feature = "rayon")]
245 {
246 if self.num_rows >= 246 {
248 return self.matvec_parallel(x);
249 }
250 }
251
252 self.matvec_sequential(x)
253 }
254
255 fn matvec_sequential(&self, x: &Array1<T>) -> Array1<T> {
257 let mut y = Array1::from_elem(self.num_rows, T::zero());
258
259 for i in 0..self.num_rows {
260 let mut sum = T::zero();
261 for idx in self.row_range(i) {
262 let j = self.col_indices[idx];
263 sum += self.values[idx] * x[j];
264 }
265 y[i] = sum;
266 }
267
268 y
269 }
270
271 #[cfg(feature = "rayon")]
273 fn matvec_parallel(&self, x: &Array1<T>) -> Array1<T>
274 where
275 T: Send + Sync,
276 {
277 let x_slice = x.as_slice().expect("Array should be contiguous");
278
279 let results: Vec<T> = (0..self.num_rows)
280 .into_par_iter()
281 .map(|i| {
282 let mut sum = T::zero();
283 for idx in self.row_range(i) {
284 let j = self.col_indices[idx];
285 sum += self.values[idx] * x_slice[j];
286 }
287 sum
288 })
289 .collect();
290
291 Array1::from_vec(results)
292 }
293
294 pub fn matvec_add(&self, x: &Array1<T>, y: &mut Array1<T>) {
296 assert_eq!(x.len(), self.num_cols, "Input vector size mismatch");
297 assert_eq!(y.len(), self.num_rows, "Output vector size mismatch");
298
299 for i in 0..self.num_rows {
300 for idx in self.row_range(i) {
301 let j = self.col_indices[idx];
302 y[i] += self.values[idx] * x[j];
303 }
304 }
305 }
306
307 pub fn matvec_transpose(&self, x: &Array1<T>) -> Array1<T> {
309 assert_eq!(x.len(), self.num_rows, "Input vector size mismatch");
310
311 let mut y = Array1::from_elem(self.num_cols, T::zero());
312
313 for i in 0..self.num_rows {
314 for idx in self.row_range(i) {
315 let j = self.col_indices[idx];
316 y[j] += self.values[idx] * x[i];
317 }
318 }
319
320 y
321 }
322
323 pub fn matvec_hermitian(&self, x: &Array1<T>) -> Array1<T> {
325 assert_eq!(x.len(), self.num_rows, "Input vector size mismatch");
326
327 let mut y = Array1::from_elem(self.num_cols, T::zero());
328
329 for i in 0..self.num_rows {
330 for idx in self.row_range(i) {
331 let j = self.col_indices[idx];
332 y[j] += self.values[idx].conj() * x[i];
333 }
334 }
335
336 y
337 }
338
339 pub fn get(&self, i: usize, j: usize) -> T {
341 for idx in self.row_range(i) {
342 if self.col_indices[idx] == j {
343 return self.values[idx];
344 }
345 }
346 T::zero()
347 }
348
349 pub fn diagonal(&self) -> Array1<T> {
351 let n = self.num_rows.min(self.num_cols);
352 let mut diag = Array1::from_elem(n, T::zero());
353
354 for i in 0..n {
355 diag[i] = self.get(i, i);
356 }
357
358 diag
359 }
360
361 pub fn scale(&mut self, scalar: T) {
363 for val in &mut self.values {
364 *val *= scalar;
365 }
366 }
367
368 pub fn add_diagonal(&mut self, scalar: T) {
370 let n = self.num_rows.min(self.num_cols);
371
372 for i in 0..n {
373 for idx in self.row_range(i) {
374 if self.col_indices[idx] == i {
375 self.values[idx] += scalar;
376 break;
377 }
378 }
379 }
380 }
381
382 pub fn identity(n: usize) -> Self {
384 Self {
385 num_rows: n,
386 num_cols: n,
387 values: vec![T::one(); n],
388 col_indices: (0..n).collect(),
389 row_ptrs: (0..=n).collect(),
390 }
391 }
392
393 pub fn from_diagonal(diag: &Array1<T>) -> Self {
395 let n = diag.len();
396 Self {
397 num_rows: n,
398 num_cols: n,
399 values: diag.to_vec(),
400 col_indices: (0..n).collect(),
401 row_ptrs: (0..=n).collect(),
402 }
403 }
404
405 pub fn to_dense(&self) -> Array2<T> {
407 let mut dense = Array2::from_elem((self.num_rows, self.num_cols), T::zero());
408
409 for i in 0..self.num_rows {
410 for idx in self.row_range(i) {
411 let j = self.col_indices[idx];
412 dense[[i, j]] = self.values[idx];
413 }
414 }
415
416 dense
417 }
418}
419
420impl<T: ComplexField> LinearOperator<T> for CsrMatrix<T> {
421 fn num_rows(&self) -> usize {
422 self.num_rows
423 }
424
425 fn num_cols(&self) -> usize {
426 self.num_cols
427 }
428
429 fn apply(&self, x: &Array1<T>) -> Array1<T> {
430 self.matvec(x)
431 }
432
433 fn apply_transpose(&self, x: &Array1<T>) -> Array1<T> {
434 self.matvec_transpose(x)
435 }
436
437 fn apply_hermitian(&self, x: &Array1<T>) -> Array1<T> {
438 self.matvec_hermitian(x)
439 }
440}
441
442pub struct CsrBuilder<T: ComplexField> {
444 num_rows: usize,
445 num_cols: usize,
446 values: Vec<T>,
447 col_indices: Vec<usize>,
448 row_ptrs: Vec<usize>,
449 current_row: usize,
450}
451
452impl<T: ComplexField> CsrBuilder<T> {
453 pub fn new(num_rows: usize, num_cols: usize) -> Self {
455 Self {
456 num_rows,
457 num_cols,
458 values: Vec::new(),
459 col_indices: Vec::new(),
460 row_ptrs: vec![0],
461 current_row: 0,
462 }
463 }
464
465 pub fn with_capacity(num_rows: usize, num_cols: usize, nnz_estimate: usize) -> Self {
467 Self {
468 num_rows,
469 num_cols,
470 values: Vec::with_capacity(nnz_estimate),
471 col_indices: Vec::with_capacity(nnz_estimate),
472 row_ptrs: Vec::with_capacity(num_rows + 1),
473 current_row: 0,
474 }
475 }
476
477 pub fn add_row_entries(&mut self, entries: impl Iterator<Item = (usize, T)>) {
479 for (col, val) in entries {
480 if val.norm() > T::Real::zero() {
481 self.values.push(val);
482 self.col_indices.push(col);
483 }
484 }
485 self.row_ptrs.push(self.values.len());
486 self.current_row += 1;
487 }
488
489 pub fn finish(mut self) -> CsrMatrix<T> {
491 while self.current_row < self.num_rows {
493 self.row_ptrs.push(self.values.len());
494 self.current_row += 1;
495 }
496
497 CsrMatrix {
498 num_rows: self.num_rows,
499 num_cols: self.num_cols,
500 values: self.values,
501 col_indices: self.col_indices,
502 row_ptrs: self.row_ptrs,
503 }
504 }
505}
506
507#[derive(Debug, Clone)]
512pub struct BlockedCsr<T: ComplexField> {
513 pub num_rows: usize,
515 pub num_cols: usize,
517 pub block_size: usize,
519 pub num_block_rows: usize,
521 pub num_block_cols: usize,
523 pub blocks: Vec<Array2<T>>,
525 pub block_col_indices: Vec<usize>,
527 pub block_row_ptrs: Vec<usize>,
529}
530
531impl<T: ComplexField> BlockedCsr<T> {
532 pub fn new(num_rows: usize, num_cols: usize, block_size: usize) -> Self {
534 let num_block_rows = num_rows.div_ceil(block_size);
535 let num_block_cols = num_cols.div_ceil(block_size);
536
537 Self {
538 num_rows,
539 num_cols,
540 block_size,
541 num_block_rows,
542 num_block_cols,
543 blocks: Vec::new(),
544 block_col_indices: Vec::new(),
545 block_row_ptrs: vec![0; num_block_rows + 1],
546 }
547 }
548
549 pub fn matvec(&self, x: &Array1<T>) -> Array1<T> {
551 assert_eq!(x.len(), self.num_cols, "Input vector size mismatch");
552
553 let mut y = Array1::from_elem(self.num_rows, T::zero());
554
555 for block_i in 0..self.num_block_rows {
556 let row_start = block_i * self.block_size;
557 let row_end = (row_start + self.block_size).min(self.num_rows);
558 let local_rows = row_end - row_start;
559
560 for idx in self.block_row_ptrs[block_i]..self.block_row_ptrs[block_i + 1] {
561 let block_j = self.block_col_indices[idx];
562 let block = &self.blocks[idx];
563
564 let col_start = block_j * self.block_size;
565 let col_end = (col_start + self.block_size).min(self.num_cols);
566 let local_cols = col_end - col_start;
567
568 let x_local: Array1<T> = Array1::from_iter((col_start..col_end).map(|j| x[j]));
570
571 for i in 0..local_rows {
573 let mut sum = T::zero();
574 for j in 0..local_cols {
575 sum += block[[i, j]] * x_local[j];
576 }
577 y[row_start + i] += sum;
578 }
579 }
580 }
581
582 y
583 }
584}
585
586impl<T: ComplexField> CsrMatrix<T> {
593 pub fn matmul(&self, other: &CsrMatrix<T>) -> CsrMatrix<T> {
595 assert_eq!(
596 self.num_cols, other.num_rows,
597 "Matrix dimension mismatch: A.cols ({}) != B.rows ({})",
598 self.num_cols, other.num_rows
599 );
600
601 let m = self.num_rows;
602 let n = other.num_cols;
603
604 if m == 0 || n == 0 || self.nnz() == 0 || other.nnz() == 0 {
605 return CsrMatrix::new(m, n);
606 }
607
608 let tol = T::Real::from_f64(1e-15).unwrap();
609
610 let mut triplets: Vec<(usize, usize, T)> = Vec::with_capacity(self.nnz() * 4);
611
612 for i in 0..m {
613 let mut row_data: Vec<(usize, T)> = Vec::new();
614
615 for (k, a_ik) in self.row_entries(i) {
616 for (j, b_kj) in other.row_entries(k) {
617 row_data.push((j, a_ik * b_kj));
618 }
619 }
620
621 if row_data.is_empty() {
622 continue;
623 }
624
625 row_data.sort_by_key(|&(j, _)| j);
626
627 let mut current_j = row_data[0].0;
628 let mut current_val = row_data[0].1;
629
630 for &(j, val) in &row_data[1..] {
631 if j == current_j {
632 current_val += val;
633 } else {
634 if current_val.norm() > tol {
635 triplets.push((i, current_j, current_val));
636 }
637 current_j = j;
638 current_val = val;
639 }
640 }
641
642 if current_val.norm() > tol {
643 triplets.push((i, current_j, current_val));
644 }
645 }
646
647 CsrMatrix::from_triplets(m, n, triplets)
648 }
649}
650
651#[cfg(test)]
652mod tests {
653 use super::*;
654 use approx::assert_relative_eq;
655 use ndarray::array;
656 use num_complex::Complex64;
657
658 #[test]
659 fn test_csr_from_dense() {
660 let dense = array![
661 [
662 Complex64::new(1.0, 0.0),
663 Complex64::new(0.0, 0.0),
664 Complex64::new(2.0, 0.0)
665 ],
666 [
667 Complex64::new(0.0, 0.0),
668 Complex64::new(3.0, 0.0),
669 Complex64::new(0.0, 0.0)
670 ],
671 [
672 Complex64::new(4.0, 0.0),
673 Complex64::new(0.0, 0.0),
674 Complex64::new(5.0, 0.0)
675 ],
676 ];
677
678 let csr = CsrMatrix::from_dense(&dense, 1e-15);
679
680 assert_eq!(csr.num_rows, 3);
681 assert_eq!(csr.num_cols, 3);
682 assert_eq!(csr.nnz(), 5);
683
684 assert_relative_eq!(csr.get(0, 0).re, 1.0);
685 assert_relative_eq!(csr.get(0, 2).re, 2.0);
686 assert_relative_eq!(csr.get(1, 1).re, 3.0);
687 assert_relative_eq!(csr.get(2, 0).re, 4.0);
688 assert_relative_eq!(csr.get(2, 2).re, 5.0);
689 }
690
691 #[test]
692 fn test_csr_matvec() {
693 let dense = array![
694 [Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)],
695 [Complex64::new(3.0, 0.0), Complex64::new(4.0, 0.0)],
696 ];
697
698 let csr = CsrMatrix::from_dense(&dense, 1e-15);
699 let x = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
700
701 let y = csr.matvec(&x);
702
703 assert_relative_eq!(y[0].re, 5.0, epsilon = 1e-10);
706 assert_relative_eq!(y[1].re, 11.0, epsilon = 1e-10);
707 }
708
709 #[test]
710 fn test_csr_from_triplets() {
711 let triplets = vec![
712 (0, 0, Complex64::new(1.0, 0.0)),
713 (0, 2, Complex64::new(2.0, 0.0)),
714 (1, 1, Complex64::new(3.0, 0.0)),
715 (2, 0, Complex64::new(4.0, 0.0)),
716 (2, 2, Complex64::new(5.0, 0.0)),
717 ];
718
719 let csr = CsrMatrix::from_triplets(3, 3, triplets);
720
721 assert_eq!(csr.nnz(), 5);
722 assert_relative_eq!(csr.get(0, 0).re, 1.0);
723 assert_relative_eq!(csr.get(1, 1).re, 3.0);
724 }
725
726 #[test]
727 fn test_csr_triplets_duplicate() {
728 let triplets = vec![
729 (0, 0, Complex64::new(1.0, 0.0)),
730 (0, 0, Complex64::new(2.0, 0.0)), (1, 1, Complex64::new(3.0, 0.0)),
732 ];
733
734 let csr = CsrMatrix::from_triplets(2, 2, triplets);
735
736 assert_relative_eq!(csr.get(0, 0).re, 3.0); }
738
739 #[test]
740 fn test_csr_identity() {
741 let id: CsrMatrix<Complex64> = CsrMatrix::identity(3);
742
743 assert_eq!(id.nnz(), 3);
744 assert_relative_eq!(id.get(0, 0).re, 1.0);
745 assert_relative_eq!(id.get(1, 1).re, 1.0);
746 assert_relative_eq!(id.get(2, 2).re, 1.0);
747 assert_relative_eq!(id.get(0, 1).norm(), 0.0);
748 }
749
750 #[test]
751 fn test_csr_builder() {
752 let mut builder: CsrBuilder<Complex64> = CsrBuilder::new(3, 3);
753
754 builder.add_row_entries(
755 [(0, Complex64::new(1.0, 0.0)), (2, Complex64::new(2.0, 0.0))].into_iter(),
756 );
757 builder.add_row_entries([(1, Complex64::new(3.0, 0.0))].into_iter());
758 builder.add_row_entries(
759 [(0, Complex64::new(4.0, 0.0)), (2, Complex64::new(5.0, 0.0))].into_iter(),
760 );
761
762 let csr = builder.finish();
763
764 assert_eq!(csr.nnz(), 5);
765 assert_relative_eq!(csr.get(0, 0).re, 1.0);
766 assert_relative_eq!(csr.get(1, 1).re, 3.0);
767 }
768
769 #[test]
770 fn test_csr_to_dense_roundtrip() {
771 let original = array![
772 [Complex64::new(1.0, 0.5), Complex64::new(0.0, 0.0)],
773 [Complex64::new(2.0, -1.0), Complex64::new(3.0, 0.0)],
774 ];
775
776 let csr = CsrMatrix::from_dense(&original, 1e-15);
777 let recovered = csr.to_dense();
778
779 for i in 0..2 {
780 for j in 0..2 {
781 assert_relative_eq!(
782 (original[[i, j]] - recovered[[i, j]]).norm(),
783 0.0,
784 epsilon = 1e-10
785 );
786 }
787 }
788 }
789
790 #[test]
791 fn test_linear_operator_impl() {
792 let dense = array![
793 [Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)],
794 [Complex64::new(3.0, 0.0), Complex64::new(4.0, 0.0)],
795 ];
796
797 let csr = CsrMatrix::from_dense(&dense, 1e-15);
798 let x = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
799
800 let y = csr.apply(&x);
802 assert_relative_eq!(y[0].re, 5.0, epsilon = 1e-10);
803 assert_relative_eq!(y[1].re, 11.0, epsilon = 1e-10);
804
805 assert!(csr.is_square());
806 assert_eq!(csr.num_rows(), 2);
807 assert_eq!(csr.num_cols(), 2);
808 }
809
810 #[test]
811 fn test_f64_csr() {
812 let dense = array![[1.0_f64, 2.0], [3.0, 4.0],];
813
814 let csr = CsrMatrix::from_dense(&dense, 1e-15);
815 let x = array![1.0_f64, 2.0];
816
817 let y = csr.matvec(&x);
818 assert_relative_eq!(y[0], 5.0, epsilon = 1e-10);
819 assert_relative_eq!(y[1], 11.0, epsilon = 1e-10);
820 }
821}