1use std::collections::BTreeMap;
2
3use cjc_repro::kahan_sum_f64;
4
5use crate::accumulator::binned_sum_f64;
6use crate::error::RuntimeError;
7use crate::tensor::Tensor;
8
9#[derive(Debug, Clone)]
15pub struct SparseCsr {
16 pub values: Vec<f64>,
17 pub col_indices: Vec<usize>,
18 pub row_offsets: Vec<usize>, pub nrows: usize,
20 pub ncols: usize,
21}
22
23impl SparseCsr {
24 pub fn nnz(&self) -> usize {
26 self.values.len()
27 }
28
29 pub fn get(&self, row: usize, col: usize) -> f64 {
31 if row >= self.nrows || col >= self.ncols {
32 return 0.0;
33 }
34 let start = self.row_offsets[row];
35 let end = self.row_offsets[row + 1];
36 for idx in start..end {
37 if self.col_indices[idx] == col {
38 return self.values[idx];
39 }
40 }
41 0.0
42 }
43
44 pub fn matvec(&self, x: &[f64]) -> Result<Vec<f64>, RuntimeError> {
46 if x.len() != self.ncols {
47 return Err(RuntimeError::DimensionMismatch {
48 expected: self.ncols,
49 got: x.len(),
50 });
51 }
52 let mut y = vec![0.0f64; self.nrows];
53 for row in 0..self.nrows {
54 let start = self.row_offsets[row];
55 let end = self.row_offsets[row + 1];
56 let products: Vec<f64> = (start..end)
57 .map(|idx| self.values[idx] * x[self.col_indices[idx]])
58 .collect();
59 y[row] = kahan_sum_f64(&products);
60 }
61 Ok(y)
62 }
63
64 pub fn to_dense(&self) -> Tensor {
66 let mut data = vec![0.0f64; self.nrows * self.ncols];
67 for row in 0..self.nrows {
68 let start = self.row_offsets[row];
69 let end = self.row_offsets[row + 1];
70 for idx in start..end {
71 data[row * self.ncols + self.col_indices[idx]] = self.values[idx];
72 }
73 }
74 Tensor::from_vec(data, &[self.nrows, self.ncols]).unwrap()
75 }
76
77 pub fn from_coo(coo: &SparseCoo) -> Self {
79 let nnz = coo.values.len();
81 let mut order: Vec<usize> = (0..nnz).collect();
82 order.sort_by_key(|&i| (coo.row_indices[i], coo.col_indices[i]));
83
84 let mut values = Vec::with_capacity(nnz);
85 let mut col_indices = Vec::with_capacity(nnz);
86 let mut row_offsets = vec![0usize; coo.nrows + 1];
87
88 for &i in &order {
89 values.push(coo.values[i]);
90 col_indices.push(coo.col_indices[i]);
91 row_offsets[coo.row_indices[i] + 1] += 1;
92 }
93
94 for i in 1..=coo.nrows {
96 row_offsets[i] += row_offsets[i - 1];
97 }
98
99 SparseCsr {
100 values,
101 col_indices,
102 row_offsets,
103 nrows: coo.nrows,
104 ncols: coo.ncols,
105 }
106 }
107}
108
109#[derive(Debug, Clone)]
111pub struct SparseCoo {
112 pub values: Vec<f64>,
113 pub row_indices: Vec<usize>,
114 pub col_indices: Vec<usize>,
115 pub nrows: usize,
116 pub ncols: usize,
117}
118
119impl SparseCoo {
120 pub fn new(
121 values: Vec<f64>,
122 row_indices: Vec<usize>,
123 col_indices: Vec<usize>,
124 nrows: usize,
125 ncols: usize,
126 ) -> Self {
127 SparseCoo {
128 values,
129 row_indices,
130 col_indices,
131 nrows,
132 ncols,
133 }
134 }
135
136 pub fn nnz(&self) -> usize {
137 self.values.len()
138 }
139
140 pub fn to_csr(&self) -> SparseCsr {
141 SparseCsr::from_coo(self)
142 }
143
144 pub fn sum(&self) -> f64 {
145 kahan_sum_f64(&self.values)
146 }
147}
148
149fn merge_rows(
156 a_vals: &[f64],
157 a_cols: &[usize],
158 b_vals: &[f64],
159 b_cols: &[usize],
160 combine: fn(f64, f64) -> f64,
161 default_a: f64,
162 default_b: f64,
163) -> (Vec<f64>, Vec<usize>) {
164 let mut values = Vec::new();
165 let mut cols = Vec::new();
166 let mut ia = 0;
167 let mut ib = 0;
168
169 while ia < a_cols.len() && ib < b_cols.len() {
170 match a_cols[ia].cmp(&b_cols[ib]) {
171 std::cmp::Ordering::Less => {
172 let v = combine(a_vals[ia], default_b);
173 if v != 0.0 {
174 values.push(v);
175 cols.push(a_cols[ia]);
176 }
177 ia += 1;
178 }
179 std::cmp::Ordering::Greater => {
180 let v = combine(default_a, b_vals[ib]);
181 if v != 0.0 {
182 values.push(v);
183 cols.push(b_cols[ib]);
184 }
185 ib += 1;
186 }
187 std::cmp::Ordering::Equal => {
188 let v = combine(a_vals[ia], b_vals[ib]);
189 if v != 0.0 {
190 values.push(v);
191 cols.push(a_cols[ia]);
192 }
193 ia += 1;
194 ib += 1;
195 }
196 }
197 }
198 while ia < a_cols.len() {
199 let v = combine(a_vals[ia], default_b);
200 if v != 0.0 {
201 values.push(v);
202 cols.push(a_cols[ia]);
203 }
204 ia += 1;
205 }
206 while ib < b_cols.len() {
207 let v = combine(default_a, b_vals[ib]);
208 if v != 0.0 {
209 values.push(v);
210 cols.push(b_cols[ib]);
211 }
212 ib += 1;
213 }
214 (values, cols)
215}
216
217fn sparse_binop(
219 a: &SparseCsr,
220 b: &SparseCsr,
221 combine: fn(f64, f64) -> f64,
222 default_a: f64,
223 default_b: f64,
224 op_name: &str,
225) -> Result<SparseCsr, String> {
226 if a.nrows != b.nrows || a.ncols != b.ncols {
227 return Err(format!(
228 "sparse_{}: dimension mismatch: ({}, {}) vs ({}, {})",
229 op_name, a.nrows, a.ncols, b.nrows, b.ncols
230 ));
231 }
232
233 let mut values = Vec::new();
234 let mut col_indices = Vec::new();
235 let mut row_offsets = Vec::with_capacity(a.nrows + 1);
236 row_offsets.push(0);
237
238 for row in 0..a.nrows {
239 let a_start = a.row_offsets[row];
240 let a_end = a.row_offsets[row + 1];
241 let b_start = b.row_offsets[row];
242 let b_end = b.row_offsets[row + 1];
243
244 let (rv, rc) = merge_rows(
245 &a.values[a_start..a_end],
246 &a.col_indices[a_start..a_end],
247 &b.values[b_start..b_end],
248 &b.col_indices[b_start..b_end],
249 combine,
250 default_a,
251 default_b,
252 );
253 values.extend_from_slice(&rv);
254 col_indices.extend_from_slice(&rc);
255 row_offsets.push(values.len());
256 }
257
258 Ok(SparseCsr {
259 values,
260 col_indices,
261 row_offsets,
262 nrows: a.nrows,
263 ncols: a.ncols,
264 })
265}
266
267pub fn sparse_add(a: &SparseCsr, b: &SparseCsr) -> Result<SparseCsr, String> {
269 sparse_binop(a, b, |x, y| x + y, 0.0, 0.0, "add")
270}
271
272pub fn sparse_sub(a: &SparseCsr, b: &SparseCsr) -> Result<SparseCsr, String> {
274 sparse_binop(a, b, |x, y| x - y, 0.0, 0.0, "sub")
275}
276
277pub fn sparse_mul(a: &SparseCsr, b: &SparseCsr) -> Result<SparseCsr, String> {
280 if a.nrows != b.nrows || a.ncols != b.ncols {
281 return Err(format!(
282 "sparse_mul: dimension mismatch: ({}, {}) vs ({}, {})",
283 a.nrows, a.ncols, b.nrows, b.ncols
284 ));
285 }
286
287 let mut values = Vec::new();
288 let mut col_indices = Vec::new();
289 let mut row_offsets = Vec::with_capacity(a.nrows + 1);
290 row_offsets.push(0);
291
292 for row in 0..a.nrows {
293 let a_start = a.row_offsets[row];
294 let a_end = a.row_offsets[row + 1];
295 let b_start = b.row_offsets[row];
296 let b_end = b.row_offsets[row + 1];
297
298 let mut ia = a_start;
299 let mut ib = b_start;
300
301 while ia < a_end && ib < b_end {
303 match a.col_indices[ia].cmp(&b.col_indices[ib]) {
304 std::cmp::Ordering::Less => ia += 1,
305 std::cmp::Ordering::Greater => ib += 1,
306 std::cmp::Ordering::Equal => {
307 let v = a.values[ia] * b.values[ib];
308 if v != 0.0 {
309 values.push(v);
310 col_indices.push(a.col_indices[ia]);
311 }
312 ia += 1;
313 ib += 1;
314 }
315 }
316 }
317 row_offsets.push(values.len());
318 }
319
320 Ok(SparseCsr {
321 values,
322 col_indices,
323 row_offsets,
324 nrows: a.nrows,
325 ncols: a.ncols,
326 })
327}
328
329pub fn sparse_matmul(a: &SparseCsr, b: &SparseCsr) -> Result<SparseCsr, String> {
333 if a.ncols != b.nrows {
334 return Err(format!(
335 "sparse_matmul: inner dimension mismatch: A is ({}, {}), B is ({}, {})",
336 a.nrows, a.ncols, b.nrows, b.ncols
337 ));
338 }
339
340 let mut values = Vec::new();
341 let mut col_indices = Vec::new();
342 let mut row_offsets = Vec::with_capacity(a.nrows + 1);
343 row_offsets.push(0);
344
345 for row in 0..a.nrows {
346 let mut accum: BTreeMap<usize, Vec<f64>> = BTreeMap::new();
348
349 let a_start = a.row_offsets[row];
350 let a_end = a.row_offsets[row + 1];
351
352 for a_idx in a_start..a_end {
353 let k = a.col_indices[a_idx];
354 let a_val = a.values[a_idx];
355
356 let b_start = b.row_offsets[k];
357 let b_end = b.row_offsets[k + 1];
358
359 for b_idx in b_start..b_end {
360 let j = b.col_indices[b_idx];
361 accum.entry(j).or_default().push(a_val * b.values[b_idx]);
362 }
363 }
364
365 for (col, terms) in &accum {
367 let v = binned_sum_f64(&terms);
368 if v != 0.0 {
369 col_indices.push(*col);
370 values.push(v);
371 }
372 }
373 row_offsets.push(values.len());
374 }
375
376 Ok(SparseCsr {
377 values,
378 col_indices,
379 row_offsets,
380 nrows: a.nrows,
381 ncols: b.ncols,
382 })
383}
384
385pub fn sparse_scalar_mul(a: &SparseCsr, s: f64) -> SparseCsr {
387 let values: Vec<f64> = a.values.iter().map(|&v| v * s).collect();
388 SparseCsr {
389 values,
390 col_indices: a.col_indices.clone(),
391 row_offsets: a.row_offsets.clone(),
392 nrows: a.nrows,
393 ncols: a.ncols,
394 }
395}
396
397pub fn sparse_transpose(a: &SparseCsr) -> SparseCsr {
399 let mut row_counts = vec![0usize; a.ncols + 1];
401
402 for &c in &a.col_indices {
404 row_counts[c + 1] += 1;
405 }
406 for i in 1..=a.ncols {
408 row_counts[i] += row_counts[i - 1];
409 }
410
411 let nnz = a.values.len();
412 let mut new_values = vec![0.0f64; nnz];
413 let mut new_col_indices = vec![0usize; nnz];
414 let mut cursor = row_counts.clone();
415
416 for row in 0..a.nrows {
417 let start = a.row_offsets[row];
418 let end = a.row_offsets[row + 1];
419 for idx in start..end {
420 let col = a.col_indices[idx];
421 let dest = cursor[col];
422 new_values[dest] = a.values[idx];
423 new_col_indices[dest] = row;
424 cursor[col] += 1;
425 }
426 }
427
428 SparseCsr {
429 values: new_values,
430 col_indices: new_col_indices,
431 row_offsets: row_counts,
432 nrows: a.ncols,
433 ncols: a.nrows,
434 }
435}
436
437#[cfg(test)]
442mod tests {
443 use super::*;
444
445 fn csr_from_dense(data: &[f64], nrows: usize, ncols: usize) -> SparseCsr {
447 let mut values = Vec::new();
448 let mut col_indices = Vec::new();
449 let mut row_offsets = vec![0usize];
450
451 for r in 0..nrows {
452 for c in 0..ncols {
453 let v = data[r * ncols + c];
454 if v != 0.0 {
455 values.push(v);
456 col_indices.push(c);
457 }
458 }
459 row_offsets.push(values.len());
460 }
461
462 SparseCsr { values, col_indices, row_offsets, nrows, ncols }
463 }
464
465 #[test]
468 fn test_sparse_add_basic() {
469 let a = csr_from_dense(&[1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 0.0, 4.0, 5.0], 3, 3);
470 let b = csr_from_dense(&[0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 4.0, 0.0, 5.0], 3, 3);
471 let c = sparse_add(&a, &b).unwrap();
472 for r in 0..3 {
474 for col in 0..3 {
475 let expected = a.get(r, col) + b.get(r, col);
476 assert_eq!(c.get(r, col), expected, "mismatch at ({}, {})", r, col);
477 }
478 }
479 }
480
481 #[test]
482 fn test_sparse_add_a_plus_a_eq_2a() {
483 let a = csr_from_dense(&[1.0, 2.0, 0.0, 3.0], 2, 2);
484 let sum = sparse_add(&a, &a).unwrap();
485 let doubled = sparse_scalar_mul(&a, 2.0);
486 for r in 0..2 {
487 for c in 0..2 {
488 assert_eq!(sum.get(r, c), doubled.get(r, c));
489 }
490 }
491 }
492
493 #[test]
494 fn test_sparse_add_dimension_mismatch() {
495 let a = csr_from_dense(&[1.0, 2.0], 1, 2);
496 let b = csr_from_dense(&[1.0, 2.0, 3.0], 1, 3);
497 assert!(sparse_add(&a, &b).is_err());
498 }
499
500 #[test]
503 fn test_sparse_sub_basic() {
504 let a = csr_from_dense(&[5.0, 3.0, 0.0, 1.0], 2, 2);
505 let b = csr_from_dense(&[2.0, 3.0, 1.0, 0.0], 2, 2);
506 let c = sparse_sub(&a, &b).unwrap();
507 assert_eq!(c.get(0, 0), 3.0);
508 assert_eq!(c.get(0, 1), 0.0); assert_eq!(c.get(1, 0), -1.0);
510 assert_eq!(c.get(1, 1), 1.0);
511 }
512
513 #[test]
514 fn test_sparse_sub_self_is_zero() {
515 let a = csr_from_dense(&[1.0, 2.0, 3.0, 4.0], 2, 2);
516 let c = sparse_sub(&a, &a).unwrap();
517 assert_eq!(c.nnz(), 0);
518 }
519
520 #[test]
523 fn test_sparse_mul_hadamard() {
524 let a = csr_from_dense(&[1.0, 0.0, 3.0, 4.0], 2, 2);
525 let b = csr_from_dense(&[2.0, 5.0, 0.0, 3.0], 2, 2);
526 let c = sparse_mul(&a, &b).unwrap();
527 assert_eq!(c.get(0, 0), 2.0); assert_eq!(c.get(0, 1), 0.0); assert_eq!(c.get(1, 0), 0.0); assert_eq!(c.get(1, 1), 12.0); }
532
533 #[test]
536 fn test_sparse_matmul_identity() {
537 let a = csr_from_dense(&[1.0, 2.0, 3.0, 4.0], 2, 2);
539 let eye = csr_from_dense(&[1.0, 0.0, 0.0, 1.0], 2, 2);
540 let c = sparse_matmul(&a, &eye).unwrap();
541 for r in 0..2 {
542 for col in 0..2 {
543 assert_eq!(c.get(r, col), a.get(r, col));
544 }
545 }
546 }
547
548 #[test]
549 fn test_sparse_matmul_vs_dense() {
550 let a_data = [1.0, 2.0, 0.0, 0.0, 3.0, 4.0];
552 let b_data = [5.0, 0.0, 6.0, 7.0, 0.0, 8.0];
553 let a = csr_from_dense(&a_data, 2, 3);
554 let b = csr_from_dense(&b_data, 3, 2);
555
556 let c = sparse_matmul(&a, &b).unwrap();
557
558 assert_eq!(c.get(0, 0), 17.0);
564 assert_eq!(c.get(0, 1), 14.0);
565 assert_eq!(c.get(1, 0), 18.0);
566 assert_eq!(c.get(1, 1), 53.0);
567 }
568
569 #[test]
570 fn test_sparse_matmul_dimension_mismatch() {
571 let a = csr_from_dense(&[1.0, 2.0], 1, 2);
572 let b = csr_from_dense(&[1.0, 2.0, 3.0], 1, 3);
573 assert!(sparse_matmul(&a, &b).is_err());
574 }
575
576 #[test]
579 fn test_sparse_scalar_mul_basic() {
580 let a = csr_from_dense(&[2.0, 0.0, 0.0, 4.0], 2, 2);
581 let c = sparse_scalar_mul(&a, 3.0);
582 assert_eq!(c.get(0, 0), 6.0);
583 assert_eq!(c.get(1, 1), 12.0);
584 assert_eq!(c.nnz(), 2);
585 }
586
587 #[test]
590 fn test_sparse_transpose_square() {
591 let a = csr_from_dense(&[1.0, 2.0, 3.0, 4.0], 2, 2);
592 let at = sparse_transpose(&a);
593 assert_eq!(at.get(0, 0), 1.0);
594 assert_eq!(at.get(0, 1), 3.0);
595 assert_eq!(at.get(1, 0), 2.0);
596 assert_eq!(at.get(1, 1), 4.0);
597 }
598
599 #[test]
600 fn test_sparse_transpose_rect() {
601 let a = csr_from_dense(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
602 let at = sparse_transpose(&a);
603 assert_eq!(at.nrows, 3);
604 assert_eq!(at.ncols, 2);
605 for r in 0..2 {
606 for c in 0..3 {
607 assert_eq!(at.get(c, r), a.get(r, c), "mismatch at transpose({}, {})", c, r);
608 }
609 }
610 }
611
612 #[test]
613 fn test_sparse_transpose_double_is_identity() {
614 let a = csr_from_dense(&[1.0, 0.0, 2.0, 3.0, 0.0, 4.0], 2, 3);
615 let att = sparse_transpose(&sparse_transpose(&a));
616 assert_eq!(att.nrows, a.nrows);
617 assert_eq!(att.ncols, a.ncols);
618 for r in 0..a.nrows {
619 for c in 0..a.ncols {
620 assert_eq!(att.get(r, c), a.get(r, c));
621 }
622 }
623 }
624
625 #[test]
628 fn test_sparse_matmul_determinism() {
629 let a = csr_from_dense(&[1.0, 2.0, 0.0, 0.0, 3.0, 4.0], 2, 3);
630 let b = csr_from_dense(&[5.0, 0.0, 6.0, 7.0, 0.0, 8.0], 3, 2);
631
632 let c1 = sparse_matmul(&a, &b).unwrap();
633 let c2 = sparse_matmul(&a, &b).unwrap();
634
635 assert_eq!(c1.values, c2.values);
636 assert_eq!(c1.col_indices, c2.col_indices);
637 assert_eq!(c1.row_offsets, c2.row_offsets);
638 }
639
640 #[test]
641 fn test_sparse_add_determinism() {
642 let a = csr_from_dense(&[1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 0.0, 4.0, 5.0], 3, 3);
643 let b = csr_from_dense(&[0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 4.0, 0.0, 5.0], 3, 3);
644
645 let c1 = sparse_add(&a, &b).unwrap();
646 let c2 = sparse_add(&a, &b).unwrap();
647
648 assert_eq!(c1.values, c2.values);
649 assert_eq!(c1.col_indices, c2.col_indices);
650 }
651}
652