1use cyanea_core::{CyaneaError, Result, Summarizable};
9
10#[derive(Debug, Clone)]
12pub struct SparseMatrix {
13 rows: Vec<usize>,
14 cols: Vec<usize>,
15 values: Vec<f64>,
16 n_rows: usize,
17 n_cols: usize,
18}
19
20impl SparseMatrix {
21 pub fn new(n_rows: usize, n_cols: usize) -> Self {
23 Self {
24 rows: Vec::new(),
25 cols: Vec::new(),
26 values: Vec::new(),
27 n_rows,
28 n_cols,
29 }
30 }
31
32 pub fn from_triplets(
37 rows: Vec<usize>,
38 cols: Vec<usize>,
39 values: Vec<f64>,
40 n_rows: usize,
41 n_cols: usize,
42 ) -> Result<Self> {
43 if rows.len() != cols.len() || cols.len() != values.len() {
44 return Err(CyaneaError::InvalidInput(
45 "rows, cols, and values must have the same length".into(),
46 ));
47 }
48 for (i, (&r, &c)) in rows.iter().zip(cols.iter()).enumerate() {
49 if r >= n_rows || c >= n_cols {
50 return Err(CyaneaError::InvalidInput(format!(
51 "triplet {i} index ({r}, {c}) out of bounds for ({n_rows}, {n_cols})"
52 )));
53 }
54 }
55 Ok(Self {
56 rows,
57 cols,
58 values,
59 n_rows,
60 n_cols,
61 })
62 }
63
64 pub fn insert(&mut self, row: usize, col: usize, value: f64) -> Result<()> {
66 if row >= self.n_rows || col >= self.n_cols {
67 return Err(CyaneaError::InvalidInput(format!(
68 "index ({row}, {col}) out of bounds for ({}, {})",
69 self.n_rows, self.n_cols
70 )));
71 }
72 self.rows.push(row);
73 self.cols.push(col);
74 self.values.push(value);
75 Ok(())
76 }
77
78 pub fn get(&self, row: usize, col: usize) -> f64 {
82 for i in 0..self.values.len() {
83 if self.rows[i] == row && self.cols[i] == col {
84 return self.values[i];
85 }
86 }
87 0.0
88 }
89
90 pub fn nnz(&self) -> usize {
92 self.values.len()
93 }
94
95 pub fn density(&self) -> f64 {
97 let total = self.n_rows as f64 * self.n_cols as f64;
98 if total == 0.0 {
99 return 0.0;
100 }
101 self.values.len() as f64 / total
102 }
103
104 pub fn shape(&self) -> (usize, usize) {
106 (self.n_rows, self.n_cols)
107 }
108
109 pub fn to_dense(&self) -> Vec<Vec<f64>> {
111 let mut dense = vec![vec![0.0; self.n_cols]; self.n_rows];
112 for i in 0..self.values.len() {
113 dense[self.rows[i]][self.cols[i]] = self.values[i];
114 }
115 dense
116 }
117
118 pub fn from_dense(data: &[Vec<f64>], threshold: f64) -> Self {
120 let n_rows = data.len();
121 let n_cols = data.first().map_or(0, |r| r.len());
122 let mut rows = Vec::new();
123 let mut cols = Vec::new();
124 let mut values = Vec::new();
125
126 for (r, row) in data.iter().enumerate() {
127 for (c, &val) in row.iter().enumerate() {
128 if val.abs() > threshold {
129 rows.push(r);
130 cols.push(c);
131 values.push(val);
132 }
133 }
134 }
135
136 Self {
137 rows,
138 cols,
139 values,
140 n_rows,
141 n_cols,
142 }
143 }
144
145 pub fn row_nnz(&self, row: usize) -> usize {
147 self.rows.iter().filter(|&&r| r == row).count()
148 }
149
150 pub fn col_nnz(&self, col: usize) -> usize {
152 self.cols.iter().filter(|&&c| c == col).count()
153 }
154
155 pub fn to_csr(&self) -> (Vec<f64>, Vec<usize>, Vec<usize>) {
162 let nnz = self.values.len();
164 let mut order: Vec<usize> = (0..nnz).collect();
165 order.sort_by_key(|&i| (self.rows[i], self.cols[i]));
166
167 let mut data = Vec::with_capacity(nnz);
168 let mut indices = Vec::with_capacity(nnz);
169 let mut indptr = vec![0usize; self.n_rows + 1];
170
171 for &i in &order {
172 data.push(self.values[i]);
173 indices.push(self.cols[i]);
174 indptr[self.rows[i] + 1] += 1;
175 }
176
177 for i in 1..=self.n_rows {
179 indptr[i] += indptr[i - 1];
180 }
181
182 (data, indices, indptr)
183 }
184
185 pub fn from_csr(
191 data: Vec<f64>,
192 indices: Vec<usize>,
193 indptr: Vec<usize>,
194 n_rows: usize,
195 n_cols: usize,
196 ) -> Result<Self> {
197 if data.len() != indices.len() {
198 return Err(CyaneaError::InvalidInput(
199 "CSR data and indices must have the same length".into(),
200 ));
201 }
202 if indptr.len() != n_rows + 1 {
203 return Err(CyaneaError::InvalidInput(format!(
204 "CSR indptr length ({}) must be n_rows + 1 ({})",
205 indptr.len(),
206 n_rows + 1
207 )));
208 }
209
210 let nnz = data.len();
211 let mut rows = Vec::with_capacity(nnz);
212 let mut cols = Vec::with_capacity(nnz);
213
214 for row in 0..n_rows {
215 let start = indptr[row];
216 let end = indptr[row + 1];
217 for idx in start..end {
218 if idx >= nnz {
219 return Err(CyaneaError::InvalidInput(format!(
220 "CSR indptr references index {idx} but nnz is {nnz}"
221 )));
222 }
223 if indices[idx] >= n_cols {
224 return Err(CyaneaError::InvalidInput(format!(
225 "CSR column index {} out of bounds for n_cols={}",
226 indices[idx], n_cols
227 )));
228 }
229 rows.push(row);
230 cols.push(indices[idx]);
231 }
232 }
233
234 Ok(Self {
235 rows,
236 cols,
237 values: data,
238 n_rows,
239 n_cols,
240 })
241 }
242
243 pub fn iter(&self) -> impl Iterator<Item = (usize, usize, f64)> + '_ {
245 self.rows
246 .iter()
247 .zip(self.cols.iter())
248 .zip(self.values.iter())
249 .map(|((&r, &c), &v)| (r, c, v))
250 }
251
252 pub fn column_sums(&self) -> Vec<f64> {
254 let mut sums = vec![0.0; self.n_cols];
255 for i in 0..self.values.len() {
256 sums[self.cols[i]] += self.values[i];
257 }
258 sums
259 }
260
261 pub fn column_means(&self) -> Vec<f64> {
263 if self.n_rows == 0 {
264 return vec![0.0; self.n_cols];
265 }
266 let sums = self.column_sums();
267 let n = self.n_rows as f64;
268 sums.into_iter().map(|s| s / n).collect()
269 }
270
271 pub fn row_sums(&self) -> Vec<f64> {
273 let mut sums = vec![0.0; self.n_rows];
274 for i in 0..self.values.len() {
275 sums[self.rows[i]] += self.values[i];
276 }
277 sums
278 }
279
280 pub fn scale_rows(&mut self, factors: &[f64]) {
284 for i in 0..self.values.len() {
285 self.values[i] *= factors[self.rows[i]];
286 }
287 }
288
289 pub fn map_values(&mut self, f: impl Fn(f64) -> f64) {
291 for v in &mut self.values {
292 *v = f(*v);
293 }
294 }
295
296 pub fn n_rows(&self) -> usize {
298 self.n_rows
299 }
300
301 pub fn n_cols(&self) -> usize {
303 self.n_cols
304 }
305}
306
307impl Summarizable for SparseMatrix {
308 fn summary(&self) -> String {
309 format!(
310 "SparseMatrix: {}\u{00d7}{}, {} nonzeros ({:.2}% density)",
311 self.n_rows,
312 self.n_cols,
313 self.nnz(),
314 self.density() * 100.0
315 )
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 #[test]
324 fn test_new_empty() {
325 let m = SparseMatrix::new(10, 20);
326 assert_eq!(m.shape(), (10, 20));
327 assert_eq!(m.nnz(), 0);
328 assert_eq!(m.density(), 0.0);
329 }
330
331 #[test]
332 fn test_from_triplets() {
333 let m = SparseMatrix::from_triplets(
334 vec![0, 1, 2],
335 vec![0, 1, 2],
336 vec![1.0, 2.0, 3.0],
337 3,
338 3,
339 )
340 .unwrap();
341 assert_eq!(m.nnz(), 3);
342 assert_eq!(m.get(0, 0), 1.0);
343 assert_eq!(m.get(1, 1), 2.0);
344 assert_eq!(m.get(0, 1), 0.0);
345 }
346
347 #[test]
348 fn test_from_triplets_bounds_check() {
349 let result = SparseMatrix::from_triplets(
350 vec![5],
351 vec![0],
352 vec![1.0],
353 3,
354 3,
355 );
356 assert!(result.is_err());
357 }
358
359 #[test]
360 fn test_from_triplets_length_mismatch() {
361 let result = SparseMatrix::from_triplets(
362 vec![0, 1],
363 vec![0],
364 vec![1.0],
365 3,
366 3,
367 );
368 assert!(result.is_err());
369 }
370
371 #[test]
372 fn test_insert() {
373 let mut m = SparseMatrix::new(3, 3);
374 m.insert(0, 0, 5.0).unwrap();
375 assert_eq!(m.get(0, 0), 5.0);
376 assert_eq!(m.nnz(), 1);
377
378 assert!(m.insert(10, 0, 1.0).is_err());
379 }
380
381 #[test]
382 fn test_density() {
383 let m = SparseMatrix::from_triplets(
384 vec![0, 1],
385 vec![0, 1],
386 vec![1.0, 2.0],
387 10,
388 10,
389 )
390 .unwrap();
391 assert!((m.density() - 0.02).abs() < 1e-10);
392 }
393
394 #[test]
395 fn test_to_dense() {
396 let m = SparseMatrix::from_triplets(
397 vec![0, 1],
398 vec![1, 0],
399 vec![3.0, 7.0],
400 2,
401 2,
402 )
403 .unwrap();
404 let dense = m.to_dense();
405 assert_eq!(dense, vec![vec![0.0, 3.0], vec![7.0, 0.0]]);
406 }
407
408 #[test]
409 fn test_from_dense() {
410 let data = vec![vec![0.0, 3.0], vec![7.0, 0.0]];
411 let m = SparseMatrix::from_dense(&data, 0.0);
412 assert_eq!(m.nnz(), 2);
413 assert_eq!(m.get(0, 1), 3.0);
414 assert_eq!(m.get(1, 0), 7.0);
415 }
416
417 #[test]
418 fn test_from_dense_with_threshold() {
419 let data = vec![vec![0.1, 3.0], vec![7.0, 0.05]];
420 let m = SparseMatrix::from_dense(&data, 0.5);
421 assert_eq!(m.nnz(), 2); }
423
424 #[test]
425 fn test_row_col_nnz() {
426 let m = SparseMatrix::from_triplets(
427 vec![0, 0, 1],
428 vec![0, 1, 0],
429 vec![1.0, 2.0, 3.0],
430 2,
431 2,
432 )
433 .unwrap();
434 assert_eq!(m.row_nnz(0), 2);
435 assert_eq!(m.row_nnz(1), 1);
436 assert_eq!(m.col_nnz(0), 2);
437 assert_eq!(m.col_nnz(1), 1);
438 }
439
440 #[test]
441 fn test_iter() {
442 let m = SparseMatrix::from_triplets(
443 vec![0, 1],
444 vec![0, 1],
445 vec![1.0, 2.0],
446 2,
447 2,
448 )
449 .unwrap();
450 let triplets: Vec<_> = m.iter().collect();
451 assert_eq!(triplets, vec![(0, 0, 1.0), (1, 1, 2.0)]);
452 }
453
454 #[test]
455 fn test_summary() {
456 let m = SparseMatrix::from_triplets(
457 vec![0],
458 vec![0],
459 vec![1.0],
460 100,
461 50,
462 )
463 .unwrap();
464 assert_eq!(
465 m.summary(),
466 "SparseMatrix: 100\u{00d7}50, 1 nonzeros (0.02% density)"
467 );
468 }
469
470 #[test]
471 fn test_zero_dimension_density() {
472 let m = SparseMatrix::new(0, 0);
473 assert_eq!(m.density(), 0.0);
474 }
475
476 #[test]
477 fn test_csr_roundtrip() {
478 let m = SparseMatrix::from_triplets(
479 vec![0, 0, 1, 2, 2],
480 vec![0, 2, 1, 0, 2],
481 vec![1.0, 2.0, 3.0, 4.0, 5.0],
482 3,
483 3,
484 )
485 .unwrap();
486
487 let (data, indices, indptr) = m.to_csr();
488 let m2 = SparseMatrix::from_csr(data, indices, indptr, 3, 3).unwrap();
489
490 assert_eq!(m2.shape(), (3, 3));
491 assert_eq!(m2.nnz(), 5);
492 assert_eq!(m2.get(0, 0), 1.0);
493 assert_eq!(m2.get(0, 2), 2.0);
494 assert_eq!(m2.get(1, 1), 3.0);
495 assert_eq!(m2.get(2, 0), 4.0);
496 assert_eq!(m2.get(2, 2), 5.0);
497 assert_eq!(m2.get(1, 0), 0.0);
498 }
499
500 #[test]
501 fn test_csr_empty() {
502 let m = SparseMatrix::new(3, 4);
503 let (data, indices, indptr) = m.to_csr();
504 assert!(data.is_empty());
505 assert!(indices.is_empty());
506 assert_eq!(indptr, vec![0, 0, 0, 0]);
507
508 let m2 = SparseMatrix::from_csr(data, indices, indptr, 3, 4).unwrap();
509 assert_eq!(m2.nnz(), 0);
510 assert_eq!(m2.shape(), (3, 4));
511 }
512
513 #[test]
514 fn test_column_sums() {
515 let m = SparseMatrix::from_triplets(
516 vec![0, 0, 1, 1],
517 vec![0, 1, 0, 2],
518 vec![1.0, 2.0, 3.0, 4.0],
519 2,
520 3,
521 )
522 .unwrap();
523 assert_eq!(m.column_sums(), vec![4.0, 2.0, 4.0]);
524 }
525
526 #[test]
527 fn test_column_sums_empty() {
528 let m = SparseMatrix::new(3, 4);
529 assert_eq!(m.column_sums(), vec![0.0, 0.0, 0.0, 0.0]);
530 }
531
532 #[test]
533 fn test_column_means() {
534 let m = SparseMatrix::from_triplets(
535 vec![0, 1],
536 vec![0, 0],
537 vec![4.0, 6.0],
538 2,
539 2,
540 )
541 .unwrap();
542 let means = m.column_means();
543 assert!((means[0] - 5.0).abs() < 1e-10);
544 assert!((means[1] - 0.0).abs() < 1e-10);
545 }
546
547 #[test]
548 fn test_column_means_zero_rows() {
549 let m = SparseMatrix::new(0, 3);
550 assert_eq!(m.column_means(), vec![0.0, 0.0, 0.0]);
551 }
552
553 #[test]
554 fn test_row_sums() {
555 let m = SparseMatrix::from_triplets(
556 vec![0, 0, 1, 2],
557 vec![0, 1, 0, 2],
558 vec![1.0, 2.0, 3.0, 4.0],
559 3,
560 3,
561 )
562 .unwrap();
563 assert_eq!(m.row_sums(), vec![3.0, 3.0, 4.0]);
564 }
565
566 #[test]
567 fn test_scale_rows() {
568 let mut m = SparseMatrix::from_triplets(
569 vec![0, 0, 1, 1],
570 vec![0, 1, 0, 1],
571 vec![2.0, 4.0, 6.0, 8.0],
572 2,
573 2,
574 )
575 .unwrap();
576 m.scale_rows(&[0.5, 2.0]);
577 assert!((m.get(0, 0) - 1.0).abs() < 1e-10);
578 assert!((m.get(0, 1) - 2.0).abs() < 1e-10);
579 assert!((m.get(1, 0) - 12.0).abs() < 1e-10);
580 assert!((m.get(1, 1) - 16.0).abs() < 1e-10);
581 }
582
583 #[test]
584 fn test_map_values() {
585 let mut m = SparseMatrix::from_triplets(
586 vec![0, 1],
587 vec![0, 1],
588 vec![4.0, 9.0],
589 2,
590 2,
591 )
592 .unwrap();
593 m.map_values(|v| v.sqrt());
594 assert!((m.get(0, 0) - 2.0).abs() < 1e-10);
595 assert!((m.get(1, 1) - 3.0).abs() < 1e-10);
596 }
597
598 #[test]
599 fn test_n_rows_n_cols() {
600 let m = SparseMatrix::new(5, 8);
601 assert_eq!(m.n_rows(), 5);
602 assert_eq!(m.n_cols(), 8);
603 }
604
605 #[test]
606 fn test_csr_single_row() {
607 let m = SparseMatrix::from_triplets(
608 vec![0, 0, 0],
609 vec![0, 2, 4],
610 vec![1.0, 2.0, 3.0],
611 1,
612 5,
613 )
614 .unwrap();
615
616 let (data, indices, indptr) = m.to_csr();
617 assert_eq!(data, vec![1.0, 2.0, 3.0]);
618 assert_eq!(indices, vec![0, 2, 4]);
619 assert_eq!(indptr, vec![0, 3]);
620
621 let m2 = SparseMatrix::from_csr(data, indices, indptr, 1, 5).unwrap();
622 assert_eq!(m2.nnz(), 3);
623 assert_eq!(m2.get(0, 0), 1.0);
624 assert_eq!(m2.get(0, 2), 2.0);
625 assert_eq!(m2.get(0, 4), 3.0);
626 }
627}