Skip to main content

ferrolearn_sparse/
helpers.rs

1//! Convenience constructors for common sparse-matrix patterns.
2//!
3//! These mirror the most-used scipy.sparse helpers:
4//!
5//! - [`eye`] — sparse identity matrix.
6//! - [`diags`] — sparse diagonal matrix.
7//! - [`hstack`] — horizontal concatenation of CSR matrices.
8//! - [`vstack`] — vertical concatenation of CSR matrices.
9
10use ferrolearn_core::FerroError;
11use num_traits::One;
12use std::ops::Add;
13
14use crate::coo::CooMatrix;
15use crate::csr::CsrMatrix;
16
17/// Build an `n x n` sparse identity matrix.
18pub fn eye<T>(n: usize) -> Result<CsrMatrix<T>, FerroError>
19where
20    T: Clone + One + Add<Output = T> + 'static,
21{
22    let mut coo = CooMatrix::<T>::with_capacity(n, n, n);
23    for i in 0..n {
24        coo.push(i, i, T::one())
25            .map_err(|e| FerroError::InvalidParameter {
26                name: "eye".into(),
27                reason: format!("push failed at ({i}, {i}): {e}"),
28            })?;
29    }
30    CsrMatrix::from_coo(&coo)
31}
32
33/// Build a sparse `n x n` matrix from a single diagonal vector at `offset`.
34///
35/// `offset == 0` puts `values` on the main diagonal; `offset > 0` shifts to
36/// a super-diagonal; `offset < 0` shifts to a sub-diagonal.
37pub fn diags<T>(values: &[T], offset: isize, n: usize) -> Result<CsrMatrix<T>, FerroError>
38where
39    T: Clone + Add<Output = T> + 'static,
40{
41    let mut coo = CooMatrix::<T>::with_capacity(n, n, values.len());
42    for (k, v) in values.iter().enumerate() {
43        let (i, j) = if offset >= 0 {
44            (k, k + offset as usize)
45        } else {
46            (k + (-offset) as usize, k)
47        };
48        if i < n && j < n {
49            coo.push(i, j, v.clone())
50                .map_err(|e| FerroError::InvalidParameter {
51                    name: "diags".into(),
52                    reason: format!("push failed at ({i}, {j}): {e}"),
53                })?;
54        }
55    }
56    CsrMatrix::from_coo(&coo)
57}
58
59/// Horizontally concatenate CSR matrices.
60///
61/// All matrices must have the same number of rows.
62pub fn hstack<T>(matrices: &[&CsrMatrix<T>]) -> Result<CsrMatrix<T>, FerroError>
63where
64    T: Clone + Add<Output = T> + 'static,
65{
66    if matrices.is_empty() {
67        return Err(FerroError::InvalidParameter {
68            name: "matrices".into(),
69            reason: "hstack: at least one matrix required".into(),
70        });
71    }
72    let n_rows = matrices[0].n_rows();
73    for (idx, m) in matrices.iter().enumerate() {
74        if m.n_rows() != n_rows {
75            return Err(FerroError::ShapeMismatch {
76                expected: vec![n_rows],
77                actual: vec![m.n_rows()],
78                context: format!("hstack: matrix {idx} has {} rows", m.n_rows()),
79            });
80        }
81    }
82    let total_cols: usize = matrices.iter().map(|m| m.n_cols()).sum();
83    let mut coo = CooMatrix::<T>::new(n_rows, total_cols);
84    let mut col_offset = 0usize;
85    for m in matrices {
86        for (val, (r, c)) in m.inner().iter() {
87            coo.push(r, c + col_offset, val.clone())
88                .map_err(|e| FerroError::InvalidParameter {
89                    name: "hstack".into(),
90                    reason: format!("push failed: {e}"),
91                })?;
92        }
93        col_offset += m.n_cols();
94    }
95    CsrMatrix::from_coo(&coo)
96}
97
98/// Vertically concatenate CSR matrices.
99///
100/// All matrices must have the same number of columns.
101pub fn vstack<T>(matrices: &[&CsrMatrix<T>]) -> Result<CsrMatrix<T>, FerroError>
102where
103    T: Clone + Add<Output = T> + 'static,
104{
105    if matrices.is_empty() {
106        return Err(FerroError::InvalidParameter {
107            name: "matrices".into(),
108            reason: "vstack: at least one matrix required".into(),
109        });
110    }
111    let n_cols = matrices[0].n_cols();
112    for (idx, m) in matrices.iter().enumerate() {
113        if m.n_cols() != n_cols {
114            return Err(FerroError::ShapeMismatch {
115                expected: vec![n_cols],
116                actual: vec![m.n_cols()],
117                context: format!("vstack: matrix {idx} has {} cols", m.n_cols()),
118            });
119        }
120    }
121    let total_rows: usize = matrices.iter().map(|m| m.n_rows()).sum();
122    let mut coo = CooMatrix::<T>::new(total_rows, n_cols);
123    let mut row_offset = 0usize;
124    for m in matrices {
125        for (val, (r, c)) in m.inner().iter() {
126            coo.push(r + row_offset, c, val.clone())
127                .map_err(|e| FerroError::InvalidParameter {
128                    name: "vstack".into(),
129                    reason: format!("push failed: {e}"),
130                })?;
131        }
132        row_offset += m.n_rows();
133    }
134    CsrMatrix::from_coo(&coo)
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140
141    #[test]
142    fn test_eye_basic() {
143        let m: CsrMatrix<f64> = eye(3).unwrap();
144        let dense = m.to_dense();
145        for i in 0..3 {
146            for j in 0..3 {
147                assert!((dense[[i, j]] - if i == j { 1.0 } else { 0.0 }).abs() < 1e-12);
148            }
149        }
150    }
151
152    #[test]
153    fn test_diags_main_diagonal() {
154        let m: CsrMatrix<f64> = diags(&[1.0, 2.0, 3.0], 0, 3).unwrap();
155        let d = m.to_dense();
156        assert!((d[[0, 0]] - 1.0).abs() < 1e-12);
157        assert!((d[[1, 1]] - 2.0).abs() < 1e-12);
158        assert!((d[[2, 2]] - 3.0).abs() < 1e-12);
159    }
160
161    #[test]
162    fn test_diags_super_diagonal() {
163        let m: CsrMatrix<f64> = diags(&[1.0, 2.0], 1, 3).unwrap();
164        let d = m.to_dense();
165        assert!((d[[0, 1]] - 1.0).abs() < 1e-12);
166        assert!((d[[1, 2]] - 2.0).abs() < 1e-12);
167    }
168
169    #[test]
170    fn test_hstack_basic() {
171        let a: CsrMatrix<f64> = eye(2).unwrap();
172        let b: CsrMatrix<f64> = diags(&[5.0, 5.0], 0, 2).unwrap();
173        let h = hstack(&[&a, &b]).unwrap();
174        assert_eq!(h.n_rows(), 2);
175        assert_eq!(h.n_cols(), 4);
176        let d = h.to_dense();
177        assert!((d[[0, 2]] - 5.0).abs() < 1e-12);
178    }
179
180    #[test]
181    fn test_vstack_basic() {
182        let a: CsrMatrix<f64> = eye(2).unwrap();
183        let b: CsrMatrix<f64> = diags(&[5.0, 5.0], 0, 2).unwrap();
184        let v = vstack(&[&a, &b]).unwrap();
185        assert_eq!(v.n_rows(), 4);
186        assert_eq!(v.n_cols(), 2);
187        let d = v.to_dense();
188        assert!((d[[2, 0]] - 5.0).abs() < 1e-12);
189    }
190}