ferrolearn_sparse/
helpers.rs1use ferrolearn_core::FerroError;
11use num_traits::One;
12use std::ops::Add;
13
14use crate::coo::CooMatrix;
15use crate::csr::CsrMatrix;
16
17pub fn eye<T>(n: usize) -> Result<CsrMatrix<T>, FerroError>
19where
20 T: Clone + One + Add<Output = T> + 'static,
21{
22 let mut coo = CooMatrix::<T>::with_capacity(n, n, n);
23 for i in 0..n {
24 coo.push(i, i, T::one())
25 .map_err(|e| FerroError::InvalidParameter {
26 name: "eye".into(),
27 reason: format!("push failed at ({i}, {i}): {e}"),
28 })?;
29 }
30 CsrMatrix::from_coo(&coo)
31}
32
33pub fn diags<T>(values: &[T], offset: isize, n: usize) -> Result<CsrMatrix<T>, FerroError>
38where
39 T: Clone + Add<Output = T> + 'static,
40{
41 let mut coo = CooMatrix::<T>::with_capacity(n, n, values.len());
42 for (k, v) in values.iter().enumerate() {
43 let (i, j) = if offset >= 0 {
44 (k, k + offset as usize)
45 } else {
46 (k + (-offset) as usize, k)
47 };
48 if i < n && j < n {
49 coo.push(i, j, v.clone())
50 .map_err(|e| FerroError::InvalidParameter {
51 name: "diags".into(),
52 reason: format!("push failed at ({i}, {j}): {e}"),
53 })?;
54 }
55 }
56 CsrMatrix::from_coo(&coo)
57}
58
59pub fn hstack<T>(matrices: &[&CsrMatrix<T>]) -> Result<CsrMatrix<T>, FerroError>
63where
64 T: Clone + Add<Output = T> + 'static,
65{
66 if matrices.is_empty() {
67 return Err(FerroError::InvalidParameter {
68 name: "matrices".into(),
69 reason: "hstack: at least one matrix required".into(),
70 });
71 }
72 let n_rows = matrices[0].n_rows();
73 for (idx, m) in matrices.iter().enumerate() {
74 if m.n_rows() != n_rows {
75 return Err(FerroError::ShapeMismatch {
76 expected: vec![n_rows],
77 actual: vec![m.n_rows()],
78 context: format!("hstack: matrix {idx} has {} rows", m.n_rows()),
79 });
80 }
81 }
82 let total_cols: usize = matrices.iter().map(|m| m.n_cols()).sum();
83 let mut coo = CooMatrix::<T>::new(n_rows, total_cols);
84 let mut col_offset = 0usize;
85 for m in matrices {
86 for (val, (r, c)) in m.inner().iter() {
87 coo.push(r, c + col_offset, val.clone())
88 .map_err(|e| FerroError::InvalidParameter {
89 name: "hstack".into(),
90 reason: format!("push failed: {e}"),
91 })?;
92 }
93 col_offset += m.n_cols();
94 }
95 CsrMatrix::from_coo(&coo)
96}
97
98pub fn vstack<T>(matrices: &[&CsrMatrix<T>]) -> Result<CsrMatrix<T>, FerroError>
102where
103 T: Clone + Add<Output = T> + 'static,
104{
105 if matrices.is_empty() {
106 return Err(FerroError::InvalidParameter {
107 name: "matrices".into(),
108 reason: "vstack: at least one matrix required".into(),
109 });
110 }
111 let n_cols = matrices[0].n_cols();
112 for (idx, m) in matrices.iter().enumerate() {
113 if m.n_cols() != n_cols {
114 return Err(FerroError::ShapeMismatch {
115 expected: vec![n_cols],
116 actual: vec![m.n_cols()],
117 context: format!("vstack: matrix {idx} has {} cols", m.n_cols()),
118 });
119 }
120 }
121 let total_rows: usize = matrices.iter().map(|m| m.n_rows()).sum();
122 let mut coo = CooMatrix::<T>::new(total_rows, n_cols);
123 let mut row_offset = 0usize;
124 for m in matrices {
125 for (val, (r, c)) in m.inner().iter() {
126 coo.push(r + row_offset, c, val.clone())
127 .map_err(|e| FerroError::InvalidParameter {
128 name: "vstack".into(),
129 reason: format!("push failed: {e}"),
130 })?;
131 }
132 row_offset += m.n_rows();
133 }
134 CsrMatrix::from_coo(&coo)
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140
141 #[test]
142 fn test_eye_basic() {
143 let m: CsrMatrix<f64> = eye(3).unwrap();
144 let dense = m.to_dense();
145 for i in 0..3 {
146 for j in 0..3 {
147 assert!((dense[[i, j]] - if i == j { 1.0 } else { 0.0 }).abs() < 1e-12);
148 }
149 }
150 }
151
152 #[test]
153 fn test_diags_main_diagonal() {
154 let m: CsrMatrix<f64> = diags(&[1.0, 2.0, 3.0], 0, 3).unwrap();
155 let d = m.to_dense();
156 assert!((d[[0, 0]] - 1.0).abs() < 1e-12);
157 assert!((d[[1, 1]] - 2.0).abs() < 1e-12);
158 assert!((d[[2, 2]] - 3.0).abs() < 1e-12);
159 }
160
161 #[test]
162 fn test_diags_super_diagonal() {
163 let m: CsrMatrix<f64> = diags(&[1.0, 2.0], 1, 3).unwrap();
164 let d = m.to_dense();
165 assert!((d[[0, 1]] - 1.0).abs() < 1e-12);
166 assert!((d[[1, 2]] - 2.0).abs() < 1e-12);
167 }
168
169 #[test]
170 fn test_hstack_basic() {
171 let a: CsrMatrix<f64> = eye(2).unwrap();
172 let b: CsrMatrix<f64> = diags(&[5.0, 5.0], 0, 2).unwrap();
173 let h = hstack(&[&a, &b]).unwrap();
174 assert_eq!(h.n_rows(), 2);
175 assert_eq!(h.n_cols(), 4);
176 let d = h.to_dense();
177 assert!((d[[0, 2]] - 5.0).abs() < 1e-12);
178 }
179
180 #[test]
181 fn test_vstack_basic() {
182 let a: CsrMatrix<f64> = eye(2).unwrap();
183 let b: CsrMatrix<f64> = diags(&[5.0, 5.0], 0, 2).unwrap();
184 let v = vstack(&[&a, &b]).unwrap();
185 assert_eq!(v.n_rows(), 4);
186 assert_eq!(v.n_cols(), 2);
187 let d = v.to_dense();
188 assert!((d[[2, 0]] - 5.0).abs() < 1e-12);
189 }
190}