Skip to main content

hisab/num/
dense_matrix.rs

1//! Row-major dense matrix backed by a flat `Vec<f64>`.
2//!
3//! [`DenseMatrix`] stores an *m × n* matrix as a single contiguous allocation,
4//! which is cache-friendly for row-wise access patterns and avoids the pointer
5//! indirection of `Vec<Vec<f64>>`.
6
7use crate::HisabError;
8
9// ---------------------------------------------------------------------------
10
11/// Row-major dense matrix stored as a flat `Vec<f64>`.
12///
13/// Indexing is `row * cols + col`. All public mutating operations return
14/// `&mut Self` or take `&mut self` — there are no hidden reallocations after
15/// construction.
16///
17/// # Examples
18///
19/// ```
20/// use hisab::num::DenseMatrix;
21///
22/// let mut m = DenseMatrix::zeros(2, 3);
23/// m.set(0, 1, 7.0);
24/// assert_eq!(m.get(0, 1), 7.0);
25/// ```
26#[derive(Debug, Clone, PartialEq)]
27pub struct DenseMatrix {
28    data: Vec<f64>,
29    rows: usize,
30    cols: usize,
31}
32
33impl DenseMatrix {
34    // -----------------------------------------------------------------------
35    // Constructors
36
37    /// Create a zero-filled *rows × cols* matrix.
38    #[must_use]
39    #[inline]
40    pub fn zeros(rows: usize, cols: usize) -> Self {
41        Self {
42            data: vec![0.0; rows * cols],
43            rows,
44            cols,
45        }
46    }
47
48    /// Create an *n × n* identity matrix.
49    #[must_use]
50    #[inline]
51    pub fn identity(n: usize) -> Self {
52        let mut m = Self::zeros(n, n);
53        for i in 0..n {
54            m.data[i * n + i] = 1.0;
55        }
56        m
57    }
58
59    /// Construct from a flat row-major `Vec<f64>`.
60    ///
61    /// # Errors
62    ///
63    /// Returns [`HisabError::InvalidInput`] if `data.len() != rows * cols`.
64    #[must_use = "returns the matrix or an error"]
65    pub fn from_rows(rows: usize, cols: usize, data: Vec<f64>) -> Result<Self, HisabError> {
66        if data.len() != rows * cols {
67            return Err(HisabError::InvalidInput(alloc_msg(
68                "data length",
69                data.len(),
70                rows * cols,
71            )));
72        }
73        Ok(Self { data, rows, cols })
74    }
75
76    /// Construct from a slice of row vectors.
77    ///
78    /// All rows must have the same length.
79    ///
80    /// # Errors
81    ///
82    /// Returns [`HisabError::InvalidInput`] if the input is empty or rows have
83    /// inconsistent lengths.
84    #[must_use = "returns the matrix or an error"]
85    pub fn from_vec_of_vec(v: &[Vec<f64>]) -> Result<Self, HisabError> {
86        if v.is_empty() {
87            return Err(HisabError::InvalidInput("empty row list".into()));
88        }
89        let cols = v[0].len();
90        let rows = v.len();
91        let mut data = Vec::with_capacity(rows * cols);
92        for (r, row) in v.iter().enumerate() {
93            if row.len() != cols {
94                return Err(HisabError::InvalidInput(alloc_msg(
95                    &format!("row {r} length"),
96                    row.len(),
97                    cols,
98                )));
99            }
100            data.extend_from_slice(row);
101        }
102        Ok(Self { data, rows, cols })
103    }
104
105    // -----------------------------------------------------------------------
106    // Conversions
107
108    /// Convert to `Vec<Vec<f64>>` (row-major).
109    #[must_use]
110    pub fn to_vec_of_vec(&self) -> Vec<Vec<f64>> {
111        (0..self.rows)
112            .map(|r| self.data[r * self.cols..(r + 1) * self.cols].to_vec())
113            .collect()
114    }
115
116    // -----------------------------------------------------------------------
117    // Dimensions
118
119    /// Number of rows.
120    #[must_use]
121    #[inline]
122    pub fn rows(&self) -> usize {
123        self.rows
124    }
125
126    /// Number of columns.
127    #[must_use]
128    #[inline]
129    pub fn cols(&self) -> usize {
130        self.cols
131    }
132
133    // -----------------------------------------------------------------------
134    // Element access
135
136    /// Read the element at `(row, col)`.
137    ///
138    /// # Panics
139    ///
140    /// Panics in debug builds if `row >= self.rows || col >= self.cols`.
141    #[must_use]
142    #[inline]
143    pub fn get(&self, row: usize, col: usize) -> f64 {
144        debug_assert!(row < self.rows && col < self.cols, "index out of bounds");
145        self.data[row * self.cols + col]
146    }
147
148    /// Mutable reference to the element at `(row, col)`.
149    ///
150    /// # Panics
151    ///
152    /// Panics in debug builds if `row >= self.rows || col >= self.cols`.
153    #[inline]
154    pub fn get_mut(&mut self, row: usize, col: usize) -> &mut f64 {
155        debug_assert!(row < self.rows && col < self.cols, "index out of bounds");
156        &mut self.data[row * self.cols + col]
157    }
158
159    /// Immutable slice of row `i`.
160    ///
161    /// # Panics
162    ///
163    /// Panics in debug builds if `i >= self.rows`.
164    #[must_use]
165    #[inline]
166    pub fn row(&self, i: usize) -> &[f64] {
167        debug_assert!(i < self.rows, "row index out of bounds");
168        &self.data[i * self.cols..(i + 1) * self.cols]
169    }
170
171    /// Set the element at `(row, col)` to `val`.
172    ///
173    /// # Panics
174    ///
175    /// Panics in debug builds if `row >= self.rows || col >= self.cols`.
176    #[inline]
177    pub fn set(&mut self, row: usize, col: usize, val: f64) {
178        debug_assert!(row < self.rows && col < self.cols, "index out of bounds");
179        self.data[row * self.cols + col] = val;
180    }
181
182    // -----------------------------------------------------------------------
183    // Operations
184
185    /// Matrix-vector multiply: **A** · **x**, returning **y** = **Ax**.
186    ///
187    /// # Errors
188    ///
189    /// Returns [`HisabError::InvalidInput`] if `x.len() != self.cols`.
190    #[must_use = "returns the product vector or an error"]
191    pub fn mul_vec(&self, x: &[f64]) -> Result<Vec<f64>, HisabError> {
192        if x.len() != self.cols {
193            return Err(HisabError::InvalidInput(alloc_msg(
194                "vector length",
195                x.len(),
196                self.cols,
197            )));
198        }
199        let mut out = vec![0.0; self.rows];
200        for (r, dst) in out.iter_mut().enumerate() {
201            let row = &self.data[r * self.cols..(r + 1) * self.cols];
202            // Neumaier-compensated dot product for accuracy.
203            let mut sum = 0.0_f64;
204            let mut comp = 0.0_f64;
205            for c in 0..self.cols {
206                let v = row[c] * x[c];
207                let t = sum + v;
208                if sum.abs() >= v.abs() {
209                    comp += (sum - t) + v;
210                } else {
211                    comp += (v - t) + sum;
212                }
213                sum = t;
214            }
215            *dst = sum + comp;
216        }
217        Ok(out)
218    }
219
220    /// Matrix-matrix multiply: **self** · **other**.
221    ///
222    /// # Errors
223    ///
224    /// Returns [`HisabError::InvalidInput`] if `self.cols != other.rows`.
225    #[must_use = "returns the product matrix or an error"]
226    pub fn mul_mat(&self, other: &DenseMatrix) -> Result<DenseMatrix, HisabError> {
227        if self.cols != other.rows {
228            return Err(HisabError::InvalidInput(alloc_msg(
229                "self.cols",
230                self.cols,
231                other.rows,
232            )));
233        }
234        let rows = self.rows;
235        let cols = other.cols;
236        let inner = self.cols;
237        let mut out = DenseMatrix::zeros(rows, cols);
238        for r in 0..rows {
239            for c in 0..cols {
240                // Neumaier-compensated dot product along the inner dimension.
241                let mut sum = 0.0_f64;
242                let mut comp = 0.0_f64;
243                for k in 0..inner {
244                    let v = self.data[r * inner + k] * other.data[k * cols + c];
245                    let t = sum + v;
246                    if sum.abs() >= v.abs() {
247                        comp += (sum - t) + v;
248                    } else {
249                        comp += (v - t) + sum;
250                    }
251                    sum = t;
252                }
253                out.data[r * cols + c] = sum + comp;
254            }
255        }
256        Ok(out)
257    }
258
259    /// Transpose: returns a new *cols × rows* matrix.
260    #[must_use]
261    pub fn transpose(&self) -> DenseMatrix {
262        let mut out = DenseMatrix::zeros(self.cols, self.rows);
263        for r in 0..self.rows {
264            for c in 0..self.cols {
265                out.data[c * self.rows + r] = self.data[r * self.cols + c];
266            }
267        }
268        out
269    }
270
271    /// Frobenius norm: √(∑ aᵢⱼ²).
272    #[must_use]
273    pub fn frobenius_norm(&self) -> f64 {
274        self.data.iter().map(|&v| v * v).sum::<f64>().sqrt()
275    }
276}
277
278// ---------------------------------------------------------------------------
279// Index / IndexMut
280
281impl std::ops::Index<(usize, usize)> for DenseMatrix {
282    type Output = f64;
283
284    #[inline]
285    fn index(&self, (row, col): (usize, usize)) -> &f64 {
286        debug_assert!(row < self.rows && col < self.cols, "index out of bounds");
287        &self.data[row * self.cols + col]
288    }
289}
290
291impl std::ops::IndexMut<(usize, usize)> for DenseMatrix {
292    #[inline]
293    fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut f64 {
294        debug_assert!(row < self.rows && col < self.cols, "index out of bounds");
295        &mut self.data[row * self.cols + col]
296    }
297}
298
299// ---------------------------------------------------------------------------
300// Internal helpers
301
302/// Build a size-mismatch error message without heap allocation via format!.
303fn alloc_msg(field: &str, got: usize, expected: usize) -> String {
304    let mut s = String::new();
305    let _ = std::fmt::write(
306        &mut s,
307        format_args!("{field}: expected {expected}, got {got}"),
308    );
309    s
310}
311
312// ---------------------------------------------------------------------------
313// Tests
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    #[test]
320    fn zeros_is_all_zero() {
321        let m = DenseMatrix::zeros(3, 4);
322        for r in 0..3 {
323            for c in 0..4 {
324                assert_eq!(m.get(r, c), 0.0);
325            }
326        }
327    }
328
329    #[test]
330    fn identity_diagonal() {
331        let id = DenseMatrix::identity(4);
332        for r in 0..4 {
333            for c in 0..4 {
334                let expected = if r == c { 1.0 } else { 0.0 };
335                assert_eq!(id.get(r, c), expected);
336            }
337        }
338    }
339
340    #[test]
341    fn from_rows_roundtrip() {
342        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
343        let m = DenseMatrix::from_rows(2, 3, data.clone()).unwrap();
344        assert_eq!(m.get(0, 0), 1.0);
345        assert_eq!(m.get(0, 2), 3.0);
346        assert_eq!(m.get(1, 0), 4.0);
347        assert_eq!(m.get(1, 2), 6.0);
348    }
349
350    #[test]
351    fn from_rows_size_mismatch() {
352        let result = DenseMatrix::from_rows(2, 3, vec![1.0; 5]);
353        assert!(result.is_err());
354    }
355
356    #[test]
357    fn from_vec_of_vec_and_back() {
358        let rows = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
359        let m = DenseMatrix::from_vec_of_vec(&rows).unwrap();
360        let back = m.to_vec_of_vec();
361        assert_eq!(back, rows);
362    }
363
364    #[test]
365    fn from_vec_of_vec_inconsistent_cols() {
366        let rows = vec![vec![1.0, 2.0], vec![3.0]];
367        assert!(DenseMatrix::from_vec_of_vec(&rows).is_err());
368    }
369
370    #[test]
371    fn from_vec_of_vec_empty() {
372        assert!(DenseMatrix::from_vec_of_vec(&[]).is_err());
373    }
374
375    #[test]
376    fn set_get_roundtrip() {
377        let mut m = DenseMatrix::zeros(3, 3);
378        m.set(1, 2, 42.0);
379        assert_eq!(m.get(1, 2), 42.0);
380        // Other cells untouched.
381        assert_eq!(m.get(0, 0), 0.0);
382    }
383
384    #[test]
385    fn index_operator() {
386        let mut m = DenseMatrix::zeros(2, 2);
387        m[(0, 1)] = 99.0;
388        assert_eq!(m[(0, 1)], 99.0);
389    }
390
391    #[test]
392    fn row_slice() {
393        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
394        let m = DenseMatrix::from_rows(2, 3, data).unwrap();
395        assert_eq!(m.row(0), &[1.0, 2.0, 3.0]);
396        assert_eq!(m.row(1), &[4.0, 5.0, 6.0]);
397    }
398
399    #[test]
400    fn mul_vec_identity() {
401        let id = DenseMatrix::identity(3);
402        let x = vec![1.0, 2.0, 3.0];
403        let y = id.mul_vec(&x).unwrap();
404        assert_eq!(y, x);
405    }
406
407    #[test]
408    fn mul_vec_known() {
409        // [[1,2],[3,4]] * [1,1] = [3,7]
410        let m = DenseMatrix::from_rows(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
411        let y = m.mul_vec(&[1.0, 1.0]).unwrap();
412        assert!((y[0] - 3.0).abs() < 1e-12);
413        assert!((y[1] - 7.0).abs() < 1e-12);
414    }
415
416    #[test]
417    fn mul_vec_size_mismatch() {
418        let m = DenseMatrix::zeros(2, 3);
419        assert!(m.mul_vec(&[1.0, 2.0]).is_err());
420    }
421
422    #[test]
423    fn mul_mat_identity() {
424        let m = DenseMatrix::from_rows(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
425        let id = DenseMatrix::identity(2);
426        let result = m.mul_mat(&id).unwrap();
427        assert_eq!(result, m);
428    }
429
430    #[test]
431    fn mul_mat_known() {
432        // [[1,2],[3,4]] * [[5,6],[7,8]] = [[19,22],[43,50]]
433        let a = DenseMatrix::from_rows(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
434        let b = DenseMatrix::from_rows(2, 2, vec![5.0, 6.0, 7.0, 8.0]).unwrap();
435        let c = a.mul_mat(&b).unwrap();
436        assert!((c.get(0, 0) - 19.0).abs() < 1e-12);
437        assert!((c.get(0, 1) - 22.0).abs() < 1e-12);
438        assert!((c.get(1, 0) - 43.0).abs() < 1e-12);
439        assert!((c.get(1, 1) - 50.0).abs() < 1e-12);
440    }
441
442    #[test]
443    fn mul_mat_size_mismatch() {
444        let a = DenseMatrix::zeros(2, 3);
445        let b = DenseMatrix::zeros(2, 2);
446        assert!(a.mul_mat(&b).is_err());
447    }
448
449    #[test]
450    fn transpose_square() {
451        let m = DenseMatrix::from_rows(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
452        let t = m.transpose();
453        assert_eq!(t.get(0, 0), 1.0);
454        assert_eq!(t.get(0, 1), 3.0);
455        assert_eq!(t.get(1, 0), 2.0);
456        assert_eq!(t.get(1, 1), 4.0);
457    }
458
459    #[test]
460    fn transpose_rectangular() {
461        // 2×3 → 3×2
462        let m = DenseMatrix::from_rows(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
463        let t = m.transpose();
464        assert_eq!(t.rows(), 3);
465        assert_eq!(t.cols(), 2);
466        assert_eq!(t.get(0, 0), 1.0);
467        assert_eq!(t.get(2, 1), 6.0);
468    }
469
470    #[test]
471    fn transpose_double_is_identity() {
472        let m = DenseMatrix::from_rows(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
473        assert_eq!(m.transpose().transpose(), m);
474    }
475
476    #[test]
477    fn frobenius_norm_identity() {
478        // Identity n×n has n ones, so Frobenius = sqrt(n).
479        let id = DenseMatrix::identity(4);
480        assert!((id.frobenius_norm() - 2.0).abs() < 1e-12);
481    }
482
483    #[test]
484    fn frobenius_norm_zeros() {
485        assert_eq!(DenseMatrix::zeros(5, 5).frobenius_norm(), 0.0);
486    }
487
488    #[test]
489    fn get_mut_modifies() {
490        let mut m = DenseMatrix::zeros(2, 2);
491        *m.get_mut(1, 0) = 55.0;
492        assert_eq!(m.get(1, 0), 55.0);
493    }
494
495    #[test]
496    fn mul_mat_non_square() {
497        // (2×3) * (3×4) = (2×4)
498        let a = DenseMatrix::from_rows(2, 3, vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0]).unwrap();
499        let b = DenseMatrix::from_rows(
500            3,
501            4,
502            vec![
503                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
504            ],
505        )
506        .unwrap();
507        let c = a.mul_mat(&b).unwrap();
508        assert_eq!(c.rows(), 2);
509        assert_eq!(c.cols(), 4);
510        // Row 0 of result = row 0 of b (a row 0 = [1,0,0])
511        assert!((c.get(0, 0) - 1.0).abs() < 1e-12);
512        // Row 1 of result = row 1 of b (a row 1 = [0,1,0])
513        assert!((c.get(1, 0) - 5.0).abs() < 1e-12);
514    }
515}