1use crate::CsrMatrix;
2use alloc::vec::Vec;
3use tang::Scalar;
4
5pub struct CooMatrix<S> {
9 pub nrows: usize,
10 pub ncols: usize,
11 pub rows: Vec<usize>,
12 pub cols: Vec<usize>,
13 pub vals: Vec<S>,
14}
15
16impl<S: Scalar> CooMatrix<S> {
17 pub fn new(nrows: usize, ncols: usize) -> Self {
18 Self {
19 nrows,
20 ncols,
21 rows: Vec::new(),
22 cols: Vec::new(),
23 vals: Vec::new(),
24 }
25 }
26
27 pub fn with_capacity(nrows: usize, ncols: usize, nnz: usize) -> Self {
28 Self {
29 nrows,
30 ncols,
31 rows: Vec::with_capacity(nnz),
32 cols: Vec::with_capacity(nnz),
33 vals: Vec::with_capacity(nnz),
34 }
35 }
36
37 pub fn push(&mut self, row: usize, col: usize, val: S) {
39 assert!(row < self.nrows && col < self.ncols);
40 self.rows.push(row);
41 self.cols.push(col);
42 self.vals.push(val);
43 }
44
45 pub fn nnz(&self) -> usize {
46 self.rows.len()
47 }
48
49 pub fn to_csr(&self) -> CsrMatrix<S> {
51 let mut row_counts = alloc::vec![0usize; self.nrows + 1];
52 for &r in &self.rows {
53 row_counts[r + 1] += 1;
54 }
55 for i in 1..=self.nrows {
57 row_counts[i] += row_counts[i - 1];
58 }
59 let nnz = row_counts[self.nrows];
60 let mut col_indices = alloc::vec![0usize; nnz];
61 let mut values = alloc::vec![S::ZERO; nnz];
62 let mut offsets = row_counts.clone();
63
64 for k in 0..self.rows.len() {
65 let r = self.rows[k];
66 let pos = offsets[r];
67 col_indices[pos] = self.cols[k];
68 values[pos] = self.vals[k];
69 offsets[r] += 1;
70 }
71
72 let row_ptrs = row_counts;
74 for i in 0..self.nrows {
75 let start = row_ptrs[i];
76 let end = row_ptrs[i + 1];
77 for j in (start + 1)..end {
79 let mut k = j;
80 while k > start && col_indices[k] < col_indices[k - 1] {
81 col_indices.swap(k, k - 1);
82 values.swap(k, k - 1);
83 k -= 1;
84 }
85 }
86 }
87
88 let mut new_col = Vec::with_capacity(nnz);
90 let mut new_val = Vec::with_capacity(nnz);
91 let mut new_ptrs = alloc::vec![0usize; self.nrows + 1];
92
93 for i in 0..self.nrows {
94 let start = row_ptrs[i];
95 let end = row_ptrs[i + 1];
96 let mut j = start;
97 while j < end {
98 let c = col_indices[j];
99 let mut v = values[j];
100 j += 1;
101 while j < end && col_indices[j] == c {
102 v += values[j];
103 j += 1;
104 }
105 new_col.push(c);
106 new_val.push(v);
107 }
108 new_ptrs[i + 1] = new_col.len();
109 }
110
111 CsrMatrix {
112 nrows: self.nrows,
113 ncols: self.ncols,
114 row_ptrs: new_ptrs,
115 col_indices: new_col,
116 values: new_val,
117 }
118 }
119}