1use crate::error::{Result, ScryLearnError};
9
10#[derive(Clone, Debug)]
19#[non_exhaustive]
20pub struct DenseMatrix {
21 data: Vec<f64>,
23 n_rows: usize,
25 n_cols: usize,
27}
28
29impl DenseMatrix {
30 pub fn new(data: Vec<f64>, n_rows: usize, n_cols: usize) -> Result<Self> {
34 if data.len() != n_rows * n_cols {
35 return Err(ScryLearnError::InvalidParameter(format!(
36 "DenseMatrix::new: data.len()={} but n_rows*n_cols={}",
37 data.len(),
38 n_rows * n_cols,
39 )));
40 }
41 Ok(Self {
42 data,
43 n_rows,
44 n_cols,
45 })
46 }
47
48 pub fn zeros(n_rows: usize, n_cols: usize) -> Self {
50 Self {
51 data: vec![0.0; n_rows * n_cols],
52 n_rows,
53 n_cols,
54 }
55 }
56
57 pub fn from_col_major(cols: Vec<Vec<f64>>) -> Result<Self> {
61 if cols.is_empty() {
62 return Ok(Self {
63 data: Vec::new(),
64 n_rows: 0,
65 n_cols: 0,
66 });
67 }
68 let n_rows = cols[0].len();
69 let n_cols = cols.len();
70 for (i, col) in cols.iter().enumerate() {
71 if col.len() != n_rows {
72 return Err(ScryLearnError::InvalidParameter(format!(
73 "DenseMatrix::from_col_major: column {i} has {} rows, expected {n_rows}",
74 col.len(),
75 )));
76 }
77 }
78 let mut data = Vec::with_capacity(n_rows * n_cols);
79 for col in &cols {
80 data.extend_from_slice(col);
81 }
82 Ok(Self {
83 data,
84 n_rows,
85 n_cols,
86 })
87 }
88
89 pub fn from_row_major(rows: &[&[f64]], n_rows: usize, n_cols: usize) -> Self {
91 let mut data = vec![0.0; n_rows * n_cols];
92 for (i, row) in rows.iter().enumerate() {
93 for (j, &val) in row.iter().enumerate() {
94 data[j * n_rows + i] = val;
95 }
96 }
97 Self {
98 data,
99 n_rows,
100 n_cols,
101 }
102 }
103
104 #[inline]
106 pub fn col(&self, j: usize) -> &[f64] {
107 let start = j * self.n_rows;
108 &self.data[start..start + self.n_rows]
109 }
110
111 #[inline]
113 pub fn col_mut(&mut self, j: usize) -> &mut [f64] {
114 let start = j * self.n_rows;
115 &mut self.data[start..start + self.n_rows]
116 }
117
118 #[inline]
120 pub fn get(&self, row: usize, col: usize) -> f64 {
121 self.data[col * self.n_rows + row]
122 }
123
124 #[inline]
126 pub fn set(&mut self, row: usize, col: usize, val: f64) {
127 self.data[col * self.n_rows + row] = val;
128 }
129
130 #[inline]
132 pub fn n_rows(&self) -> usize {
133 self.n_rows
134 }
135
136 #[inline]
138 pub fn n_cols(&self) -> usize {
139 self.n_cols
140 }
141
142 #[inline]
144 pub fn as_slice(&self) -> &[f64] {
145 &self.data
146 }
147
148 pub fn row_iter(&self, i: usize) -> impl Iterator<Item = f64> + '_ {
150 (0..self.n_cols).map(move |j| self.data[j * self.n_rows + i])
151 }
152
153 pub fn row_to_vec(&self, i: usize) -> Vec<f64> {
155 self.row_iter(i).collect()
156 }
157
158 pub fn from_col_major_ref(cols: &[Vec<f64>]) -> Result<Self> {
163 if cols.is_empty() {
164 return Ok(Self {
165 data: Vec::new(),
166 n_rows: 0,
167 n_cols: 0,
168 });
169 }
170 let n_rows = cols[0].len();
171 let n_cols = cols.len();
172 for (i, col) in cols.iter().enumerate() {
173 if col.len() != n_rows {
174 return Err(ScryLearnError::InvalidParameter(format!(
175 "DenseMatrix::from_col_major_ref: column {i} has {} rows, expected {n_rows}",
176 col.len(),
177 )));
178 }
179 }
180 let mut data = Vec::with_capacity(n_rows * n_cols);
181 for col in cols {
182 data.extend_from_slice(col);
183 }
184 Ok(Self {
185 data,
186 n_rows,
187 n_cols,
188 })
189 }
190
191 pub fn to_col_vecs(&self) -> Vec<Vec<f64>> {
193 (0..self.n_cols).map(|j| self.col(j).to_vec()).collect()
194 }
195}
196
197impl From<Vec<Vec<f64>>> for DenseMatrix {
202 fn from(cols: Vec<Vec<f64>>) -> Self {
204 Self::from_col_major(cols).expect("ragged column vectors in DenseMatrix::from")
205 }
206}
207
208impl From<&[Vec<f64>]> for DenseMatrix {
209 fn from(cols: &[Vec<f64>]) -> Self {
210 let owned: Vec<Vec<f64>> = cols.to_vec();
211 Self::from(owned)
212 }
213}
214
215#[cfg(feature = "serde")]
220impl serde::Serialize for DenseMatrix {
221 fn serialize<S: serde::Serializer>(
222 &self,
223 serializer: S,
224 ) -> std::result::Result<S::Ok, S::Error> {
225 use serde::ser::SerializeStruct;
226 let mut state = serializer.serialize_struct("DenseMatrix", 3)?;
227 state.serialize_field("data", &self.data)?;
228 state.serialize_field("n_rows", &self.n_rows)?;
229 state.serialize_field("n_cols", &self.n_cols)?;
230 state.end()
231 }
232}
233
234#[cfg(feature = "serde")]
235impl<'de> serde::Deserialize<'de> for DenseMatrix {
236 fn deserialize<D: serde::Deserializer<'de>>(
237 deserializer: D,
238 ) -> std::result::Result<Self, D::Error> {
239 #[derive(serde::Deserialize)]
240 struct Raw {
241 data: Vec<f64>,
242 n_rows: usize,
243 n_cols: usize,
244 }
245 let raw = Raw::deserialize(deserializer)?;
246 Self::new(raw.data, raw.n_rows, raw.n_cols).map_err(serde::de::Error::custom)
247 }
248}
249
250#[cfg(test)]
255#[allow(clippy::float_cmp)]
256mod tests {
257 use super::*;
258
259 #[test]
260 fn from_col_major_roundtrip() {
261 let cols = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
262 let m = DenseMatrix::from_col_major(cols.clone()).unwrap();
263 assert_eq!(m.n_rows(), 3);
264 assert_eq!(m.n_cols(), 2);
265 assert_eq!(m.to_col_vecs(), cols);
266 }
267
268 #[test]
269 fn col_correctness() {
270 let m = DenseMatrix::from_col_major(vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]])
271 .unwrap();
272 assert_eq!(m.col(0), &[1.0, 2.0]);
273 assert_eq!(m.col(1), &[3.0, 4.0]);
274 assert_eq!(m.col(2), &[5.0, 6.0]);
275 }
276
277 #[test]
278 fn row_iter_correctness() {
279 let m =
280 DenseMatrix::from_col_major(vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]).unwrap();
281 let row0: Vec<f64> = m.row_iter(0).collect();
282 assert_eq!(row0, vec![1.0, 4.0]);
283 let row2: Vec<f64> = m.row_iter(2).collect();
284 assert_eq!(row2, vec![3.0, 6.0]);
285 }
286
287 #[test]
288 fn get_set_indexing() {
289 let mut m = DenseMatrix::zeros(3, 2);
290 m.set(1, 0, 42.0);
291 m.set(2, 1, 99.0);
292 assert_eq!(m.get(1, 0), 42.0);
293 assert_eq!(m.get(2, 1), 99.0);
294 assert_eq!(m.get(0, 0), 0.0);
295 }
296
297 #[test]
298 fn from_vec_vec_conversion() {
299 let cols = vec![vec![10.0, 20.0], vec![30.0, 40.0]];
300 let m: DenseMatrix = cols.into();
301 assert_eq!(m.n_rows(), 2);
302 assert_eq!(m.n_cols(), 2);
303 assert_eq!(m.get(0, 0), 10.0);
304 assert_eq!(m.get(1, 1), 40.0);
305 }
306
307 #[test]
308 fn from_slice_conversion() {
309 let cols = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
310 let m: DenseMatrix = cols.as_slice().into();
311 assert_eq!(m.col(0), &[1.0, 2.0]);
312 }
313
314 #[test]
315 fn empty_matrix() {
316 let m = DenseMatrix::from_col_major(vec![]).unwrap();
317 assert_eq!(m.n_rows(), 0);
318 assert_eq!(m.n_cols(), 0);
319 assert_eq!(m.as_slice(), &[] as &[f64]);
320 }
321
322 #[test]
323 fn zero_row_matrix() {
324 let m = DenseMatrix::from_col_major(vec![vec![], vec![]]).unwrap();
325 assert_eq!(m.n_rows(), 0);
326 assert_eq!(m.n_cols(), 2);
327 }
328
329 #[test]
330 fn single_column() {
331 let m = DenseMatrix::from_col_major(vec![vec![1.0, 2.0, 3.0]]).unwrap();
332 assert_eq!(m.n_cols(), 1);
333 assert_eq!(m.col(0), &[1.0, 2.0, 3.0]);
334 assert_eq!(m.row_to_vec(1), vec![2.0]);
335 }
336
337 #[test]
338 fn ragged_error() {
339 let result = DenseMatrix::from_col_major(vec![vec![1.0, 2.0], vec![3.0]]);
340 assert!(result.is_err());
341 }
342
343 #[test]
344 fn new_validates_length() {
345 assert!(DenseMatrix::new(vec![1.0, 2.0, 3.0], 2, 2).is_err());
346 assert!(DenseMatrix::new(vec![1.0, 2.0, 3.0, 4.0], 2, 2).is_ok());
347 }
348
349 #[test]
350 fn from_row_major_transposes() {
351 let rows: Vec<&[f64]> = vec![&[1.0, 2.0], &[3.0, 4.0], &[5.0, 6.0]];
352 let m = DenseMatrix::from_row_major(&rows, 3, 2);
353 assert_eq!(m.col(0), &[1.0, 3.0, 5.0]);
355 assert_eq!(m.col(1), &[2.0, 4.0, 6.0]);
356 }
357
358 #[test]
359 fn col_mut_works() {
360 let mut m = DenseMatrix::zeros(3, 2);
361 let col = m.col_mut(1);
362 col[0] = 10.0;
363 col[1] = 20.0;
364 col[2] = 30.0;
365 assert_eq!(m.col(1), &[10.0, 20.0, 30.0]);
366 }
367
368 #[cfg(feature = "serde")]
369 #[test]
370 fn serde_roundtrip() {
371 let m = DenseMatrix::from_col_major(vec![vec![1.0, 2.0], vec![3.0, 4.0]]).unwrap();
372 let json = serde_json::to_string(&m).unwrap();
373 let m2: DenseMatrix = serde_json::from_str(&json).unwrap();
374 assert_eq!(m.as_slice(), m2.as_slice());
375 assert_eq!(m.n_rows(), m2.n_rows());
376 assert_eq!(m.n_cols(), m2.n_cols());
377 }
378}