1use ndarray::Array1;
15use num_complex::Complex64;
16use std::ops::Range;
17
18#[derive(Debug, Clone)]
23pub struct CsrMatrix {
24 pub num_rows: usize,
26 pub num_cols: usize,
28 pub values: Vec<Complex64>,
30 pub col_indices: Vec<usize>,
32 pub row_ptrs: Vec<usize>,
35}
36
37impl CsrMatrix {
38 pub fn new(num_rows: usize, num_cols: usize) -> Self {
40 Self {
41 num_rows,
42 num_cols,
43 values: Vec::new(),
44 col_indices: Vec::new(),
45 row_ptrs: vec![0; num_rows + 1],
46 }
47 }
48
49 pub fn with_capacity(num_rows: usize, num_cols: usize, nnz_estimate: usize) -> Self {
51 Self {
52 num_rows,
53 num_cols,
54 values: Vec::with_capacity(nnz_estimate),
55 col_indices: Vec::with_capacity(nnz_estimate),
56 row_ptrs: vec![0; num_rows + 1],
57 }
58 }
59
60 pub fn from_dense(dense: &ndarray::Array2<Complex64>, threshold: f64) -> Self {
64 let num_rows = dense.nrows();
65 let num_cols = dense.ncols();
66
67 let mut values = Vec::new();
68 let mut col_indices = Vec::new();
69 let mut row_ptrs = vec![0usize; num_rows + 1];
70
71 for i in 0..num_rows {
72 for j in 0..num_cols {
73 let val = dense[[i, j]];
74 if val.norm() > threshold {
75 values.push(val);
76 col_indices.push(j);
77 }
78 }
79 row_ptrs[i + 1] = values.len();
80 }
81
82 Self {
83 num_rows,
84 num_cols,
85 values,
86 col_indices,
87 row_ptrs,
88 }
89 }
90
91 pub fn from_triplets(
95 num_rows: usize,
96 num_cols: usize,
97 mut triplets: Vec<(usize, usize, Complex64)>,
98 ) -> Self {
99 if triplets.is_empty() {
100 return Self::new(num_rows, num_cols);
101 }
102
103 triplets.sort_by(|a, b| {
105 if a.0 != b.0 {
106 a.0.cmp(&b.0)
107 } else {
108 a.1.cmp(&b.1)
109 }
110 });
111
112 let mut values = Vec::with_capacity(triplets.len());
113 let mut col_indices = Vec::with_capacity(triplets.len());
114 let mut row_ptrs = vec![0usize; num_rows + 1];
115
116 let mut prev_row = usize::MAX;
117 let mut prev_col = usize::MAX;
118
119 for (row, col, val) in triplets {
120 if row == prev_row && col == prev_col {
121 if let Some(last) = values.last_mut() {
123 *last += val;
124 }
125 } else {
126 values.push(val);
128 col_indices.push(col);
129
130 if row != prev_row {
132 let start = if prev_row == usize::MAX {
133 0
134 } else {
135 prev_row + 1
136 };
137 for item in row_ptrs.iter_mut().take(row + 1).skip(start) {
138 *item = values.len() - 1;
139 }
140 }
141
142 prev_row = row;
143 prev_col = col;
144 }
145 }
146
147 let last_row = if prev_row == usize::MAX {
149 0
150 } else {
151 prev_row + 1
152 };
153 for item in row_ptrs.iter_mut().take(num_rows + 1).skip(last_row) {
154 *item = values.len();
155 }
156
157 Self {
158 num_rows,
159 num_cols,
160 values,
161 col_indices,
162 row_ptrs,
163 }
164 }
165
166 pub fn nnz(&self) -> usize {
168 self.values.len()
169 }
170
171 pub fn sparsity(&self) -> f64 {
173 let total = self.num_rows * self.num_cols;
174 if total == 0 {
175 0.0
176 } else {
177 self.nnz() as f64 / total as f64
178 }
179 }
180
181 pub fn row_range(&self, row: usize) -> Range<usize> {
183 self.row_ptrs[row]..self.row_ptrs[row + 1]
184 }
185
186 pub fn row_entries(&self, row: usize) -> impl Iterator<Item = (usize, Complex64)> + '_ {
188 let range = self.row_range(row);
189 self.col_indices[range.clone()]
190 .iter()
191 .copied()
192 .zip(self.values[range].iter().copied())
193 }
194
195 pub fn matvec(&self, x: &Array1<Complex64>) -> Array1<Complex64> {
197 assert_eq!(x.len(), self.num_cols, "Input vector size mismatch");
198
199 let mut y = Array1::zeros(self.num_rows);
200
201 for i in 0..self.num_rows {
202 let mut sum = Complex64::new(0.0, 0.0);
203 for idx in self.row_range(i) {
204 let j = self.col_indices[idx];
205 sum += self.values[idx] * x[j];
206 }
207 y[i] = sum;
208 }
209
210 y
211 }
212
213 pub fn matvec_add(&self, x: &Array1<Complex64>, y: &mut Array1<Complex64>) {
215 assert_eq!(x.len(), self.num_cols, "Input vector size mismatch");
216 assert_eq!(y.len(), self.num_rows, "Output vector size mismatch");
217
218 for i in 0..self.num_rows {
219 for idx in self.row_range(i) {
220 let j = self.col_indices[idx];
221 y[i] += self.values[idx] * x[j];
222 }
223 }
224 }
225
226 pub fn matvec_transpose(&self, x: &Array1<Complex64>) -> Array1<Complex64> {
228 assert_eq!(x.len(), self.num_rows, "Input vector size mismatch");
229
230 let mut y = Array1::zeros(self.num_cols);
231
232 for i in 0..self.num_rows {
233 for idx in self.row_range(i) {
234 let j = self.col_indices[idx];
235 y[j] += self.values[idx] * x[i];
236 }
237 }
238
239 y
240 }
241
242 pub fn matvec_hermitian(&self, x: &Array1<Complex64>) -> Array1<Complex64> {
244 assert_eq!(x.len(), self.num_rows, "Input vector size mismatch");
245
246 let mut y = Array1::zeros(self.num_cols);
247
248 for i in 0..self.num_rows {
249 for idx in self.row_range(i) {
250 let j = self.col_indices[idx];
251 y[j] += self.values[idx].conj() * x[i];
252 }
253 }
254
255 y
256 }
257
258 pub fn get(&self, i: usize, j: usize) -> Complex64 {
260 for idx in self.row_range(i) {
261 if self.col_indices[idx] == j {
262 return self.values[idx];
263 }
264 }
265 Complex64::new(0.0, 0.0)
266 }
267
268 pub fn diagonal(&self) -> Array1<Complex64> {
270 let n = self.num_rows.min(self.num_cols);
271 let mut diag = Array1::zeros(n);
272
273 for i in 0..n {
274 diag[i] = self.get(i, i);
275 }
276
277 diag
278 }
279
280 pub fn scale(&mut self, scalar: Complex64) {
282 for val in &mut self.values {
283 *val *= scalar;
284 }
285 }
286
287 pub fn add_diagonal(&mut self, scalar: Complex64) {
289 let n = self.num_rows.min(self.num_cols);
290
291 for i in 0..n {
292 for idx in self.row_range(i) {
293 if self.col_indices[idx] == i {
294 self.values[idx] += scalar;
295 break;
296 }
297 }
298 }
299 }
300
301 pub fn identity(n: usize) -> Self {
303 Self {
304 num_rows: n,
305 num_cols: n,
306 values: vec![Complex64::new(1.0, 0.0); n],
307 col_indices: (0..n).collect(),
308 row_ptrs: (0..=n).collect(),
309 }
310 }
311
312 pub fn from_diagonal(diag: &Array1<Complex64>) -> Self {
314 let n = diag.len();
315 Self {
316 num_rows: n,
317 num_cols: n,
318 values: diag.to_vec(),
319 col_indices: (0..n).collect(),
320 row_ptrs: (0..=n).collect(),
321 }
322 }
323
324 pub fn to_dense(&self) -> ndarray::Array2<Complex64> {
326 let mut dense = ndarray::Array2::zeros((self.num_rows, self.num_cols));
327
328 for i in 0..self.num_rows {
329 for idx in self.row_range(i) {
330 let j = self.col_indices[idx];
331 dense[[i, j]] = self.values[idx];
332 }
333 }
334
335 dense
336 }
337}
338
339pub struct CsrBuilder {
341 num_rows: usize,
342 num_cols: usize,
343 values: Vec<Complex64>,
344 col_indices: Vec<usize>,
345 row_ptrs: Vec<usize>,
346 current_row: usize,
347}
348
349impl CsrBuilder {
350 pub fn new(num_rows: usize, num_cols: usize) -> Self {
352 Self {
353 num_rows,
354 num_cols,
355 values: Vec::new(),
356 col_indices: Vec::new(),
357 row_ptrs: vec![0],
358 current_row: 0,
359 }
360 }
361
362 pub fn with_capacity(num_rows: usize, num_cols: usize, nnz_estimate: usize) -> Self {
364 Self {
365 num_rows,
366 num_cols,
367 values: Vec::with_capacity(nnz_estimate),
368 col_indices: Vec::with_capacity(nnz_estimate),
369 row_ptrs: Vec::with_capacity(num_rows + 1),
370 current_row: 0,
371 }
372 }
373
374 pub fn add_row_entries(&mut self, entries: impl Iterator<Item = (usize, Complex64)>) {
376 for (col, val) in entries {
377 if val.norm() > 0.0 {
378 self.values.push(val);
379 self.col_indices.push(col);
380 }
381 }
382 self.row_ptrs.push(self.values.len());
383 self.current_row += 1;
384 }
385
386 pub fn finish(mut self) -> CsrMatrix {
388 while self.current_row < self.num_rows {
390 self.row_ptrs.push(self.values.len());
391 self.current_row += 1;
392 }
393
394 CsrMatrix {
395 num_rows: self.num_rows,
396 num_cols: self.num_cols,
397 values: self.values,
398 col_indices: self.col_indices,
399 row_ptrs: self.row_ptrs,
400 }
401 }
402}
403
404#[derive(Debug, Clone)]
409pub struct BlockedCsr {
410 pub num_rows: usize,
412 pub num_cols: usize,
414 pub block_size: usize,
416 pub num_block_rows: usize,
418 pub num_block_cols: usize,
420 pub blocks: Vec<ndarray::Array2<Complex64>>,
423 pub block_col_indices: Vec<usize>,
425 pub block_row_ptrs: Vec<usize>,
427}
428
429impl BlockedCsr {
430 pub fn new(num_rows: usize, num_cols: usize, block_size: usize) -> Self {
432 let num_block_rows = num_rows.div_ceil(block_size);
433 let num_block_cols = num_cols.div_ceil(block_size);
434
435 Self {
436 num_rows,
437 num_cols,
438 block_size,
439 num_block_rows,
440 num_block_cols,
441 blocks: Vec::new(),
442 block_col_indices: Vec::new(),
443 block_row_ptrs: vec![0; num_block_rows + 1],
444 }
445 }
446
447 pub fn matvec(&self, x: &Array1<Complex64>) -> Array1<Complex64> {
449 assert_eq!(x.len(), self.num_cols, "Input vector size mismatch");
450
451 let mut y = Array1::zeros(self.num_rows);
452
453 for block_i in 0..self.num_block_rows {
454 let row_start = block_i * self.block_size;
455 let row_end = (row_start + self.block_size).min(self.num_rows);
456 let local_rows = row_end - row_start;
457
458 for idx in self.block_row_ptrs[block_i]..self.block_row_ptrs[block_i + 1] {
459 let block_j = self.block_col_indices[idx];
460 let block = &self.blocks[idx];
461
462 let col_start = block_j * self.block_size;
463 let col_end = (col_start + self.block_size).min(self.num_cols);
464 let local_cols = col_end - col_start;
465
466 let x_local: Array1<Complex64> =
468 Array1::from_iter((col_start..col_end).map(|j| x[j]));
469
470 for i in 0..local_rows {
472 let mut sum = Complex64::new(0.0, 0.0);
473 for j in 0..local_cols {
474 sum += block[[i, j]] * x_local[j];
475 }
476 y[row_start + i] += sum;
477 }
478 }
479 }
480
481 y
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488 use ndarray::array;
489
490 #[test]
491 fn test_csr_from_dense() {
492 let dense = array![
493 [
494 Complex64::new(1.0, 0.0),
495 Complex64::new(0.0, 0.0),
496 Complex64::new(2.0, 0.0)
497 ],
498 [
499 Complex64::new(0.0, 0.0),
500 Complex64::new(3.0, 0.0),
501 Complex64::new(0.0, 0.0)
502 ],
503 [
504 Complex64::new(4.0, 0.0),
505 Complex64::new(0.0, 0.0),
506 Complex64::new(5.0, 0.0)
507 ],
508 ];
509
510 let csr = CsrMatrix::from_dense(&dense, 1e-15);
511
512 assert_eq!(csr.num_rows, 3);
513 assert_eq!(csr.num_cols, 3);
514 assert_eq!(csr.nnz(), 5);
515
516 assert_eq!(csr.get(0, 0), Complex64::new(1.0, 0.0));
518 assert_eq!(csr.get(0, 2), Complex64::new(2.0, 0.0));
519 assert_eq!(csr.get(1, 1), Complex64::new(3.0, 0.0));
520 assert_eq!(csr.get(2, 0), Complex64::new(4.0, 0.0));
521 assert_eq!(csr.get(2, 2), Complex64::new(5.0, 0.0));
522
523 assert_eq!(csr.get(0, 1), Complex64::new(0.0, 0.0));
525 assert_eq!(csr.get(1, 0), Complex64::new(0.0, 0.0));
526 }
527
528 #[test]
529 fn test_csr_matvec() {
530 let dense = array![
531 [Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)],
532 [Complex64::new(3.0, 0.0), Complex64::new(4.0, 0.0)],
533 ];
534
535 let csr = CsrMatrix::from_dense(&dense, 1e-15);
536 let x = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
537
538 let y = csr.matvec(&x);
539
540 assert!((y[0] - Complex64::new(5.0, 0.0)).norm() < 1e-10);
543 assert!((y[1] - Complex64::new(11.0, 0.0)).norm() < 1e-10);
544 }
545
546 #[test]
547 fn test_csr_from_triplets() {
548 let triplets = vec![
549 (0, 0, Complex64::new(1.0, 0.0)),
550 (0, 2, Complex64::new(2.0, 0.0)),
551 (1, 1, Complex64::new(3.0, 0.0)),
552 (2, 0, Complex64::new(4.0, 0.0)),
553 (2, 2, Complex64::new(5.0, 0.0)),
554 ];
555
556 let csr = CsrMatrix::from_triplets(3, 3, triplets);
557
558 assert_eq!(csr.nnz(), 5);
559 assert_eq!(csr.get(0, 0), Complex64::new(1.0, 0.0));
560 assert_eq!(csr.get(1, 1), Complex64::new(3.0, 0.0));
561 }
562
563 #[test]
564 fn test_csr_triplets_duplicate() {
565 let triplets = vec![
567 (0, 0, Complex64::new(1.0, 0.0)),
568 (0, 0, Complex64::new(2.0, 0.0)), (1, 1, Complex64::new(3.0, 0.0)),
570 ];
571
572 let csr = CsrMatrix::from_triplets(2, 2, triplets);
573
574 assert_eq!(csr.get(0, 0), Complex64::new(3.0, 0.0)); }
576
577 #[test]
578 fn test_csr_identity() {
579 let id = CsrMatrix::identity(3);
580
581 assert_eq!(id.nnz(), 3);
582 assert_eq!(id.get(0, 0), Complex64::new(1.0, 0.0));
583 assert_eq!(id.get(1, 1), Complex64::new(1.0, 0.0));
584 assert_eq!(id.get(2, 2), Complex64::new(1.0, 0.0));
585 assert_eq!(id.get(0, 1), Complex64::new(0.0, 0.0));
586 }
587
588 #[test]
589 fn test_csr_builder() {
590 let mut builder = CsrBuilder::new(3, 3);
591
592 builder.add_row_entries(
594 [(0, Complex64::new(1.0, 0.0)), (2, Complex64::new(2.0, 0.0))].into_iter(),
595 );
596
597 builder.add_row_entries([(1, Complex64::new(3.0, 0.0))].into_iter());
599
600 builder.add_row_entries(
602 [(0, Complex64::new(4.0, 0.0)), (2, Complex64::new(5.0, 0.0))].into_iter(),
603 );
604
605 let csr = builder.finish();
606
607 assert_eq!(csr.nnz(), 5);
608 assert_eq!(csr.get(0, 0), Complex64::new(1.0, 0.0));
609 assert_eq!(csr.get(1, 1), Complex64::new(3.0, 0.0));
610 }
611
612 #[test]
613 fn test_csr_to_dense_roundtrip() {
614 let original = array![
615 [Complex64::new(1.0, 0.5), Complex64::new(0.0, 0.0)],
616 [Complex64::new(2.0, -1.0), Complex64::new(3.0, 0.0)],
617 ];
618
619 let csr = CsrMatrix::from_dense(&original, 1e-15);
620 let recovered = csr.to_dense();
621
622 for i in 0..2 {
623 for j in 0..2 {
624 assert!((original[[i, j]] - recovered[[i, j]]).norm() < 1e-10);
625 }
626 }
627 }
628
629 #[test]
630 fn test_csr_transpose_matvec() {
631 let dense = array![
632 [Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)],
633 [Complex64::new(3.0, 0.0), Complex64::new(4.0, 0.0)],
634 [Complex64::new(5.0, 0.0), Complex64::new(6.0, 0.0)],
635 ];
636
637 let csr = CsrMatrix::from_dense(&dense, 1e-15);
638 let x = array![
639 Complex64::new(1.0, 0.0),
640 Complex64::new(2.0, 0.0),
641 Complex64::new(3.0, 0.0)
642 ];
643
644 let y = csr.matvec_transpose(&x);
645
646 assert!((y[0] - Complex64::new(22.0, 0.0)).norm() < 1e-10);
650 assert!((y[1] - Complex64::new(28.0, 0.0)).norm() < 1e-10);
651 }
652}