Skip to main content

ferrolearn_sparse/
coo.rs

1//! Coordinate (COO / triplet) sparse matrix format.
2//!
3//! [`CooMatrix<T>`] is a newtype wrapper around [`sprs::TriMat<T>`]. It is
4//! primarily useful for incrementally building a sparse matrix before
5//! converting it to [`CsrMatrix`](crate::CsrMatrix) or
6//! [`CscMatrix`](crate::CscMatrix) for computation.
7
8use ferrolearn_core::FerroError;
9use ndarray::Array2;
10use num_traits::Zero;
11use sprs::{SpIndex, TriMat};
12
13/// Coordinate-format (COO / triplet) sparse matrix.
14///
15/// Stores non-zero entries as `(row, col, value)` triplets. Duplicate entries
16/// at the same position are **summed** during conversion to CSR/CSC. This
17/// format is most convenient for construction; prefer [`CsrMatrix`](crate::CsrMatrix)
18/// or [`CscMatrix`](crate::CscMatrix) for arithmetic.
19///
20/// # Type Parameter
21///
22/// `T` — the scalar type stored in the matrix. No additional bounds are
23/// required for construction; conversion methods impose their own bounds.
24#[derive(Debug)]
25pub struct CooMatrix<T> {
26    inner: TriMat<T>,
27}
28
29impl<T: Clone> Clone for CooMatrix<T> {
30    /// Clone by rebuilding the inner [`sprs::TriMat`] from raw components.
31    ///
32    /// [`sprs::TriMat`] does not implement `Clone` generically, so we
33    /// reconstruct it from the stored row indices, column indices, and data.
34    fn clone(&self) -> Self {
35        Self {
36            inner: TriMat::from_triplets(
37                (self.n_rows(), self.n_cols()),
38                self.inner.row_inds().to_vec(),
39                self.inner.col_inds().to_vec(),
40                self.inner.data().to_vec(),
41            ),
42        }
43    }
44}
45
46impl<T> CooMatrix<T> {
47    /// Create an empty COO matrix with the given shape.
48    ///
49    /// # Arguments
50    ///
51    /// * `n_rows` — number of rows.
52    /// * `n_cols` — number of columns.
53    pub fn new(n_rows: usize, n_cols: usize) -> Self {
54        Self {
55            inner: TriMat::new((n_rows, n_cols)),
56        }
57    }
58
59    /// Create a COO matrix with the given shape and pre-allocated capacity.
60    ///
61    /// # Arguments
62    ///
63    /// * `n_rows` — number of rows.
64    /// * `n_cols` — number of columns.
65    /// * `capacity` — expected number of non-zero entries.
66    pub fn with_capacity(n_rows: usize, n_cols: usize, capacity: usize) -> Self {
67        Self {
68            inner: TriMat::with_capacity((n_rows, n_cols), capacity),
69        }
70    }
71
72    /// Build a [`CooMatrix`] from raw triplet components.
73    ///
74    /// All three slices must have the same length. Row indices must be less
75    /// than `n_rows`; column indices must be less than `n_cols`.
76    ///
77    /// # Errors
78    ///
79    /// Returns [`FerroError::InvalidParameter`] if the slice lengths differ or
80    /// if any index is out of bounds.
81    pub fn from_triplets(
82        n_rows: usize,
83        n_cols: usize,
84        row_inds: Vec<usize>,
85        col_inds: Vec<usize>,
86        data: Vec<T>,
87    ) -> Result<Self, FerroError> {
88        if row_inds.len() != col_inds.len() || row_inds.len() != data.len() {
89            return Err(FerroError::InvalidParameter {
90                name: "triplet arrays".into(),
91                reason: format!(
92                    "row_inds ({}), col_inds ({}), and data ({}) must all have the same length",
93                    row_inds.len(),
94                    col_inds.len(),
95                    data.len()
96                ),
97            });
98        }
99        if let Some(&r) = row_inds.iter().find(|&&r| r >= n_rows) {
100            return Err(FerroError::InvalidParameter {
101                name: "row_inds".into(),
102                reason: format!("index {r} is out of bounds for n_rows={n_rows}"),
103            });
104        }
105        if let Some(&c) = col_inds.iter().find(|&&c| c >= n_cols) {
106            return Err(FerroError::InvalidParameter {
107                name: "col_inds".into(),
108                reason: format!("index {c} is out of bounds for n_cols={n_cols}"),
109            });
110        }
111        Ok(Self {
112            inner: TriMat::from_triplets((n_rows, n_cols), row_inds, col_inds, data),
113        })
114    }
115
116    /// Append a single non-zero entry `(row, col, value)`.
117    ///
118    /// # Errors
119    ///
120    /// Returns [`FerroError::InvalidParameter`] if `row >= n_rows()` or
121    /// `col >= n_cols()`.
122    pub fn push(&mut self, row: usize, col: usize, value: T) -> Result<(), FerroError> {
123        if row >= self.n_rows() {
124            return Err(FerroError::InvalidParameter {
125                name: "row".into(),
126                reason: format!("index {row} is out of bounds for n_rows={}", self.n_rows()),
127            });
128        }
129        if col >= self.n_cols() {
130            return Err(FerroError::InvalidParameter {
131                name: "col".into(),
132                reason: format!("index {col} is out of bounds for n_cols={}", self.n_cols()),
133            });
134        }
135        self.inner.add_triplet(row, col, value);
136        Ok(())
137    }
138
139    /// Returns the number of rows.
140    pub fn n_rows(&self) -> usize {
141        self.inner.rows()
142    }
143
144    /// Returns the number of columns.
145    pub fn n_cols(&self) -> usize {
146        self.inner.cols()
147    }
148
149    /// Returns the number of stored non-zero entries (counting duplicates).
150    pub fn nnz(&self) -> usize {
151        self.inner.nnz()
152    }
153
154    /// Returns a reference to the underlying [`sprs::TriMat<T>`].
155    pub fn inner(&self) -> &TriMat<T> {
156        &self.inner
157    }
158
159    /// Consume this matrix and return the underlying [`sprs::TriMat<T>`].
160    pub fn into_inner(self) -> TriMat<T> {
161        self.inner
162    }
163}
164
165impl<T> CooMatrix<T>
166where
167    T: Clone + Zero + num_traits::NumAssign + 'static,
168{
169    /// Convert this COO matrix to a dense [`Array2<T>`].
170    ///
171    /// Duplicate entries at the same position are summed.
172    pub fn to_dense(&self) -> Array2<T> {
173        let mut out = Array2::<T>::zeros((self.n_rows(), self.n_cols()));
174        for (val, (r, c)) in self.inner.triplet_iter() {
175            out[[r.index(), c.index()]] += val.clone();
176        }
177        out
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    #[test]
186    fn test_coo_new() {
187        let m: CooMatrix<f64> = CooMatrix::new(4, 5);
188        assert_eq!(m.n_rows(), 4);
189        assert_eq!(m.n_cols(), 5);
190        assert_eq!(m.nnz(), 0);
191    }
192
193    #[test]
194    fn test_coo_push() {
195        let mut m: CooMatrix<f64> = CooMatrix::new(3, 3);
196        m.push(0, 0, 1.0).unwrap();
197        m.push(1, 2, 5.0).unwrap();
198        assert_eq!(m.nnz(), 2);
199    }
200
201    #[test]
202    fn test_coo_push_out_of_bounds() {
203        let mut m: CooMatrix<f64> = CooMatrix::new(2, 2);
204        assert!(m.push(2, 0, 1.0).is_err());
205        assert!(m.push(0, 2, 1.0).is_err());
206    }
207
208    #[test]
209    fn test_coo_from_triplets_mismatch() {
210        let result = CooMatrix::<f64>::from_triplets(3, 3, vec![0, 1], vec![0], vec![1.0, 2.0]);
211        assert!(result.is_err());
212    }
213
214    #[test]
215    fn test_coo_from_triplets_out_of_bounds() {
216        let result = CooMatrix::<f64>::from_triplets(2, 2, vec![3], vec![0], vec![1.0]);
217        assert!(result.is_err());
218    }
219
220    #[test]
221    fn test_coo_to_dense() {
222        let mut m: CooMatrix<f64> = CooMatrix::new(2, 3);
223        m.push(0, 1, 3.0).unwrap();
224        m.push(1, 0, 7.0).unwrap();
225        let d = m.to_dense();
226        assert_eq!(d[[0, 1]], 3.0);
227        assert_eq!(d[[1, 0]], 7.0);
228        assert_eq!(d[[0, 0]], 0.0);
229    }
230
231    #[test]
232    fn test_coo_to_dense_duplicate_summed() {
233        let mut m: CooMatrix<f64> = CooMatrix::new(2, 2);
234        m.push(0, 0, 1.0).unwrap();
235        m.push(0, 0, 2.0).unwrap(); // duplicate — should sum to 3.0
236        let d = m.to_dense();
237        assert_eq!(d[[0, 0]], 3.0);
238    }
239
240    #[test]
241    fn test_coo_clone() {
242        let mut m: CooMatrix<f64> = CooMatrix::new(2, 2);
243        m.push(0, 0, 5.0).unwrap();
244        let m2 = m.clone();
245        assert_eq!(m2.nnz(), 1);
246        assert_eq!(m2.n_rows(), 2);
247        assert_eq!(m2.n_cols(), 2);
248    }
249}