1use ferrolearn_core::FerroError;
9use ndarray::Array2;
10use num_traits::Zero;
11use sprs::{SpIndex, TriMat};
12
13#[derive(Debug)]
25pub struct CooMatrix<T> {
26 inner: TriMat<T>,
27}
28
29impl<T: Clone> Clone for CooMatrix<T> {
30 fn clone(&self) -> Self {
35 Self {
36 inner: TriMat::from_triplets(
37 (self.n_rows(), self.n_cols()),
38 self.inner.row_inds().to_vec(),
39 self.inner.col_inds().to_vec(),
40 self.inner.data().to_vec(),
41 ),
42 }
43 }
44}
45
46impl<T> CooMatrix<T> {
47 pub fn new(n_rows: usize, n_cols: usize) -> Self {
54 Self {
55 inner: TriMat::new((n_rows, n_cols)),
56 }
57 }
58
59 pub fn with_capacity(n_rows: usize, n_cols: usize, capacity: usize) -> Self {
67 Self {
68 inner: TriMat::with_capacity((n_rows, n_cols), capacity),
69 }
70 }
71
72 pub fn from_triplets(
82 n_rows: usize,
83 n_cols: usize,
84 row_inds: Vec<usize>,
85 col_inds: Vec<usize>,
86 data: Vec<T>,
87 ) -> Result<Self, FerroError> {
88 if row_inds.len() != col_inds.len() || row_inds.len() != data.len() {
89 return Err(FerroError::InvalidParameter {
90 name: "triplet arrays".into(),
91 reason: format!(
92 "row_inds ({}), col_inds ({}), and data ({}) must all have the same length",
93 row_inds.len(),
94 col_inds.len(),
95 data.len()
96 ),
97 });
98 }
99 if let Some(&r) = row_inds.iter().find(|&&r| r >= n_rows) {
100 return Err(FerroError::InvalidParameter {
101 name: "row_inds".into(),
102 reason: format!("index {r} is out of bounds for n_rows={n_rows}"),
103 });
104 }
105 if let Some(&c) = col_inds.iter().find(|&&c| c >= n_cols) {
106 return Err(FerroError::InvalidParameter {
107 name: "col_inds".into(),
108 reason: format!("index {c} is out of bounds for n_cols={n_cols}"),
109 });
110 }
111 Ok(Self {
112 inner: TriMat::from_triplets((n_rows, n_cols), row_inds, col_inds, data),
113 })
114 }
115
116 pub fn push(&mut self, row: usize, col: usize, value: T) -> Result<(), FerroError> {
123 if row >= self.n_rows() {
124 return Err(FerroError::InvalidParameter {
125 name: "row".into(),
126 reason: format!("index {row} is out of bounds for n_rows={}", self.n_rows()),
127 });
128 }
129 if col >= self.n_cols() {
130 return Err(FerroError::InvalidParameter {
131 name: "col".into(),
132 reason: format!("index {col} is out of bounds for n_cols={}", self.n_cols()),
133 });
134 }
135 self.inner.add_triplet(row, col, value);
136 Ok(())
137 }
138
139 pub fn n_rows(&self) -> usize {
141 self.inner.rows()
142 }
143
144 pub fn n_cols(&self) -> usize {
146 self.inner.cols()
147 }
148
149 pub fn nnz(&self) -> usize {
151 self.inner.nnz()
152 }
153
154 pub fn inner(&self) -> &TriMat<T> {
156 &self.inner
157 }
158
159 pub fn into_inner(self) -> TriMat<T> {
161 self.inner
162 }
163}
164
165impl<T> CooMatrix<T>
166where
167 T: Clone + Zero + num_traits::NumAssign + 'static,
168{
169 pub fn to_dense(&self) -> Array2<T> {
173 let mut out = Array2::<T>::zeros((self.n_rows(), self.n_cols()));
174 for (val, (r, c)) in self.inner.triplet_iter() {
175 out[[r.index(), c.index()]] += val.clone();
176 }
177 out
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 #[test]
186 fn test_coo_new() {
187 let m: CooMatrix<f64> = CooMatrix::new(4, 5);
188 assert_eq!(m.n_rows(), 4);
189 assert_eq!(m.n_cols(), 5);
190 assert_eq!(m.nnz(), 0);
191 }
192
193 #[test]
194 fn test_coo_push() {
195 let mut m: CooMatrix<f64> = CooMatrix::new(3, 3);
196 m.push(0, 0, 1.0).unwrap();
197 m.push(1, 2, 5.0).unwrap();
198 assert_eq!(m.nnz(), 2);
199 }
200
201 #[test]
202 fn test_coo_push_out_of_bounds() {
203 let mut m: CooMatrix<f64> = CooMatrix::new(2, 2);
204 assert!(m.push(2, 0, 1.0).is_err());
205 assert!(m.push(0, 2, 1.0).is_err());
206 }
207
208 #[test]
209 fn test_coo_from_triplets_mismatch() {
210 let result = CooMatrix::<f64>::from_triplets(3, 3, vec![0, 1], vec![0], vec![1.0, 2.0]);
211 assert!(result.is_err());
212 }
213
214 #[test]
215 fn test_coo_from_triplets_out_of_bounds() {
216 let result = CooMatrix::<f64>::from_triplets(2, 2, vec![3], vec![0], vec![1.0]);
217 assert!(result.is_err());
218 }
219
220 #[test]
221 fn test_coo_to_dense() {
222 let mut m: CooMatrix<f64> = CooMatrix::new(2, 3);
223 m.push(0, 1, 3.0).unwrap();
224 m.push(1, 0, 7.0).unwrap();
225 let d = m.to_dense();
226 assert_eq!(d[[0, 1]], 3.0);
227 assert_eq!(d[[1, 0]], 7.0);
228 assert_eq!(d[[0, 0]], 0.0);
229 }
230
231 #[test]
232 fn test_coo_to_dense_duplicate_summed() {
233 let mut m: CooMatrix<f64> = CooMatrix::new(2, 2);
234 m.push(0, 0, 1.0).unwrap();
235 m.push(0, 0, 2.0).unwrap(); let d = m.to_dense();
237 assert_eq!(d[[0, 0]], 3.0);
238 }
239
240 #[test]
241 fn test_coo_clone() {
242 let mut m: CooMatrix<f64> = CooMatrix::new(2, 2);
243 m.push(0, 0, 5.0).unwrap();
244 let m2 = m.clone();
245 assert_eq!(m2.nnz(), 1);
246 assert_eq!(m2.n_rows(), 2);
247 assert_eq!(m2.n_cols(), 2);
248 }
249}