1use ndarray::Array1;
15use num_complex::Complex64;
16use std::ops::Range;
17
18#[derive(Debug, Clone)]
23pub struct CsrMatrix {
24 pub num_rows: usize,
26 pub num_cols: usize,
28 pub values: Vec<Complex64>,
30 pub col_indices: Vec<usize>,
32 pub row_ptrs: Vec<usize>,
35}
36
37impl CsrMatrix {
38 pub fn new(num_rows: usize, num_cols: usize) -> Self {
40 Self {
41 num_rows,
42 num_cols,
43 values: Vec::new(),
44 col_indices: Vec::new(),
45 row_ptrs: vec![0; num_rows + 1],
46 }
47 }
48
49 pub fn with_capacity(num_rows: usize, num_cols: usize, nnz_estimate: usize) -> Self {
51 Self {
52 num_rows,
53 num_cols,
54 values: Vec::with_capacity(nnz_estimate),
55 col_indices: Vec::with_capacity(nnz_estimate),
56 row_ptrs: vec![0; num_rows + 1],
57 }
58 }
59
60 pub fn from_dense(
64 dense: &ndarray::Array2<Complex64>,
65 threshold: f64,
66 ) -> Self {
67 let num_rows = dense.nrows();
68 let num_cols = dense.ncols();
69
70 let mut values = Vec::new();
71 let mut col_indices = Vec::new();
72 let mut row_ptrs = vec![0usize; num_rows + 1];
73
74 for i in 0..num_rows {
75 for j in 0..num_cols {
76 let val = dense[[i, j]];
77 if val.norm() > threshold {
78 values.push(val);
79 col_indices.push(j);
80 }
81 }
82 row_ptrs[i + 1] = values.len();
83 }
84
85 Self {
86 num_rows,
87 num_cols,
88 values,
89 col_indices,
90 row_ptrs,
91 }
92 }
93
94 pub fn from_triplets(
98 num_rows: usize,
99 num_cols: usize,
100 mut triplets: Vec<(usize, usize, Complex64)>,
101 ) -> Self {
102 if triplets.is_empty() {
103 return Self::new(num_rows, num_cols);
104 }
105
106 triplets.sort_by(|a, b| {
108 if a.0 != b.0 {
109 a.0.cmp(&b.0)
110 } else {
111 a.1.cmp(&b.1)
112 }
113 });
114
115 let mut values = Vec::with_capacity(triplets.len());
116 let mut col_indices = Vec::with_capacity(triplets.len());
117 let mut row_ptrs = vec![0usize; num_rows + 1];
118
119 let mut prev_row = usize::MAX;
120 let mut prev_col = usize::MAX;
121
122 for (row, col, val) in triplets {
123 if row == prev_row && col == prev_col {
124 if let Some(last) = values.last_mut() {
126 *last += val;
127 }
128 } else {
129 values.push(val);
131 col_indices.push(col);
132
133 if row != prev_row {
135 let start = if prev_row == usize::MAX { 0 } else { prev_row + 1 };
136 for r in start..=row {
137 row_ptrs[r] = values.len() - 1;
138 }
139 }
140
141 prev_row = row;
142 prev_col = col;
143 }
144 }
145
146 let last_row = if prev_row == usize::MAX { 0 } else { prev_row + 1 };
148 for r in last_row..=num_rows {
149 row_ptrs[r] = values.len();
150 }
151
152 Self {
153 num_rows,
154 num_cols,
155 values,
156 col_indices,
157 row_ptrs,
158 }
159 }
160
161 pub fn nnz(&self) -> usize {
163 self.values.len()
164 }
165
166 pub fn sparsity(&self) -> f64 {
168 let total = self.num_rows * self.num_cols;
169 if total == 0 {
170 0.0
171 } else {
172 self.nnz() as f64 / total as f64
173 }
174 }
175
176 pub fn row_range(&self, row: usize) -> Range<usize> {
178 self.row_ptrs[row]..self.row_ptrs[row + 1]
179 }
180
181 pub fn row_entries(&self, row: usize) -> impl Iterator<Item = (usize, Complex64)> + '_ {
183 let range = self.row_range(row);
184 self.col_indices[range.clone()]
185 .iter()
186 .copied()
187 .zip(self.values[range].iter().copied())
188 }
189
190 pub fn matvec(&self, x: &Array1<Complex64>) -> Array1<Complex64> {
192 assert_eq!(x.len(), self.num_cols, "Input vector size mismatch");
193
194 let mut y = Array1::zeros(self.num_rows);
195
196 for i in 0..self.num_rows {
197 let mut sum = Complex64::new(0.0, 0.0);
198 for idx in self.row_range(i) {
199 let j = self.col_indices[idx];
200 sum += self.values[idx] * x[j];
201 }
202 y[i] = sum;
203 }
204
205 y
206 }
207
208 pub fn matvec_add(&self, x: &Array1<Complex64>, y: &mut Array1<Complex64>) {
210 assert_eq!(x.len(), self.num_cols, "Input vector size mismatch");
211 assert_eq!(y.len(), self.num_rows, "Output vector size mismatch");
212
213 for i in 0..self.num_rows {
214 for idx in self.row_range(i) {
215 let j = self.col_indices[idx];
216 y[i] += self.values[idx] * x[j];
217 }
218 }
219 }
220
221 pub fn matvec_transpose(&self, x: &Array1<Complex64>) -> Array1<Complex64> {
223 assert_eq!(x.len(), self.num_rows, "Input vector size mismatch");
224
225 let mut y = Array1::zeros(self.num_cols);
226
227 for i in 0..self.num_rows {
228 for idx in self.row_range(i) {
229 let j = self.col_indices[idx];
230 y[j] += self.values[idx] * x[i];
231 }
232 }
233
234 y
235 }
236
237 pub fn matvec_hermitian(&self, x: &Array1<Complex64>) -> Array1<Complex64> {
239 assert_eq!(x.len(), self.num_rows, "Input vector size mismatch");
240
241 let mut y = Array1::zeros(self.num_cols);
242
243 for i in 0..self.num_rows {
244 for idx in self.row_range(i) {
245 let j = self.col_indices[idx];
246 y[j] += self.values[idx].conj() * x[i];
247 }
248 }
249
250 y
251 }
252
253 pub fn get(&self, i: usize, j: usize) -> Complex64 {
255 for idx in self.row_range(i) {
256 if self.col_indices[idx] == j {
257 return self.values[idx];
258 }
259 }
260 Complex64::new(0.0, 0.0)
261 }
262
263 pub fn diagonal(&self) -> Array1<Complex64> {
265 let n = self.num_rows.min(self.num_cols);
266 let mut diag = Array1::zeros(n);
267
268 for i in 0..n {
269 diag[i] = self.get(i, i);
270 }
271
272 diag
273 }
274
275 pub fn scale(&mut self, scalar: Complex64) {
277 for val in &mut self.values {
278 *val *= scalar;
279 }
280 }
281
282 pub fn add_diagonal(&mut self, scalar: Complex64) {
284 let n = self.num_rows.min(self.num_cols);
285
286 for i in 0..n {
287 for idx in self.row_range(i) {
288 if self.col_indices[idx] == i {
289 self.values[idx] += scalar;
290 break;
291 }
292 }
293 }
294 }
295
296 pub fn identity(n: usize) -> Self {
298 Self {
299 num_rows: n,
300 num_cols: n,
301 values: vec![Complex64::new(1.0, 0.0); n],
302 col_indices: (0..n).collect(),
303 row_ptrs: (0..=n).collect(),
304 }
305 }
306
307 pub fn from_diagonal(diag: &Array1<Complex64>) -> Self {
309 let n = diag.len();
310 Self {
311 num_rows: n,
312 num_cols: n,
313 values: diag.to_vec(),
314 col_indices: (0..n).collect(),
315 row_ptrs: (0..=n).collect(),
316 }
317 }
318
319 pub fn to_dense(&self) -> ndarray::Array2<Complex64> {
321 let mut dense = ndarray::Array2::zeros((self.num_rows, self.num_cols));
322
323 for i in 0..self.num_rows {
324 for idx in self.row_range(i) {
325 let j = self.col_indices[idx];
326 dense[[i, j]] = self.values[idx];
327 }
328 }
329
330 dense
331 }
332}
333
334pub struct CsrBuilder {
336 num_rows: usize,
337 num_cols: usize,
338 values: Vec<Complex64>,
339 col_indices: Vec<usize>,
340 row_ptrs: Vec<usize>,
341 current_row: usize,
342}
343
344impl CsrBuilder {
345 pub fn new(num_rows: usize, num_cols: usize) -> Self {
347 Self {
348 num_rows,
349 num_cols,
350 values: Vec::new(),
351 col_indices: Vec::new(),
352 row_ptrs: vec![0],
353 current_row: 0,
354 }
355 }
356
357 pub fn with_capacity(num_rows: usize, num_cols: usize, nnz_estimate: usize) -> Self {
359 Self {
360 num_rows,
361 num_cols,
362 values: Vec::with_capacity(nnz_estimate),
363 col_indices: Vec::with_capacity(nnz_estimate),
364 row_ptrs: Vec::with_capacity(num_rows + 1),
365 current_row: 0,
366 }
367 }
368
369 pub fn add_row_entries(&mut self, entries: impl Iterator<Item = (usize, Complex64)>) {
371 for (col, val) in entries {
372 if val.norm() > 0.0 {
373 self.values.push(val);
374 self.col_indices.push(col);
375 }
376 }
377 self.row_ptrs.push(self.values.len());
378 self.current_row += 1;
379 }
380
381 pub fn finish(mut self) -> CsrMatrix {
383 while self.current_row < self.num_rows {
385 self.row_ptrs.push(self.values.len());
386 self.current_row += 1;
387 }
388
389 CsrMatrix {
390 num_rows: self.num_rows,
391 num_cols: self.num_cols,
392 values: self.values,
393 col_indices: self.col_indices,
394 row_ptrs: self.row_ptrs,
395 }
396 }
397}
398
399#[derive(Debug, Clone)]
404pub struct BlockedCsr {
405 pub num_rows: usize,
407 pub num_cols: usize,
409 pub block_size: usize,
411 pub num_block_rows: usize,
413 pub num_block_cols: usize,
415 pub blocks: Vec<ndarray::Array2<Complex64>>,
418 pub block_col_indices: Vec<usize>,
420 pub block_row_ptrs: Vec<usize>,
422}
423
424impl BlockedCsr {
425 pub fn new(num_rows: usize, num_cols: usize, block_size: usize) -> Self {
427 let num_block_rows = (num_rows + block_size - 1) / block_size;
428 let num_block_cols = (num_cols + block_size - 1) / block_size;
429
430 Self {
431 num_rows,
432 num_cols,
433 block_size,
434 num_block_rows,
435 num_block_cols,
436 blocks: Vec::new(),
437 block_col_indices: Vec::new(),
438 block_row_ptrs: vec![0; num_block_rows + 1],
439 }
440 }
441
442 pub fn matvec(&self, x: &Array1<Complex64>) -> Array1<Complex64> {
444 assert_eq!(x.len(), self.num_cols, "Input vector size mismatch");
445
446 let mut y = Array1::zeros(self.num_rows);
447
448 for block_i in 0..self.num_block_rows {
449 let row_start = block_i * self.block_size;
450 let row_end = (row_start + self.block_size).min(self.num_rows);
451 let local_rows = row_end - row_start;
452
453 for idx in self.block_row_ptrs[block_i]..self.block_row_ptrs[block_i + 1] {
454 let block_j = self.block_col_indices[idx];
455 let block = &self.blocks[idx];
456
457 let col_start = block_j * self.block_size;
458 let col_end = (col_start + self.block_size).min(self.num_cols);
459 let local_cols = col_end - col_start;
460
461 let x_local: Array1<Complex64> =
463 Array1::from_iter((col_start..col_end).map(|j| x[j]));
464
465 for i in 0..local_rows {
467 let mut sum = Complex64::new(0.0, 0.0);
468 for j in 0..local_cols {
469 sum += block[[i, j]] * x_local[j];
470 }
471 y[row_start + i] += sum;
472 }
473 }
474 }
475
476 y
477 }
478}
479
480#[cfg(test)]
481mod tests {
482 use super::*;
483 use ndarray::array;
484
485 #[test]
486 fn test_csr_from_dense() {
487 let dense = array![
488 [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0), Complex64::new(2.0, 0.0)],
489 [Complex64::new(0.0, 0.0), Complex64::new(3.0, 0.0), Complex64::new(0.0, 0.0)],
490 [Complex64::new(4.0, 0.0), Complex64::new(0.0, 0.0), Complex64::new(5.0, 0.0)],
491 ];
492
493 let csr = CsrMatrix::from_dense(&dense, 1e-15);
494
495 assert_eq!(csr.num_rows, 3);
496 assert_eq!(csr.num_cols, 3);
497 assert_eq!(csr.nnz(), 5);
498
499 assert_eq!(csr.get(0, 0), Complex64::new(1.0, 0.0));
501 assert_eq!(csr.get(0, 2), Complex64::new(2.0, 0.0));
502 assert_eq!(csr.get(1, 1), Complex64::new(3.0, 0.0));
503 assert_eq!(csr.get(2, 0), Complex64::new(4.0, 0.0));
504 assert_eq!(csr.get(2, 2), Complex64::new(5.0, 0.0));
505
506 assert_eq!(csr.get(0, 1), Complex64::new(0.0, 0.0));
508 assert_eq!(csr.get(1, 0), Complex64::new(0.0, 0.0));
509 }
510
511 #[test]
512 fn test_csr_matvec() {
513 let dense = array![
514 [Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)],
515 [Complex64::new(3.0, 0.0), Complex64::new(4.0, 0.0)],
516 ];
517
518 let csr = CsrMatrix::from_dense(&dense, 1e-15);
519 let x = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
520
521 let y = csr.matvec(&x);
522
523 assert!((y[0] - Complex64::new(5.0, 0.0)).norm() < 1e-10);
526 assert!((y[1] - Complex64::new(11.0, 0.0)).norm() < 1e-10);
527 }
528
529 #[test]
530 fn test_csr_from_triplets() {
531 let triplets = vec![
532 (0, 0, Complex64::new(1.0, 0.0)),
533 (0, 2, Complex64::new(2.0, 0.0)),
534 (1, 1, Complex64::new(3.0, 0.0)),
535 (2, 0, Complex64::new(4.0, 0.0)),
536 (2, 2, Complex64::new(5.0, 0.0)),
537 ];
538
539 let csr = CsrMatrix::from_triplets(3, 3, triplets);
540
541 assert_eq!(csr.nnz(), 5);
542 assert_eq!(csr.get(0, 0), Complex64::new(1.0, 0.0));
543 assert_eq!(csr.get(1, 1), Complex64::new(3.0, 0.0));
544 }
545
546 #[test]
547 fn test_csr_triplets_duplicate() {
548 let triplets = vec![
550 (0, 0, Complex64::new(1.0, 0.0)),
551 (0, 0, Complex64::new(2.0, 0.0)), (1, 1, Complex64::new(3.0, 0.0)),
553 ];
554
555 let csr = CsrMatrix::from_triplets(2, 2, triplets);
556
557 assert_eq!(csr.get(0, 0), Complex64::new(3.0, 0.0)); }
559
560 #[test]
561 fn test_csr_identity() {
562 let id = CsrMatrix::identity(3);
563
564 assert_eq!(id.nnz(), 3);
565 assert_eq!(id.get(0, 0), Complex64::new(1.0, 0.0));
566 assert_eq!(id.get(1, 1), Complex64::new(1.0, 0.0));
567 assert_eq!(id.get(2, 2), Complex64::new(1.0, 0.0));
568 assert_eq!(id.get(0, 1), Complex64::new(0.0, 0.0));
569 }
570
571 #[test]
572 fn test_csr_builder() {
573 let mut builder = CsrBuilder::new(3, 3);
574
575 builder.add_row_entries(
577 [(0, Complex64::new(1.0, 0.0)), (2, Complex64::new(2.0, 0.0))]
578 .into_iter(),
579 );
580
581 builder.add_row_entries([(1, Complex64::new(3.0, 0.0))].into_iter());
583
584 builder.add_row_entries(
586 [(0, Complex64::new(4.0, 0.0)), (2, Complex64::new(5.0, 0.0))]
587 .into_iter(),
588 );
589
590 let csr = builder.finish();
591
592 assert_eq!(csr.nnz(), 5);
593 assert_eq!(csr.get(0, 0), Complex64::new(1.0, 0.0));
594 assert_eq!(csr.get(1, 1), Complex64::new(3.0, 0.0));
595 }
596
597 #[test]
598 fn test_csr_to_dense_roundtrip() {
599 let original = array![
600 [Complex64::new(1.0, 0.5), Complex64::new(0.0, 0.0)],
601 [Complex64::new(2.0, -1.0), Complex64::new(3.0, 0.0)],
602 ];
603
604 let csr = CsrMatrix::from_dense(&original, 1e-15);
605 let recovered = csr.to_dense();
606
607 for i in 0..2 {
608 for j in 0..2 {
609 assert!((original[[i, j]] - recovered[[i, j]]).norm() < 1e-10);
610 }
611 }
612 }
613
614 #[test]
615 fn test_csr_transpose_matvec() {
616 let dense = array![
617 [Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)],
618 [Complex64::new(3.0, 0.0), Complex64::new(4.0, 0.0)],
619 [Complex64::new(5.0, 0.0), Complex64::new(6.0, 0.0)],
620 ];
621
622 let csr = CsrMatrix::from_dense(&dense, 1e-15);
623 let x = array![
624 Complex64::new(1.0, 0.0),
625 Complex64::new(2.0, 0.0),
626 Complex64::new(3.0, 0.0)
627 ];
628
629 let y = csr.matvec_transpose(&x);
630
631 assert!((y[0] - Complex64::new(22.0, 0.0)).norm() < 1e-10);
635 assert!((y[1] - Complex64::new(28.0, 0.0)).norm() < 1e-10);
636 }
637}