Skip to main content

ferrolearn_sparse/
helpers.rs

1//! Convenience constructors for common sparse-matrix patterns.
2//!
3//! These mirror the most-used scipy.sparse helpers:
4//!
5//! - [`eye`] — sparse identity matrix.
6//! - [`diags`] — sparse diagonal matrix.
7//! - [`hstack`] — horizontal concatenation of CSR matrices.
8//! - [`vstack`] — vertical concatenation of CSR matrices.
9//!
10//! ## REQ status
11//!
12//! Mirrors `scipy.sparse` construction helpers (`scipy/sparse/_construct.py`;
13//! live oracle scipy 1.17, deterministic). Design doc: `.design/sparse/helpers.md`
14//! (9 REQs). Every REQ is BINARY (R-DEFER-2): SHIPPED or NOT-STARTED (with a
15//! concrete blocker). Behavior is oracle-verified vs the live scipy (R-CHAR-3) —
16//! see `tests/divergence_helpers.rs`.
17//!
18//! **5 SHIPPED / 4 NOT-STARTED.**
19//!
20//! | REQ | Status | Notes |
21//! |---|---|---|
22//! | REQ-EYE (n×n identity) | SHIPPED | `eye(n)` == scipy `eye(n).toarray()` (square identity). Guard `eye_3_matches_scipy_identity`. (Rectangular `eye(m,n,k)` gap — #2019.) |
23//! | REQ-DIAGS-SINGLE (single diagonal + alignment) | SHIPPED | `diags(values, offset, n)` main/super/sub alignment matches scipy `diags([values],[offset])`. Guards `diags_main`/`diags_super_offset1`/`diags_sub_offset_neg1_matches_scipy`. |
24//! | REQ-DIAGS-LENGTH-VALIDATION | SHIPPED | FIXED #2016: a too-SHORT diagonal (`values.len() < n−\|offset\|`) now returns `Err(FerroError)` matching scipy's `ValueError`; a too-LONG diagonal still truncates (matching scipy's silent truncation). Guards `diags_too_short_must_error_like_scipy`/`diags_too_long_truncates_like_scipy`. |
25//! | REQ-HSTACK (horizontal CSR concat) | SHIPPED | `hstack(&[..])` == scipy `hstack([..])` (oracle `[[1,0,5,0],[0,1,0,5]]`); same-rows validation → `Err`. Guards `hstack_matches_scipy`/`hstack_row_mismatch_is_err`. (`format=`/mixed-input gap — #2020.) |
26//! | REQ-VSTACK (vertical CSR concat) | SHIPPED | `vstack(&[..])` == scipy `vstack([..])`; same-cols validation → `Err`. Guards `vstack_matches_scipy`/`vstack_col_mismatch_is_err`. (`format=` gap — #2020.) |
27//! | REQ-DIAGS-MULTI (list of diagonals/offsets) | NOT-STARTED | single-diagonal only; scipy `diags(LIST, LIST)`. Blocker #2017. |
28//! | REQ-MISSING-HELPERS | NOT-STARTED | no `identity`/`spdiags`/`bmat`/`block_diag`/`kron`/`random`/`tril`/`triu`. Blocker #2018. |
29//! | REQ-CONSUMER (production consumer) | NOT-STARTED | no estimator consumes `eye`/`diags`/`hstack`/`vstack` (standalone; only the `lib.rs` re-export). Blocker #2021. |
30//! | REQ-FERRAY (ferray sparse substrate) | NOT-STARTED | builds on `sprs`/`ndarray` (via Coo/Csr) vs ferray's sparse analog (R-SUBSTRATE-1). Blocker #2022. |
31
32use ferrolearn_core::FerroError;
33use num_traits::One;
34use std::ops::Add;
35
36use crate::coo::CooMatrix;
37use crate::csr::CsrMatrix;
38
39/// Build an `n x n` sparse identity matrix.
40pub fn eye<T>(n: usize) -> Result<CsrMatrix<T>, FerroError>
41where
42    T: Clone + One + Add<Output = T> + 'static,
43{
44    let mut coo = CooMatrix::<T>::with_capacity(n, n, n);
45    for i in 0..n {
46        coo.push(i, i, T::one())
47            .map_err(|e| FerroError::InvalidParameter {
48                name: "eye".into(),
49                reason: format!("push failed at ({i}, {i}): {e}"),
50            })?;
51    }
52    CsrMatrix::from_coo(&coo)
53}
54
55/// Build a sparse `n x n` matrix from a single diagonal vector at `offset`.
56///
57/// `offset == 0` puts `values` on the main diagonal; `offset > 0` shifts to
58/// a super-diagonal; `offset < 0` shifts to a sub-diagonal.
59///
60/// The required diagonal length for an `n x n` grid at signed `offset` is
61/// `n - |offset|`. A diagonal that is too SHORT returns `Err`, matching scipy's
62/// `ValueError` (`scipy/sparse/_construct.py:435`); a too-LONG diagonal is
63/// silently truncated, matching scipy's behavior (`_construct.py:433`).
64pub fn diags<T>(values: &[T], offset: isize, n: usize) -> Result<CsrMatrix<T>, FerroError>
65where
66    T: Clone + Add<Output = T> + 'static,
67{
68    // scipy raises `ValueError` on a too-SHORT diagonal but silently truncates a
69    // too-LONG one (`_construct.py:433-439`). The required length is `n - |offset|`;
70    // saturate to avoid `usize` underflow when the diagonal is entirely off-grid.
71    let required = n.saturating_sub(offset.unsigned_abs());
72    if values.len() < required {
73        return Err(FerroError::InvalidParameter {
74            name: "diags".into(),
75            reason: format!(
76                "diagonal length {} does not agree with array size ({n}, {n}) at offset {offset} (expected {required})",
77                values.len()
78            ),
79        });
80    }
81    let mut coo = CooMatrix::<T>::with_capacity(n, n, values.len());
82    for (k, v) in values.iter().enumerate() {
83        let (i, j) = if offset >= 0 {
84            (k, k + offset as usize)
85        } else {
86            (k + (-offset) as usize, k)
87        };
88        if i < n && j < n {
89            coo.push(i, j, v.clone())
90                .map_err(|e| FerroError::InvalidParameter {
91                    name: "diags".into(),
92                    reason: format!("push failed at ({i}, {j}): {e}"),
93                })?;
94        }
95    }
96    CsrMatrix::from_coo(&coo)
97}
98
99/// Horizontally concatenate CSR matrices.
100///
101/// All matrices must have the same number of rows.
102pub fn hstack<T>(matrices: &[&CsrMatrix<T>]) -> Result<CsrMatrix<T>, FerroError>
103where
104    T: Clone + Add<Output = T> + 'static,
105{
106    if matrices.is_empty() {
107        return Err(FerroError::InvalidParameter {
108            name: "matrices".into(),
109            reason: "hstack: at least one matrix required".into(),
110        });
111    }
112    let n_rows = matrices[0].n_rows();
113    for (idx, m) in matrices.iter().enumerate() {
114        if m.n_rows() != n_rows {
115            return Err(FerroError::ShapeMismatch {
116                expected: vec![n_rows],
117                actual: vec![m.n_rows()],
118                context: format!("hstack: matrix {idx} has {} rows", m.n_rows()),
119            });
120        }
121    }
122    let total_cols: usize = matrices.iter().map(|m| m.n_cols()).sum();
123    let mut coo = CooMatrix::<T>::new(n_rows, total_cols);
124    let mut col_offset = 0usize;
125    for m in matrices {
126        for (val, (r, c)) in m.inner().iter() {
127            coo.push(r, c + col_offset, val.clone())
128                .map_err(|e| FerroError::InvalidParameter {
129                    name: "hstack".into(),
130                    reason: format!("push failed: {e}"),
131                })?;
132        }
133        col_offset += m.n_cols();
134    }
135    CsrMatrix::from_coo(&coo)
136}
137
138/// Vertically concatenate CSR matrices.
139///
140/// All matrices must have the same number of columns.
141pub fn vstack<T>(matrices: &[&CsrMatrix<T>]) -> Result<CsrMatrix<T>, FerroError>
142where
143    T: Clone + Add<Output = T> + 'static,
144{
145    if matrices.is_empty() {
146        return Err(FerroError::InvalidParameter {
147            name: "matrices".into(),
148            reason: "vstack: at least one matrix required".into(),
149        });
150    }
151    let n_cols = matrices[0].n_cols();
152    for (idx, m) in matrices.iter().enumerate() {
153        if m.n_cols() != n_cols {
154            return Err(FerroError::ShapeMismatch {
155                expected: vec![n_cols],
156                actual: vec![m.n_cols()],
157                context: format!("vstack: matrix {idx} has {} cols", m.n_cols()),
158            });
159        }
160    }
161    let total_rows: usize = matrices.iter().map(|m| m.n_rows()).sum();
162    let mut coo = CooMatrix::<T>::new(total_rows, n_cols);
163    let mut row_offset = 0usize;
164    for m in matrices {
165        for (val, (r, c)) in m.inner().iter() {
166            coo.push(r + row_offset, c, val.clone())
167                .map_err(|e| FerroError::InvalidParameter {
168                    name: "vstack".into(),
169                    reason: format!("push failed: {e}"),
170                })?;
171        }
172        row_offset += m.n_rows();
173    }
174    CsrMatrix::from_coo(&coo)
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180
181    #[test]
182    fn test_eye_basic() {
183        let m: CsrMatrix<f64> = eye(3).unwrap();
184        let dense = m.to_dense();
185        for i in 0..3 {
186            for j in 0..3 {
187                assert!((dense[[i, j]] - if i == j { 1.0 } else { 0.0 }).abs() < 1e-12);
188            }
189        }
190    }
191
192    #[test]
193    fn test_diags_main_diagonal() {
194        let m: CsrMatrix<f64> = diags(&[1.0, 2.0, 3.0], 0, 3).unwrap();
195        let d = m.to_dense();
196        assert!((d[[0, 0]] - 1.0).abs() < 1e-12);
197        assert!((d[[1, 1]] - 2.0).abs() < 1e-12);
198        assert!((d[[2, 2]] - 3.0).abs() < 1e-12);
199    }
200
201    #[test]
202    fn test_diags_super_diagonal() {
203        let m: CsrMatrix<f64> = diags(&[1.0, 2.0], 1, 3).unwrap();
204        let d = m.to_dense();
205        assert!((d[[0, 1]] - 1.0).abs() < 1e-12);
206        assert!((d[[1, 2]] - 2.0).abs() < 1e-12);
207    }
208
209    #[test]
210    fn test_hstack_basic() {
211        let a: CsrMatrix<f64> = eye(2).unwrap();
212        let b: CsrMatrix<f64> = diags(&[5.0, 5.0], 0, 2).unwrap();
213        let h = hstack(&[&a, &b]).unwrap();
214        assert_eq!(h.n_rows(), 2);
215        assert_eq!(h.n_cols(), 4);
216        let d = h.to_dense();
217        assert!((d[[0, 2]] - 5.0).abs() < 1e-12);
218    }
219
220    #[test]
221    fn test_vstack_basic() {
222        let a: CsrMatrix<f64> = eye(2).unwrap();
223        let b: CsrMatrix<f64> = diags(&[5.0, 5.0], 0, 2).unwrap();
224        let v = vstack(&[&a, &b]).unwrap();
225        assert_eq!(v.n_rows(), 4);
226        assert_eq!(v.n_cols(), 2);
227        let d = v.to_dense();
228        assert!((d[[2, 0]] - 5.0).abs() < 1e-12);
229    }
230}