1use crate::HisabError;
8
9#[derive(Debug, Clone, PartialEq)]
27pub struct DenseMatrix {
28 data: Vec<f64>,
29 rows: usize,
30 cols: usize,
31}
32
33impl DenseMatrix {
34 #[must_use]
39 #[inline]
40 pub fn zeros(rows: usize, cols: usize) -> Self {
41 Self {
42 data: vec![0.0; rows * cols],
43 rows,
44 cols,
45 }
46 }
47
48 #[must_use]
50 #[inline]
51 pub fn identity(n: usize) -> Self {
52 let mut m = Self::zeros(n, n);
53 for i in 0..n {
54 m.data[i * n + i] = 1.0;
55 }
56 m
57 }
58
59 #[must_use = "returns the matrix or an error"]
65 pub fn from_rows(rows: usize, cols: usize, data: Vec<f64>) -> Result<Self, HisabError> {
66 if data.len() != rows * cols {
67 return Err(HisabError::InvalidInput(alloc_msg(
68 "data length",
69 data.len(),
70 rows * cols,
71 )));
72 }
73 Ok(Self { data, rows, cols })
74 }
75
76 #[must_use = "returns the matrix or an error"]
85 pub fn from_vec_of_vec(v: &[Vec<f64>]) -> Result<Self, HisabError> {
86 if v.is_empty() {
87 return Err(HisabError::InvalidInput("empty row list".into()));
88 }
89 let cols = v[0].len();
90 let rows = v.len();
91 let mut data = Vec::with_capacity(rows * cols);
92 for (r, row) in v.iter().enumerate() {
93 if row.len() != cols {
94 return Err(HisabError::InvalidInput(alloc_msg(
95 &format!("row {r} length"),
96 row.len(),
97 cols,
98 )));
99 }
100 data.extend_from_slice(row);
101 }
102 Ok(Self { data, rows, cols })
103 }
104
105 #[must_use]
110 pub fn to_vec_of_vec(&self) -> Vec<Vec<f64>> {
111 (0..self.rows)
112 .map(|r| self.data[r * self.cols..(r + 1) * self.cols].to_vec())
113 .collect()
114 }
115
116 #[must_use]
121 #[inline]
122 pub fn rows(&self) -> usize {
123 self.rows
124 }
125
126 #[must_use]
128 #[inline]
129 pub fn cols(&self) -> usize {
130 self.cols
131 }
132
133 #[must_use]
142 #[inline]
143 pub fn get(&self, row: usize, col: usize) -> f64 {
144 debug_assert!(row < self.rows && col < self.cols, "index out of bounds");
145 self.data[row * self.cols + col]
146 }
147
148 #[inline]
154 pub fn get_mut(&mut self, row: usize, col: usize) -> &mut f64 {
155 debug_assert!(row < self.rows && col < self.cols, "index out of bounds");
156 &mut self.data[row * self.cols + col]
157 }
158
159 #[must_use]
165 #[inline]
166 pub fn row(&self, i: usize) -> &[f64] {
167 debug_assert!(i < self.rows, "row index out of bounds");
168 &self.data[i * self.cols..(i + 1) * self.cols]
169 }
170
171 #[inline]
177 pub fn set(&mut self, row: usize, col: usize, val: f64) {
178 debug_assert!(row < self.rows && col < self.cols, "index out of bounds");
179 self.data[row * self.cols + col] = val;
180 }
181
182 #[must_use = "returns the product vector or an error"]
191 pub fn mul_vec(&self, x: &[f64]) -> Result<Vec<f64>, HisabError> {
192 if x.len() != self.cols {
193 return Err(HisabError::InvalidInput(alloc_msg(
194 "vector length",
195 x.len(),
196 self.cols,
197 )));
198 }
199 let mut out = vec![0.0; self.rows];
200 for (r, dst) in out.iter_mut().enumerate() {
201 let row = &self.data[r * self.cols..(r + 1) * self.cols];
202 let mut sum = 0.0_f64;
204 let mut comp = 0.0_f64;
205 for c in 0..self.cols {
206 let v = row[c] * x[c];
207 let t = sum + v;
208 if sum.abs() >= v.abs() {
209 comp += (sum - t) + v;
210 } else {
211 comp += (v - t) + sum;
212 }
213 sum = t;
214 }
215 *dst = sum + comp;
216 }
217 Ok(out)
218 }
219
220 #[must_use = "returns the product matrix or an error"]
226 pub fn mul_mat(&self, other: &DenseMatrix) -> Result<DenseMatrix, HisabError> {
227 if self.cols != other.rows {
228 return Err(HisabError::InvalidInput(alloc_msg(
229 "self.cols",
230 self.cols,
231 other.rows,
232 )));
233 }
234 let rows = self.rows;
235 let cols = other.cols;
236 let inner = self.cols;
237 let mut out = DenseMatrix::zeros(rows, cols);
238 for r in 0..rows {
239 for c in 0..cols {
240 let mut sum = 0.0_f64;
242 let mut comp = 0.0_f64;
243 for k in 0..inner {
244 let v = self.data[r * inner + k] * other.data[k * cols + c];
245 let t = sum + v;
246 if sum.abs() >= v.abs() {
247 comp += (sum - t) + v;
248 } else {
249 comp += (v - t) + sum;
250 }
251 sum = t;
252 }
253 out.data[r * cols + c] = sum + comp;
254 }
255 }
256 Ok(out)
257 }
258
259 #[must_use]
261 pub fn transpose(&self) -> DenseMatrix {
262 let mut out = DenseMatrix::zeros(self.cols, self.rows);
263 for r in 0..self.rows {
264 for c in 0..self.cols {
265 out.data[c * self.rows + r] = self.data[r * self.cols + c];
266 }
267 }
268 out
269 }
270
271 #[must_use]
273 pub fn frobenius_norm(&self) -> f64 {
274 self.data.iter().map(|&v| v * v).sum::<f64>().sqrt()
275 }
276}
277
278impl std::ops::Index<(usize, usize)> for DenseMatrix {
282 type Output = f64;
283
284 #[inline]
285 fn index(&self, (row, col): (usize, usize)) -> &f64 {
286 debug_assert!(row < self.rows && col < self.cols, "index out of bounds");
287 &self.data[row * self.cols + col]
288 }
289}
290
291impl std::ops::IndexMut<(usize, usize)> for DenseMatrix {
292 #[inline]
293 fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut f64 {
294 debug_assert!(row < self.rows && col < self.cols, "index out of bounds");
295 &mut self.data[row * self.cols + col]
296 }
297}
298
299fn alloc_msg(field: &str, got: usize, expected: usize) -> String {
304 let mut s = String::new();
305 let _ = std::fmt::write(
306 &mut s,
307 format_args!("{field}: expected {expected}, got {got}"),
308 );
309 s
310}
311
312#[cfg(test)]
316mod tests {
317 use super::*;
318
319 #[test]
320 fn zeros_is_all_zero() {
321 let m = DenseMatrix::zeros(3, 4);
322 for r in 0..3 {
323 for c in 0..4 {
324 assert_eq!(m.get(r, c), 0.0);
325 }
326 }
327 }
328
329 #[test]
330 fn identity_diagonal() {
331 let id = DenseMatrix::identity(4);
332 for r in 0..4 {
333 for c in 0..4 {
334 let expected = if r == c { 1.0 } else { 0.0 };
335 assert_eq!(id.get(r, c), expected);
336 }
337 }
338 }
339
340 #[test]
341 fn from_rows_roundtrip() {
342 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
343 let m = DenseMatrix::from_rows(2, 3, data.clone()).unwrap();
344 assert_eq!(m.get(0, 0), 1.0);
345 assert_eq!(m.get(0, 2), 3.0);
346 assert_eq!(m.get(1, 0), 4.0);
347 assert_eq!(m.get(1, 2), 6.0);
348 }
349
350 #[test]
351 fn from_rows_size_mismatch() {
352 let result = DenseMatrix::from_rows(2, 3, vec![1.0; 5]);
353 assert!(result.is_err());
354 }
355
356 #[test]
357 fn from_vec_of_vec_and_back() {
358 let rows = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
359 let m = DenseMatrix::from_vec_of_vec(&rows).unwrap();
360 let back = m.to_vec_of_vec();
361 assert_eq!(back, rows);
362 }
363
364 #[test]
365 fn from_vec_of_vec_inconsistent_cols() {
366 let rows = vec![vec![1.0, 2.0], vec![3.0]];
367 assert!(DenseMatrix::from_vec_of_vec(&rows).is_err());
368 }
369
370 #[test]
371 fn from_vec_of_vec_empty() {
372 assert!(DenseMatrix::from_vec_of_vec(&[]).is_err());
373 }
374
375 #[test]
376 fn set_get_roundtrip() {
377 let mut m = DenseMatrix::zeros(3, 3);
378 m.set(1, 2, 42.0);
379 assert_eq!(m.get(1, 2), 42.0);
380 assert_eq!(m.get(0, 0), 0.0);
382 }
383
384 #[test]
385 fn index_operator() {
386 let mut m = DenseMatrix::zeros(2, 2);
387 m[(0, 1)] = 99.0;
388 assert_eq!(m[(0, 1)], 99.0);
389 }
390
391 #[test]
392 fn row_slice() {
393 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
394 let m = DenseMatrix::from_rows(2, 3, data).unwrap();
395 assert_eq!(m.row(0), &[1.0, 2.0, 3.0]);
396 assert_eq!(m.row(1), &[4.0, 5.0, 6.0]);
397 }
398
399 #[test]
400 fn mul_vec_identity() {
401 let id = DenseMatrix::identity(3);
402 let x = vec![1.0, 2.0, 3.0];
403 let y = id.mul_vec(&x).unwrap();
404 assert_eq!(y, x);
405 }
406
407 #[test]
408 fn mul_vec_known() {
409 let m = DenseMatrix::from_rows(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
411 let y = m.mul_vec(&[1.0, 1.0]).unwrap();
412 assert!((y[0] - 3.0).abs() < 1e-12);
413 assert!((y[1] - 7.0).abs() < 1e-12);
414 }
415
416 #[test]
417 fn mul_vec_size_mismatch() {
418 let m = DenseMatrix::zeros(2, 3);
419 assert!(m.mul_vec(&[1.0, 2.0]).is_err());
420 }
421
422 #[test]
423 fn mul_mat_identity() {
424 let m = DenseMatrix::from_rows(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
425 let id = DenseMatrix::identity(2);
426 let result = m.mul_mat(&id).unwrap();
427 assert_eq!(result, m);
428 }
429
430 #[test]
431 fn mul_mat_known() {
432 let a = DenseMatrix::from_rows(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
434 let b = DenseMatrix::from_rows(2, 2, vec![5.0, 6.0, 7.0, 8.0]).unwrap();
435 let c = a.mul_mat(&b).unwrap();
436 assert!((c.get(0, 0) - 19.0).abs() < 1e-12);
437 assert!((c.get(0, 1) - 22.0).abs() < 1e-12);
438 assert!((c.get(1, 0) - 43.0).abs() < 1e-12);
439 assert!((c.get(1, 1) - 50.0).abs() < 1e-12);
440 }
441
442 #[test]
443 fn mul_mat_size_mismatch() {
444 let a = DenseMatrix::zeros(2, 3);
445 let b = DenseMatrix::zeros(2, 2);
446 assert!(a.mul_mat(&b).is_err());
447 }
448
449 #[test]
450 fn transpose_square() {
451 let m = DenseMatrix::from_rows(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
452 let t = m.transpose();
453 assert_eq!(t.get(0, 0), 1.0);
454 assert_eq!(t.get(0, 1), 3.0);
455 assert_eq!(t.get(1, 0), 2.0);
456 assert_eq!(t.get(1, 1), 4.0);
457 }
458
459 #[test]
460 fn transpose_rectangular() {
461 let m = DenseMatrix::from_rows(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
463 let t = m.transpose();
464 assert_eq!(t.rows(), 3);
465 assert_eq!(t.cols(), 2);
466 assert_eq!(t.get(0, 0), 1.0);
467 assert_eq!(t.get(2, 1), 6.0);
468 }
469
470 #[test]
471 fn transpose_double_is_identity() {
472 let m = DenseMatrix::from_rows(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
473 assert_eq!(m.transpose().transpose(), m);
474 }
475
476 #[test]
477 fn frobenius_norm_identity() {
478 let id = DenseMatrix::identity(4);
480 assert!((id.frobenius_norm() - 2.0).abs() < 1e-12);
481 }
482
483 #[test]
484 fn frobenius_norm_zeros() {
485 assert_eq!(DenseMatrix::zeros(5, 5).frobenius_norm(), 0.0);
486 }
487
488 #[test]
489 fn get_mut_modifies() {
490 let mut m = DenseMatrix::zeros(2, 2);
491 *m.get_mut(1, 0) = 55.0;
492 assert_eq!(m.get(1, 0), 55.0);
493 }
494
495 #[test]
496 fn mul_mat_non_square() {
497 let a = DenseMatrix::from_rows(2, 3, vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0]).unwrap();
499 let b = DenseMatrix::from_rows(
500 3,
501 4,
502 vec![
503 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
504 ],
505 )
506 .unwrap();
507 let c = a.mul_mat(&b).unwrap();
508 assert_eq!(c.rows(), 2);
509 assert_eq!(c.cols(), 4);
510 assert!((c.get(0, 0) - 1.0).abs() < 1e-12);
512 assert!((c.get(1, 0) - 5.0).abs() < 1e-12);
514 }
515}