1use std::ops::AddAssign;
2
3#[derive(Debug, Clone)]
5pub struct CsrMat<T> {
6 nrows: usize,
7 ncols: usize,
8 row_offsets: Vec<usize>,
9 col_indices: Vec<usize>,
10 values: Vec<T>,
11}
12
13#[derive(Debug)]
15pub struct CsrRow<'a, T> {
16 col_indices: &'a [usize],
17 values: &'a [T],
18}
19
20impl<'a, T> CsrRow<'a, T> {
21 pub fn col_indices(&self) -> &'a [usize] {
22 self.col_indices
23 }
24
25 pub fn values(&self) -> &'a [T] {
26 self.values
27 }
28
29 pub fn nnz(&self) -> usize {
30 self.col_indices.len()
31 }
32}
33
34impl<T> CsrMat<T> {
35 pub fn try_from_csr_data(
36 nrows: usize,
37 ncols: usize,
38 row_offsets: Vec<usize>,
39 col_indices: Vec<usize>,
40 values: Vec<T>,
41 ) -> Result<Self, String> {
42 if row_offsets.len() != nrows + 1 {
43 return Err(format!(
44 "row_offsets length {} does not match nrows+1={}",
45 row_offsets.len(),
46 nrows + 1
47 ));
48 }
49 if col_indices.len() != values.len() {
50 return Err(format!(
51 "col_indices length {} does not match values length {}",
52 col_indices.len(),
53 values.len()
54 ));
55 }
56 let nnz = *row_offsets.last().unwrap_or(&0);
57 if col_indices.len() != nnz {
58 return Err(format!(
59 "col_indices length {} does not match last row_offset {}",
60 col_indices.len(),
61 nnz
62 ));
63 }
64 for window in row_offsets.windows(2) {
65 if window[0] > window[1] {
66 return Err("row_offsets is not monotonically non-decreasing".into());
67 }
68 }
69 for row in 0..nrows {
70 let start = row_offsets[row];
71 let end = row_offsets[row + 1];
72 let row_cols = &col_indices[start..end];
73 if row_cols.windows(2).any(|window| window[0] >= window[1]) {
74 return Err(format!(
75 "col_indices in row {} are not strictly increasing",
76 row
77 ));
78 }
79 }
80 for &col in &col_indices {
81 if col >= ncols {
82 return Err(format!(
83 "col_index {} out of range for ncols={}",
84 col, ncols
85 ));
86 }
87 }
88 Ok(Self {
89 nrows,
90 ncols,
91 row_offsets,
92 col_indices,
93 values,
94 })
95 }
96
97 pub fn nrows(&self) -> usize {
98 self.nrows
99 }
100
101 pub fn ncols(&self) -> usize {
102 self.ncols
103 }
104
105 pub fn nnz(&self) -> usize {
106 self.values.len()
107 }
108
109 pub fn row_offsets(&self) -> &[usize] {
110 &self.row_offsets
111 }
112
113 pub fn col_indices(&self) -> &[usize] {
114 &self.col_indices
115 }
116
117 pub fn values(&self) -> &[T] {
118 &self.values
119 }
120
121 pub fn values_mut(&mut self) -> &mut [T] {
122 &mut self.values
123 }
124
125 pub fn row(&self, i: usize) -> CsrRow<'_, T> {
126 let start = self.row_offsets[i];
127 let end = self.row_offsets[i + 1];
128 CsrRow {
129 col_indices: &self.col_indices[start..end],
130 values: &self.values[start..end],
131 }
132 }
133
134 pub fn row_iter(&self) -> impl Iterator<Item = CsrRow<'_, T>> {
135 (0..self.nrows).map(move |i| self.row(i))
136 }
137}
138
139impl<T: PartialEq + Copy> CsrMat<T> {
140 pub fn get(&self, row: usize, col: usize) -> Option<&T> {
141 if row >= self.nrows || col >= self.ncols {
142 return None;
143 }
144 let start = self.row_offsets[row];
145 let end = self.row_offsets[row + 1];
146 let slice = &self.col_indices[start..end];
147 match slice.binary_search(&col) {
148 Ok(pos) => Some(&self.values[start + pos]),
149 Err(_) => None,
150 }
151 }
152}
153
154impl<T: Copy> CsrMat<T> {
155 pub fn triplet_iter(&self) -> impl Iterator<Item = (usize, usize, &T)> {
156 (0..self.nrows).flat_map(move |row| {
157 let start = self.row_offsets[row];
158 let end = self.row_offsets[row + 1];
159 (start..end).map(move |idx| (row, self.col_indices[idx], &self.values[idx]))
160 })
161 }
162}
163
164impl CsrMat<f64> {
165 pub fn identity(n: usize) -> Self {
166 let row_offsets: Vec<usize> = (0..=n).collect();
167 let col_indices: Vec<usize> = (0..n).collect();
168 let values = vec![1.0; n];
169 Self {
170 nrows: n,
171 ncols: n,
172 row_offsets,
173 col_indices,
174 values,
175 }
176 }
177
178 pub fn zeros(nrows: usize, ncols: usize) -> Self {
179 let row_offsets = vec![0; nrows + 1];
180 Self {
181 nrows,
182 ncols,
183 row_offsets,
184 col_indices: Vec::new(),
185 values: Vec::new(),
186 }
187 }
188
189 pub fn linear_combination(&self, alpha: f64, other: &Self, beta: f64) -> Result<Self, String> {
190 if self.nrows != other.nrows || self.ncols != other.ncols {
191 return Err(format!(
192 "matrix shape mismatch: lhs={}x{}, rhs={}x{}",
193 self.nrows, self.ncols, other.nrows, other.ncols
194 ));
195 }
196 if self.row_offsets != other.row_offsets || self.col_indices != other.col_indices {
197 return Err("linear_combination requires identical CSR sparsity patterns".into());
198 }
199
200 let values = self
201 .values
202 .iter()
203 .zip(other.values.iter())
204 .map(|(&lhs, &rhs)| alpha * lhs + beta * rhs)
205 .collect();
206
207 Self::try_from_csr_data(
208 self.nrows,
209 self.ncols,
210 self.row_offsets.clone(),
211 self.col_indices.clone(),
212 values,
213 )
214 }
215
216 pub fn diagonal(&self) -> Result<Vec<f64>, String> {
217 let ndiag = self.nrows.min(self.ncols);
218 let mut diagonal = Vec::with_capacity(ndiag);
219 for i in 0..ndiag {
220 let row = self.row(i);
221 let diag_pos = row
222 .col_indices()
223 .binary_search(&i)
224 .map_err(|_| format!("missing diagonal entry at row {i}"))?;
225 diagonal.push(row.values()[diag_pos]);
226 }
227 Ok(diagonal)
228 }
229
230 pub fn submatrix(&self, rows: &[usize], cols: &[usize]) -> Result<Self, String> {
231 let mut col_positions = vec![usize::MAX; self.ncols];
232 for (local_col, &global_col) in cols.iter().enumerate() {
233 if global_col >= self.ncols {
234 return Err(format!(
235 "column index {} out of range for ncols={}",
236 global_col, self.ncols
237 ));
238 }
239 if col_positions[global_col] != usize::MAX {
240 return Err(format!(
241 "duplicate column index {global_col} in submatrix request"
242 ));
243 }
244 col_positions[global_col] = local_col;
245 }
246
247 let mut row_offsets = Vec::with_capacity(rows.len() + 1);
248 let mut col_indices = Vec::new();
249 let mut values = Vec::new();
250 row_offsets.push(0);
251
252 let mut seen_rows = vec![false; self.nrows];
253 for &global_row in rows {
254 if global_row >= self.nrows {
255 return Err(format!(
256 "row index {} out of range for nrows={}",
257 global_row, self.nrows
258 ));
259 }
260 if seen_rows[global_row] {
261 return Err(format!(
262 "duplicate row index {global_row} in submatrix request"
263 ));
264 }
265 seen_rows[global_row] = true;
266
267 let row = self.row(global_row);
268 let mut entries: Vec<(usize, f64)> = row
269 .col_indices()
270 .iter()
271 .zip(row.values().iter())
272 .filter_map(|(&global_col, &value)| {
273 let local_col = col_positions[global_col];
274 (local_col != usize::MAX).then_some((local_col, value))
275 })
276 .collect();
277 entries.sort_unstable_by_key(|(local_col, _)| *local_col);
278
279 for (local_col, value) in entries {
280 col_indices.push(local_col);
281 values.push(value);
282 }
283 row_offsets.push(col_indices.len());
284 }
285
286 Self::try_from_csr_data(rows.len(), cols.len(), row_offsets, col_indices, values)
287 }
288}
289
290#[derive(Debug, Clone)]
292pub struct CooMat<T> {
293 nrows: usize,
294 ncols: usize,
295 rows: Vec<usize>,
296 cols: Vec<usize>,
297 vals: Vec<T>,
298}
299
300impl<T> CooMat<T> {
301 pub fn new(nrows: usize, ncols: usize) -> Self {
302 Self {
303 nrows,
304 ncols,
305 rows: Vec::new(),
306 cols: Vec::new(),
307 vals: Vec::new(),
308 }
309 }
310
311 pub fn push(&mut self, row: usize, col: usize, val: T) {
312 self.rows.push(row);
313 self.cols.push(col);
314 self.vals.push(val);
315 }
316}
317
318impl<T: Copy + Default + AddAssign + PartialEq> From<&CooMat<T>> for CsrMat<T> {
319 fn from(coo: &CooMat<T>) -> Self {
320 let nrows = coo.nrows;
321 let ncols = coo.ncols;
322 let nnz_raw = coo.rows.len();
323
324 if nnz_raw == 0 {
325 return Self {
326 nrows,
327 ncols,
328 row_offsets: vec![0; nrows + 1],
329 col_indices: Vec::new(),
330 values: Vec::new(),
331 };
332 }
333
334 let mut order: Vec<usize> = (0..nnz_raw).collect();
336 order.sort_unstable_by(|&a, &b| {
337 coo.rows[a]
338 .cmp(&coo.rows[b])
339 .then_with(|| coo.cols[a].cmp(&coo.cols[b]))
340 });
341
342 let mut row_offsets = Vec::with_capacity(nrows + 1);
344 let mut col_indices = Vec::with_capacity(nnz_raw);
345 let mut values = Vec::with_capacity(nnz_raw);
346
347 row_offsets.push(0);
348
349 let first_row = coo.rows[order[0]];
351 if first_row > 0 {
352 row_offsets.extend(std::iter::repeat_n(0, first_row));
353 }
354
355 let mut prev_row = first_row;
356 let mut prev_col = coo.cols[order[0]];
357 let mut acc = T::default();
358 acc += coo.vals[order[0]];
359
360 for &idx in &order[1..] {
361 let r = coo.rows[idx];
362 let c = coo.cols[idx];
363
364 if r == prev_row && c == prev_col {
365 acc += coo.vals[idx];
366 } else {
367 col_indices.push(prev_col);
369 values.push(acc);
370
371 for _ in prev_row..r {
373 row_offsets.push(col_indices.len());
374 }
375
376 prev_row = r;
377 prev_col = c;
378 acc = T::default();
379 acc += coo.vals[idx];
380 }
381 }
382
383 col_indices.push(prev_col);
385 values.push(acc);
386
387 while row_offsets.len() <= nrows {
389 row_offsets.push(col_indices.len());
390 }
391
392 debug_assert_eq!(row_offsets.len(), nrows + 1);
393
394 Self {
395 nrows,
396 ncols,
397 row_offsets,
398 col_indices,
399 values,
400 }
401 }
402}
403
404#[cfg(test)]
405mod tests {
406 use super::*;
407
408 fn csr_from_triplets(
409 nrows: usize,
410 ncols: usize,
411 triplets: &[(usize, usize, f64)],
412 ) -> CsrMat<f64> {
413 let mut coo = CooMat::new(nrows, ncols);
414 for &(row, col, value) in triplets {
415 coo.push(row, col, value);
416 }
417 CsrMat::from(&coo)
418 }
419
420 #[test]
421 fn csr_basic() {
422 let mat =
423 CsrMat::try_from_csr_data(2, 2, vec![0, 2, 3], vec![0, 1, 1], vec![1.0, 2.0, 3.0])
424 .expect("CSR construction");
425 assert_eq!(mat.nrows(), 2);
426 assert_eq!(mat.ncols(), 2);
427 assert_eq!(mat.nnz(), 3);
428 assert_eq!(mat.row(0).col_indices(), &[0, 1]);
429 assert_eq!(mat.row(0).values(), &[1.0, 2.0]);
430 assert_eq!(mat.get(0, 0), Some(&1.0));
431 assert_eq!(mat.get(0, 1), Some(&2.0));
432 assert_eq!(mat.get(1, 1), Some(&3.0));
433 assert_eq!(mat.get(1, 0), None);
434 }
435
436 #[test]
437 fn csr_rejects_unsorted_row_columns() {
438 let err = CsrMat::try_from_csr_data(1, 3, vec![0, 2], vec![2, 1], vec![1.0, 2.0])
439 .expect_err("unsorted row must be rejected");
440 assert_eq!(err, "col_indices in row 0 are not strictly increasing");
441
442 let mat = CsrMat::try_from_csr_data(1, 3, vec![0, 2], vec![0, 2], vec![1.0, 2.0])
443 .expect("sorted CSR construction");
444 assert_eq!(mat.get(0, 0), Some(&1.0));
445 assert_eq!(mat.get(0, 2), Some(&2.0));
446 }
447
448 #[test]
449 fn coo_to_csr_accumulates_duplicates() {
450 let mut coo = CooMat::new(2, 2);
451 coo.push(0, 1, 1.5);
452 coo.push(0, 1, 2.5);
453 coo.push(1, 0, 4.0);
454 let csr = CsrMat::from(&coo);
455 assert_eq!(csr.nnz(), 2);
456 assert_eq!(csr.get(0, 1), Some(&4.0));
457 assert_eq!(csr.get(1, 0), Some(&4.0));
458 }
459
460 #[test]
461 fn coo_to_csr_empty() {
462 let coo: CooMat<f64> = CooMat::new(3, 3);
463 let csr = CsrMat::from(&coo);
464 assert_eq!(csr.nrows(), 3);
465 assert_eq!(csr.ncols(), 3);
466 assert_eq!(csr.nnz(), 0);
467 for i in 0..3 {
468 assert_eq!(csr.row(i).nnz(), 0);
469 }
470 }
471
472 #[test]
473 fn coo_to_csr_single_element() {
474 let mut coo = CooMat::new(5, 5);
475 coo.push(2, 3, 7.0);
476 let csr = CsrMat::from(&coo);
477 assert_eq!(csr.nnz(), 1);
478 assert_eq!(csr.get(2, 3), Some(&7.0));
479 assert_eq!(csr.get(0, 0), None);
480 }
481
482 #[test]
483 fn coo_to_csr_reverse_column_order() {
484 let mut coo = CooMat::new(1, 4);
485 coo.push(0, 3, 4.0);
486 coo.push(0, 1, 2.0);
487 coo.push(0, 0, 1.0);
488 coo.push(0, 2, 3.0);
489 let csr = CsrMat::from(&coo);
490 assert_eq!(csr.row(0).col_indices(), &[0, 1, 2, 3]);
491 assert_eq!(csr.row(0).values(), &[1.0, 2.0, 3.0, 4.0]);
492 }
493
494 #[test]
495 fn coo_to_csr_multiple_duplicates() {
496 let mut coo = CooMat::new(2, 2);
497 coo.push(0, 0, 1.0);
498 coo.push(0, 0, 2.0);
499 coo.push(0, 0, 3.0);
500 coo.push(0, 0, 4.0);
501 coo.push(1, 1, 5.0);
502 let csr = CsrMat::from(&coo);
503 assert_eq!(csr.nnz(), 2);
504 assert_eq!(csr.get(0, 0), Some(&10.0));
505 assert_eq!(csr.get(1, 1), Some(&5.0));
506 }
507
508 #[test]
509 fn coo_to_csr_sorted_columns_per_row() {
510 let mut coo = CooMat::new(3, 5);
511 coo.push(0, 4, 1.0);
512 coo.push(0, 0, 2.0);
513 coo.push(1, 3, 3.0);
514 coo.push(1, 1, 4.0);
515 coo.push(2, 2, 5.0);
516 let csr = CsrMat::from(&coo);
517 assert_eq!(csr.row(0).col_indices(), &[0, 4]);
518 assert_eq!(csr.row(1).col_indices(), &[1, 3]);
519 assert_eq!(csr.row(2).col_indices(), &[2]);
520 }
521
522 #[test]
523 fn coo_to_csr_sparse_rows_with_interior_gaps() {
524 let mut coo = CooMat::new(5, 3);
525 coo.push(0, 1, 1.0);
526 coo.push(4, 2, 2.0);
527 let csr = CsrMat::from(&coo);
528 assert_eq!(csr.nnz(), 2);
529 assert_eq!(csr.get(0, 1), Some(&1.0));
530 assert_eq!(csr.get(4, 2), Some(&2.0));
531 for i in 1..4 {
532 assert_eq!(csr.row(i).nnz(), 0);
533 }
534 }
535
536 #[test]
537 fn coo_to_csr_integer_type() {
538 let mut coo: CooMat<i32> = CooMat::new(2, 2);
539 coo.push(0, 0, 10);
540 coo.push(0, 0, 20);
541 coo.push(1, 1, 30);
542 let csr = CsrMat::from(&coo);
543 assert_eq!(csr.get(0, 0), Some(&30));
544 assert_eq!(csr.get(1, 1), Some(&30));
545 }
546
547 #[test]
548 fn linear_combination_matches_shifted_matrix_pattern() {
549 let k = csr_from_triplets(2, 2, &[(0, 0, 4.0), (0, 1, 1.0), (1, 1, 3.0)]);
550 let m = csr_from_triplets(2, 2, &[(0, 0, 1.0), (0, 1, 0.5), (1, 1, 2.0)]);
551
552 let shifted = k.linear_combination(1.0, &m, -2.0).unwrap();
553 assert_eq!(shifted.row(0).col_indices(), &[0, 1]);
554 assert_eq!(shifted.row(0).values(), &[2.0, 0.0]);
555 assert_eq!(shifted.row(1).values(), &[-1.0]);
556 }
557
558 #[test]
559 fn linear_combination_rejects_pattern_mismatch() {
560 let lhs = csr_from_triplets(2, 2, &[(0, 0, 1.0), (1, 1, 2.0)]);
561 let rhs = csr_from_triplets(2, 2, &[(0, 0, 1.0), (0, 1, 2.0), (1, 1, 3.0)]);
562
563 let err = lhs.linear_combination(1.0, &rhs, -1.0).unwrap_err();
564 assert!(err.contains("identical CSR sparsity patterns"), "err={err}");
565 }
566
567 #[test]
568 fn diagonal_extracts_all_present_entries() {
569 let mat = csr_from_triplets(3, 3, &[(0, 0, 2.0), (0, 2, 9.0), (1, 1, 3.0), (2, 2, 5.0)]);
570 assert_eq!(mat.diagonal().unwrap(), vec![2.0, 3.0, 5.0]);
571 }
572
573 #[test]
574 fn diagonal_rejects_missing_entry() {
575 let mat = csr_from_triplets(2, 2, &[(0, 1, 1.0), (1, 1, 2.0)]);
576 let err = mat.diagonal().unwrap_err();
577 assert!(err.contains("missing diagonal entry"), "err={err}");
578 }
579
580 #[test]
581 fn submatrix_preserves_requested_order_with_sorted_local_columns() {
582 let mat = csr_from_triplets(
583 3,
584 4,
585 &[
586 (0, 0, 1.0),
587 (0, 1, 2.0),
588 (0, 3, 3.0),
589 (1, 0, 4.0),
590 (1, 2, 5.0),
591 (2, 1, 6.0),
592 (2, 3, 7.0),
593 ],
594 );
595
596 let sub = mat.submatrix(&[2, 0], &[3, 1]).unwrap();
597 assert_eq!(sub.nrows(), 2);
598 assert_eq!(sub.ncols(), 2);
599 assert_eq!(sub.row(0).col_indices(), &[0, 1]);
600 assert_eq!(sub.row(0).values(), &[7.0, 6.0]);
601 assert_eq!(sub.row(1).col_indices(), &[0, 1]);
602 assert_eq!(sub.row(1).values(), &[3.0, 2.0]);
603 }
604
605 #[test]
606 fn submatrix_rejects_duplicate_indices() {
607 let mat = csr_from_triplets(2, 3, &[(0, 0, 1.0), (0, 1, 2.0), (1, 2, 3.0)]);
608
609 let row_err = mat.submatrix(&[0, 0], &[1]).unwrap_err();
610 assert!(row_err.contains("duplicate row index"), "err={row_err}");
611
612 let col_err = mat.submatrix(&[0], &[1, 1]).unwrap_err();
613 assert!(col_err.contains("duplicate column index"), "err={col_err}");
614 }
615}