1use ferrolearn_core::FerroError;
33use num_traits::One;
34use std::ops::Add;
35
36use crate::coo::CooMatrix;
37use crate::csr::CsrMatrix;
38
39pub fn eye<T>(n: usize) -> Result<CsrMatrix<T>, FerroError>
41where
42 T: Clone + One + Add<Output = T> + 'static,
43{
44 let mut coo = CooMatrix::<T>::with_capacity(n, n, n);
45 for i in 0..n {
46 coo.push(i, i, T::one())
47 .map_err(|e| FerroError::InvalidParameter {
48 name: "eye".into(),
49 reason: format!("push failed at ({i}, {i}): {e}"),
50 })?;
51 }
52 CsrMatrix::from_coo(&coo)
53}
54
55pub fn diags<T>(values: &[T], offset: isize, n: usize) -> Result<CsrMatrix<T>, FerroError>
65where
66 T: Clone + Add<Output = T> + 'static,
67{
68 let required = n.saturating_sub(offset.unsigned_abs());
72 if values.len() < required {
73 return Err(FerroError::InvalidParameter {
74 name: "diags".into(),
75 reason: format!(
76 "diagonal length {} does not agree with array size ({n}, {n}) at offset {offset} (expected {required})",
77 values.len()
78 ),
79 });
80 }
81 let mut coo = CooMatrix::<T>::with_capacity(n, n, values.len());
82 for (k, v) in values.iter().enumerate() {
83 let (i, j) = if offset >= 0 {
84 (k, k + offset as usize)
85 } else {
86 (k + (-offset) as usize, k)
87 };
88 if i < n && j < n {
89 coo.push(i, j, v.clone())
90 .map_err(|e| FerroError::InvalidParameter {
91 name: "diags".into(),
92 reason: format!("push failed at ({i}, {j}): {e}"),
93 })?;
94 }
95 }
96 CsrMatrix::from_coo(&coo)
97}
98
99pub fn hstack<T>(matrices: &[&CsrMatrix<T>]) -> Result<CsrMatrix<T>, FerroError>
103where
104 T: Clone + Add<Output = T> + 'static,
105{
106 if matrices.is_empty() {
107 return Err(FerroError::InvalidParameter {
108 name: "matrices".into(),
109 reason: "hstack: at least one matrix required".into(),
110 });
111 }
112 let n_rows = matrices[0].n_rows();
113 for (idx, m) in matrices.iter().enumerate() {
114 if m.n_rows() != n_rows {
115 return Err(FerroError::ShapeMismatch {
116 expected: vec![n_rows],
117 actual: vec![m.n_rows()],
118 context: format!("hstack: matrix {idx} has {} rows", m.n_rows()),
119 });
120 }
121 }
122 let total_cols: usize = matrices.iter().map(|m| m.n_cols()).sum();
123 let mut coo = CooMatrix::<T>::new(n_rows, total_cols);
124 let mut col_offset = 0usize;
125 for m in matrices {
126 for (val, (r, c)) in m.inner().iter() {
127 coo.push(r, c + col_offset, val.clone())
128 .map_err(|e| FerroError::InvalidParameter {
129 name: "hstack".into(),
130 reason: format!("push failed: {e}"),
131 })?;
132 }
133 col_offset += m.n_cols();
134 }
135 CsrMatrix::from_coo(&coo)
136}
137
138pub fn vstack<T>(matrices: &[&CsrMatrix<T>]) -> Result<CsrMatrix<T>, FerroError>
142where
143 T: Clone + Add<Output = T> + 'static,
144{
145 if matrices.is_empty() {
146 return Err(FerroError::InvalidParameter {
147 name: "matrices".into(),
148 reason: "vstack: at least one matrix required".into(),
149 });
150 }
151 let n_cols = matrices[0].n_cols();
152 for (idx, m) in matrices.iter().enumerate() {
153 if m.n_cols() != n_cols {
154 return Err(FerroError::ShapeMismatch {
155 expected: vec![n_cols],
156 actual: vec![m.n_cols()],
157 context: format!("vstack: matrix {idx} has {} cols", m.n_cols()),
158 });
159 }
160 }
161 let total_rows: usize = matrices.iter().map(|m| m.n_rows()).sum();
162 let mut coo = CooMatrix::<T>::new(total_rows, n_cols);
163 let mut row_offset = 0usize;
164 for m in matrices {
165 for (val, (r, c)) in m.inner().iter() {
166 coo.push(r + row_offset, c, val.clone())
167 .map_err(|e| FerroError::InvalidParameter {
168 name: "vstack".into(),
169 reason: format!("push failed: {e}"),
170 })?;
171 }
172 row_offset += m.n_rows();
173 }
174 CsrMatrix::from_coo(&coo)
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180
181 #[test]
182 fn test_eye_basic() {
183 let m: CsrMatrix<f64> = eye(3).unwrap();
184 let dense = m.to_dense();
185 for i in 0..3 {
186 for j in 0..3 {
187 assert!((dense[[i, j]] - if i == j { 1.0 } else { 0.0 }).abs() < 1e-12);
188 }
189 }
190 }
191
192 #[test]
193 fn test_diags_main_diagonal() {
194 let m: CsrMatrix<f64> = diags(&[1.0, 2.0, 3.0], 0, 3).unwrap();
195 let d = m.to_dense();
196 assert!((d[[0, 0]] - 1.0).abs() < 1e-12);
197 assert!((d[[1, 1]] - 2.0).abs() < 1e-12);
198 assert!((d[[2, 2]] - 3.0).abs() < 1e-12);
199 }
200
201 #[test]
202 fn test_diags_super_diagonal() {
203 let m: CsrMatrix<f64> = diags(&[1.0, 2.0], 1, 3).unwrap();
204 let d = m.to_dense();
205 assert!((d[[0, 1]] - 1.0).abs() < 1e-12);
206 assert!((d[[1, 2]] - 2.0).abs() < 1e-12);
207 }
208
209 #[test]
210 fn test_hstack_basic() {
211 let a: CsrMatrix<f64> = eye(2).unwrap();
212 let b: CsrMatrix<f64> = diags(&[5.0, 5.0], 0, 2).unwrap();
213 let h = hstack(&[&a, &b]).unwrap();
214 assert_eq!(h.n_rows(), 2);
215 assert_eq!(h.n_cols(), 4);
216 let d = h.to_dense();
217 assert!((d[[0, 2]] - 5.0).abs() < 1e-12);
218 }
219
220 #[test]
221 fn test_vstack_basic() {
222 let a: CsrMatrix<f64> = eye(2).unwrap();
223 let b: CsrMatrix<f64> = diags(&[5.0, 5.0], 0, 2).unwrap();
224 let v = vstack(&[&a, &b]).unwrap();
225 assert_eq!(v.n_rows(), 4);
226 assert_eq!(v.n_cols(), 2);
227 let d = v.to_dense();
228 assert!((d[[2, 0]] - 5.0).abs() < 1e-12);
229 }
230}