1use std::collections::BTreeMap;
15
16use cjc_repro::kahan_sum_f64;
17
18use crate::accumulator::binned_sum_f64;
19use crate::error::RuntimeError;
20use crate::tensor::Tensor;
21
22#[derive(Debug, Clone)]
32pub struct SparseCsr {
33 pub values: Vec<f64>,
34 pub col_indices: Vec<usize>,
35 pub row_offsets: Vec<usize>, pub nrows: usize,
37 pub ncols: usize,
38}
39
40impl SparseCsr {
41 pub fn nnz(&self) -> usize {
43 self.values.len()
44 }
45
46 pub fn get(&self, row: usize, col: usize) -> f64 {
48 if row >= self.nrows || col >= self.ncols {
49 return 0.0;
50 }
51 let start = self.row_offsets[row];
52 let end = self.row_offsets[row + 1];
53 for idx in start..end {
54 if self.col_indices[idx] == col {
55 return self.values[idx];
56 }
57 }
58 0.0
59 }
60
61 pub fn matvec(&self, x: &[f64]) -> Result<Vec<f64>, RuntimeError> {
69 if x.len() != self.ncols {
70 return Err(RuntimeError::DimensionMismatch {
71 expected: self.ncols,
72 got: x.len(),
73 });
74 }
75 let mut y = vec![0.0f64; self.nrows];
76 for row in 0..self.nrows {
77 let start = self.row_offsets[row];
78 let end = self.row_offsets[row + 1];
79 let products: Vec<f64> = (start..end)
80 .map(|idx| self.values[idx] * x[self.col_indices[idx]])
81 .collect();
82 y[row] = kahan_sum_f64(&products);
83 }
84 Ok(y)
85 }
86
87 pub fn to_dense(&self) -> Tensor {
89 let mut data = vec![0.0f64; self.nrows * self.ncols];
90 for row in 0..self.nrows {
91 let start = self.row_offsets[row];
92 let end = self.row_offsets[row + 1];
93 for idx in start..end {
94 data[row * self.ncols + self.col_indices[idx]] = self.values[idx];
95 }
96 }
97 Tensor::from_vec(data, &[self.nrows, self.ncols]).unwrap()
98 }
99
100 pub fn from_coo(coo: &SparseCoo) -> Self {
102 let nnz = coo.values.len();
104 let mut order: Vec<usize> = (0..nnz).collect();
105 order.sort_by_key(|&i| (coo.row_indices[i], coo.col_indices[i]));
106
107 let mut values = Vec::with_capacity(nnz);
108 let mut col_indices = Vec::with_capacity(nnz);
109 let mut row_offsets = vec![0usize; coo.nrows + 1];
110
111 for &i in &order {
112 values.push(coo.values[i]);
113 col_indices.push(coo.col_indices[i]);
114 row_offsets[coo.row_indices[i] + 1] += 1;
115 }
116
117 for i in 1..=coo.nrows {
119 row_offsets[i] += row_offsets[i - 1];
120 }
121
122 SparseCsr {
123 values,
124 col_indices,
125 row_offsets,
126 nrows: coo.nrows,
127 ncols: coo.ncols,
128 }
129 }
130}
131
132#[derive(Debug, Clone)]
138pub struct SparseCoo {
139 pub values: Vec<f64>,
140 pub row_indices: Vec<usize>,
141 pub col_indices: Vec<usize>,
142 pub nrows: usize,
143 pub ncols: usize,
144}
145
146impl SparseCoo {
147 pub fn new(
149 values: Vec<f64>,
150 row_indices: Vec<usize>,
151 col_indices: Vec<usize>,
152 nrows: usize,
153 ncols: usize,
154 ) -> Self {
155 SparseCoo {
156 values,
157 row_indices,
158 col_indices,
159 nrows,
160 ncols,
161 }
162 }
163
164 pub fn nnz(&self) -> usize {
166 self.values.len()
167 }
168
169 pub fn to_csr(&self) -> SparseCsr {
171 SparseCsr::from_coo(self)
172 }
173
174 pub fn sum(&self) -> f64 {
176 kahan_sum_f64(&self.values)
177 }
178}
179
180fn merge_rows(
187 a_vals: &[f64],
188 a_cols: &[usize],
189 b_vals: &[f64],
190 b_cols: &[usize],
191 combine: fn(f64, f64) -> f64,
192 default_a: f64,
193 default_b: f64,
194) -> (Vec<f64>, Vec<usize>) {
195 let mut values = Vec::new();
196 let mut cols = Vec::new();
197 let mut ia = 0;
198 let mut ib = 0;
199
200 while ia < a_cols.len() && ib < b_cols.len() {
201 match a_cols[ia].cmp(&b_cols[ib]) {
202 std::cmp::Ordering::Less => {
203 let v = combine(a_vals[ia], default_b);
204 if v != 0.0 {
205 values.push(v);
206 cols.push(a_cols[ia]);
207 }
208 ia += 1;
209 }
210 std::cmp::Ordering::Greater => {
211 let v = combine(default_a, b_vals[ib]);
212 if v != 0.0 {
213 values.push(v);
214 cols.push(b_cols[ib]);
215 }
216 ib += 1;
217 }
218 std::cmp::Ordering::Equal => {
219 let v = combine(a_vals[ia], b_vals[ib]);
220 if v != 0.0 {
221 values.push(v);
222 cols.push(a_cols[ia]);
223 }
224 ia += 1;
225 ib += 1;
226 }
227 }
228 }
229 while ia < a_cols.len() {
230 let v = combine(a_vals[ia], default_b);
231 if v != 0.0 {
232 values.push(v);
233 cols.push(a_cols[ia]);
234 }
235 ia += 1;
236 }
237 while ib < b_cols.len() {
238 let v = combine(default_a, b_vals[ib]);
239 if v != 0.0 {
240 values.push(v);
241 cols.push(b_cols[ib]);
242 }
243 ib += 1;
244 }
245 (values, cols)
246}
247
248fn sparse_binop(
250 a: &SparseCsr,
251 b: &SparseCsr,
252 combine: fn(f64, f64) -> f64,
253 default_a: f64,
254 default_b: f64,
255 op_name: &str,
256) -> Result<SparseCsr, String> {
257 if a.nrows != b.nrows || a.ncols != b.ncols {
258 return Err(format!(
259 "sparse_{}: dimension mismatch: ({}, {}) vs ({}, {})",
260 op_name, a.nrows, a.ncols, b.nrows, b.ncols
261 ));
262 }
263
264 let mut values = Vec::new();
265 let mut col_indices = Vec::new();
266 let mut row_offsets = Vec::with_capacity(a.nrows + 1);
267 row_offsets.push(0);
268
269 for row in 0..a.nrows {
270 let a_start = a.row_offsets[row];
271 let a_end = a.row_offsets[row + 1];
272 let b_start = b.row_offsets[row];
273 let b_end = b.row_offsets[row + 1];
274
275 let (rv, rc) = merge_rows(
276 &a.values[a_start..a_end],
277 &a.col_indices[a_start..a_end],
278 &b.values[b_start..b_end],
279 &b.col_indices[b_start..b_end],
280 combine,
281 default_a,
282 default_b,
283 );
284 values.extend_from_slice(&rv);
285 col_indices.extend_from_slice(&rc);
286 row_offsets.push(values.len());
287 }
288
289 Ok(SparseCsr {
290 values,
291 col_indices,
292 row_offsets,
293 nrows: a.nrows,
294 ncols: a.ncols,
295 })
296}
297
298pub fn sparse_add(a: &SparseCsr, b: &SparseCsr) -> Result<SparseCsr, String> {
300 sparse_binop(a, b, |x, y| x + y, 0.0, 0.0, "add")
301}
302
303pub fn sparse_sub(a: &SparseCsr, b: &SparseCsr) -> Result<SparseCsr, String> {
305 sparse_binop(a, b, |x, y| x - y, 0.0, 0.0, "sub")
306}
307
308pub fn sparse_mul(a: &SparseCsr, b: &SparseCsr) -> Result<SparseCsr, String> {
311 if a.nrows != b.nrows || a.ncols != b.ncols {
312 return Err(format!(
313 "sparse_mul: dimension mismatch: ({}, {}) vs ({}, {})",
314 a.nrows, a.ncols, b.nrows, b.ncols
315 ));
316 }
317
318 let mut values = Vec::new();
319 let mut col_indices = Vec::new();
320 let mut row_offsets = Vec::with_capacity(a.nrows + 1);
321 row_offsets.push(0);
322
323 for row in 0..a.nrows {
324 let a_start = a.row_offsets[row];
325 let a_end = a.row_offsets[row + 1];
326 let b_start = b.row_offsets[row];
327 let b_end = b.row_offsets[row + 1];
328
329 let mut ia = a_start;
330 let mut ib = b_start;
331
332 while ia < a_end && ib < b_end {
334 match a.col_indices[ia].cmp(&b.col_indices[ib]) {
335 std::cmp::Ordering::Less => ia += 1,
336 std::cmp::Ordering::Greater => ib += 1,
337 std::cmp::Ordering::Equal => {
338 let v = a.values[ia] * b.values[ib];
339 if v != 0.0 {
340 values.push(v);
341 col_indices.push(a.col_indices[ia]);
342 }
343 ia += 1;
344 ib += 1;
345 }
346 }
347 }
348 row_offsets.push(values.len());
349 }
350
351 Ok(SparseCsr {
352 values,
353 col_indices,
354 row_offsets,
355 nrows: a.nrows,
356 ncols: a.ncols,
357 })
358}
359
360pub fn sparse_matmul(a: &SparseCsr, b: &SparseCsr) -> Result<SparseCsr, String> {
364 if a.ncols != b.nrows {
365 return Err(format!(
366 "sparse_matmul: inner dimension mismatch: A is ({}, {}), B is ({}, {})",
367 a.nrows, a.ncols, b.nrows, b.ncols
368 ));
369 }
370
371 let mut values = Vec::new();
372 let mut col_indices = Vec::new();
373 let mut row_offsets = Vec::with_capacity(a.nrows + 1);
374 row_offsets.push(0);
375
376 for row in 0..a.nrows {
377 let mut accum: BTreeMap<usize, Vec<f64>> = BTreeMap::new();
379
380 let a_start = a.row_offsets[row];
381 let a_end = a.row_offsets[row + 1];
382
383 for a_idx in a_start..a_end {
384 let k = a.col_indices[a_idx];
385 let a_val = a.values[a_idx];
386
387 let b_start = b.row_offsets[k];
388 let b_end = b.row_offsets[k + 1];
389
390 for b_idx in b_start..b_end {
391 let j = b.col_indices[b_idx];
392 accum.entry(j).or_default().push(a_val * b.values[b_idx]);
393 }
394 }
395
396 for (col, terms) in &accum {
398 let v = binned_sum_f64(&terms);
399 if v != 0.0 {
400 col_indices.push(*col);
401 values.push(v);
402 }
403 }
404 row_offsets.push(values.len());
405 }
406
407 Ok(SparseCsr {
408 values,
409 col_indices,
410 row_offsets,
411 nrows: a.nrows,
412 ncols: b.ncols,
413 })
414}
415
416pub fn sparse_scalar_mul(a: &SparseCsr, s: f64) -> SparseCsr {
418 let values: Vec<f64> = a.values.iter().map(|&v| v * s).collect();
419 SparseCsr {
420 values,
421 col_indices: a.col_indices.clone(),
422 row_offsets: a.row_offsets.clone(),
423 nrows: a.nrows,
424 ncols: a.ncols,
425 }
426}
427
428pub fn sparse_transpose(a: &SparseCsr) -> SparseCsr {
430 let mut row_counts = vec![0usize; a.ncols + 1];
432
433 for &c in &a.col_indices {
435 row_counts[c + 1] += 1;
436 }
437 for i in 1..=a.ncols {
439 row_counts[i] += row_counts[i - 1];
440 }
441
442 let nnz = a.values.len();
443 let mut new_values = vec![0.0f64; nnz];
444 let mut new_col_indices = vec![0usize; nnz];
445 let mut cursor = row_counts.clone();
446
447 for row in 0..a.nrows {
448 let start = a.row_offsets[row];
449 let end = a.row_offsets[row + 1];
450 for idx in start..end {
451 let col = a.col_indices[idx];
452 let dest = cursor[col];
453 new_values[dest] = a.values[idx];
454 new_col_indices[dest] = row;
455 cursor[col] += 1;
456 }
457 }
458
459 SparseCsr {
460 values: new_values,
461 col_indices: new_col_indices,
462 row_offsets: row_counts,
463 nrows: a.ncols,
464 ncols: a.nrows,
465 }
466}
467
468#[cfg(test)]
473mod tests {
474 use super::*;
475
476 fn csr_from_dense(data: &[f64], nrows: usize, ncols: usize) -> SparseCsr {
478 let mut values = Vec::new();
479 let mut col_indices = Vec::new();
480 let mut row_offsets = vec![0usize];
481
482 for r in 0..nrows {
483 for c in 0..ncols {
484 let v = data[r * ncols + c];
485 if v != 0.0 {
486 values.push(v);
487 col_indices.push(c);
488 }
489 }
490 row_offsets.push(values.len());
491 }
492
493 SparseCsr { values, col_indices, row_offsets, nrows, ncols }
494 }
495
496 #[test]
499 fn test_sparse_add_basic() {
500 let a = csr_from_dense(&[1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 0.0, 4.0, 5.0], 3, 3);
501 let b = csr_from_dense(&[0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 4.0, 0.0, 5.0], 3, 3);
502 let c = sparse_add(&a, &b).unwrap();
503 for r in 0..3 {
505 for col in 0..3 {
506 let expected = a.get(r, col) + b.get(r, col);
507 assert_eq!(c.get(r, col), expected, "mismatch at ({}, {})", r, col);
508 }
509 }
510 }
511
512 #[test]
513 fn test_sparse_add_a_plus_a_eq_2a() {
514 let a = csr_from_dense(&[1.0, 2.0, 0.0, 3.0], 2, 2);
515 let sum = sparse_add(&a, &a).unwrap();
516 let doubled = sparse_scalar_mul(&a, 2.0);
517 for r in 0..2 {
518 for c in 0..2 {
519 assert_eq!(sum.get(r, c), doubled.get(r, c));
520 }
521 }
522 }
523
524 #[test]
525 fn test_sparse_add_dimension_mismatch() {
526 let a = csr_from_dense(&[1.0, 2.0], 1, 2);
527 let b = csr_from_dense(&[1.0, 2.0, 3.0], 1, 3);
528 assert!(sparse_add(&a, &b).is_err());
529 }
530
531 #[test]
534 fn test_sparse_sub_basic() {
535 let a = csr_from_dense(&[5.0, 3.0, 0.0, 1.0], 2, 2);
536 let b = csr_from_dense(&[2.0, 3.0, 1.0, 0.0], 2, 2);
537 let c = sparse_sub(&a, &b).unwrap();
538 assert_eq!(c.get(0, 0), 3.0);
539 assert_eq!(c.get(0, 1), 0.0); assert_eq!(c.get(1, 0), -1.0);
541 assert_eq!(c.get(1, 1), 1.0);
542 }
543
544 #[test]
545 fn test_sparse_sub_self_is_zero() {
546 let a = csr_from_dense(&[1.0, 2.0, 3.0, 4.0], 2, 2);
547 let c = sparse_sub(&a, &a).unwrap();
548 assert_eq!(c.nnz(), 0);
549 }
550
551 #[test]
554 fn test_sparse_mul_hadamard() {
555 let a = csr_from_dense(&[1.0, 0.0, 3.0, 4.0], 2, 2);
556 let b = csr_from_dense(&[2.0, 5.0, 0.0, 3.0], 2, 2);
557 let c = sparse_mul(&a, &b).unwrap();
558 assert_eq!(c.get(0, 0), 2.0); assert_eq!(c.get(0, 1), 0.0); assert_eq!(c.get(1, 0), 0.0); assert_eq!(c.get(1, 1), 12.0); }
563
564 #[test]
567 fn test_sparse_matmul_identity() {
568 let a = csr_from_dense(&[1.0, 2.0, 3.0, 4.0], 2, 2);
570 let eye = csr_from_dense(&[1.0, 0.0, 0.0, 1.0], 2, 2);
571 let c = sparse_matmul(&a, &eye).unwrap();
572 for r in 0..2 {
573 for col in 0..2 {
574 assert_eq!(c.get(r, col), a.get(r, col));
575 }
576 }
577 }
578
579 #[test]
580 fn test_sparse_matmul_vs_dense() {
581 let a_data = [1.0, 2.0, 0.0, 0.0, 3.0, 4.0];
583 let b_data = [5.0, 0.0, 6.0, 7.0, 0.0, 8.0];
584 let a = csr_from_dense(&a_data, 2, 3);
585 let b = csr_from_dense(&b_data, 3, 2);
586
587 let c = sparse_matmul(&a, &b).unwrap();
588
589 assert_eq!(c.get(0, 0), 17.0);
595 assert_eq!(c.get(0, 1), 14.0);
596 assert_eq!(c.get(1, 0), 18.0);
597 assert_eq!(c.get(1, 1), 53.0);
598 }
599
600 #[test]
601 fn test_sparse_matmul_dimension_mismatch() {
602 let a = csr_from_dense(&[1.0, 2.0], 1, 2);
603 let b = csr_from_dense(&[1.0, 2.0, 3.0], 1, 3);
604 assert!(sparse_matmul(&a, &b).is_err());
605 }
606
607 #[test]
610 fn test_sparse_scalar_mul_basic() {
611 let a = csr_from_dense(&[2.0, 0.0, 0.0, 4.0], 2, 2);
612 let c = sparse_scalar_mul(&a, 3.0);
613 assert_eq!(c.get(0, 0), 6.0);
614 assert_eq!(c.get(1, 1), 12.0);
615 assert_eq!(c.nnz(), 2);
616 }
617
618 #[test]
621 fn test_sparse_transpose_square() {
622 let a = csr_from_dense(&[1.0, 2.0, 3.0, 4.0], 2, 2);
623 let at = sparse_transpose(&a);
624 assert_eq!(at.get(0, 0), 1.0);
625 assert_eq!(at.get(0, 1), 3.0);
626 assert_eq!(at.get(1, 0), 2.0);
627 assert_eq!(at.get(1, 1), 4.0);
628 }
629
630 #[test]
631 fn test_sparse_transpose_rect() {
632 let a = csr_from_dense(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
633 let at = sparse_transpose(&a);
634 assert_eq!(at.nrows, 3);
635 assert_eq!(at.ncols, 2);
636 for r in 0..2 {
637 for c in 0..3 {
638 assert_eq!(at.get(c, r), a.get(r, c), "mismatch at transpose({}, {})", c, r);
639 }
640 }
641 }
642
643 #[test]
644 fn test_sparse_transpose_double_is_identity() {
645 let a = csr_from_dense(&[1.0, 0.0, 2.0, 3.0, 0.0, 4.0], 2, 3);
646 let att = sparse_transpose(&sparse_transpose(&a));
647 assert_eq!(att.nrows, a.nrows);
648 assert_eq!(att.ncols, a.ncols);
649 for r in 0..a.nrows {
650 for c in 0..a.ncols {
651 assert_eq!(att.get(r, c), a.get(r, c));
652 }
653 }
654 }
655
656 #[test]
659 fn test_sparse_matmul_determinism() {
660 let a = csr_from_dense(&[1.0, 2.0, 0.0, 0.0, 3.0, 4.0], 2, 3);
661 let b = csr_from_dense(&[5.0, 0.0, 6.0, 7.0, 0.0, 8.0], 3, 2);
662
663 let c1 = sparse_matmul(&a, &b).unwrap();
664 let c2 = sparse_matmul(&a, &b).unwrap();
665
666 assert_eq!(c1.values, c2.values);
667 assert_eq!(c1.col_indices, c2.col_indices);
668 assert_eq!(c1.row_offsets, c2.row_offsets);
669 }
670
671 #[test]
672 fn test_sparse_add_determinism() {
673 let a = csr_from_dense(&[1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 0.0, 4.0, 5.0], 3, 3);
674 let b = csr_from_dense(&[0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 4.0, 0.0, 5.0], 3, 3);
675
676 let c1 = sparse_add(&a, &b).unwrap();
677 let c2 = sparse_add(&a, &b).unwrap();
678
679 assert_eq!(c1.values, c2.values);
680 assert_eq!(c1.col_indices, c2.col_indices);
681 }
682}
683