aprender/primitives/
matrix.rs

1//! Matrix type for 2D numeric data.
2
3use super::Vector;
4use serde::{Deserialize, Serialize};
5
6/// A 2D matrix of floating-point values (row-major storage).
7///
8/// # Examples
9///
10/// ```
11/// use aprender::primitives::Matrix;
12///
13/// let m = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("data length matches rows * cols");
14/// assert_eq!(m.shape(), (2, 3));
15/// ```
16#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
17pub struct Matrix<T> {
18    data: Vec<T>,
19    rows: usize,
20    cols: usize,
21}
22
23impl<T: Copy> Matrix<T> {
24    /// Creates a new matrix from a vector of data.
25    ///
26    /// # Errors
27    ///
28    /// Returns an error if data length doesn't match rows * cols.
29    pub fn from_vec(rows: usize, cols: usize, data: Vec<T>) -> Result<Self, &'static str> {
30        if data.len() != rows * cols {
31            return Err("Data length must equal rows * cols");
32        }
33        Ok(Self { data, rows, cols })
34    }
35
36    /// Returns the shape as (rows, cols).
37    #[must_use]
38    pub fn shape(&self) -> (usize, usize) {
39        (self.rows, self.cols)
40    }
41
42    /// Returns the number of rows.
43    #[must_use]
44    pub fn n_rows(&self) -> usize {
45        self.rows
46    }
47
48    /// Returns the number of columns.
49    #[must_use]
50    pub fn n_cols(&self) -> usize {
51        self.cols
52    }
53
54    /// Gets element at (row, col).
55    ///
56    /// # Panics
57    ///
58    /// Panics if indices are out of bounds.
59    #[must_use]
60    pub fn get(&self, row: usize, col: usize) -> T {
61        self.data[row * self.cols + col]
62    }
63
64    /// Sets element at (row, col).
65    ///
66    /// # Panics
67    ///
68    /// Panics if indices are out of bounds.
69    pub fn set(&mut self, row: usize, col: usize, value: T) {
70        self.data[row * self.cols + col] = value;
71    }
72
73    /// Returns a row as a Vector.
74    #[must_use]
75    pub fn row(&self, row_idx: usize) -> Vector<T> {
76        let start = row_idx * self.cols;
77        let end = start + self.cols;
78        Vector::from_slice(&self.data[start..end])
79    }
80
81    /// Returns a column as a Vector.
82    #[must_use]
83    pub fn column(&self, col_idx: usize) -> Vector<T> {
84        let data: Vec<T> = (0..self.rows)
85            .map(|row| self.data[row * self.cols + col_idx])
86            .collect();
87        Vector::from_vec(data)
88    }
89
90    /// Returns the underlying data as a slice.
91    #[must_use]
92    pub fn as_slice(&self) -> &[T] {
93        &self.data
94    }
95}
96
97impl Matrix<f32> {
98    /// Creates a matrix of zeros.
99    #[must_use]
100    pub fn zeros(rows: usize, cols: usize) -> Self {
101        Self {
102            data: vec![0.0; rows * cols],
103            rows,
104            cols,
105        }
106    }
107
108    /// Creates a matrix of ones.
109    #[must_use]
110    pub fn ones(rows: usize, cols: usize) -> Self {
111        Self {
112            data: vec![1.0; rows * cols],
113            rows,
114            cols,
115        }
116    }
117
118    /// Creates an identity matrix.
119    #[must_use]
120    pub fn eye(n: usize) -> Self {
121        let mut data = vec![0.0; n * n];
122        for i in 0..n {
123            data[i * n + i] = 1.0;
124        }
125        Self {
126            data,
127            rows: n,
128            cols: n,
129        }
130    }
131
132    /// Transposes the matrix.
133    #[must_use]
134    pub fn transpose(&self) -> Self {
135        let mut data = vec![0.0; self.rows * self.cols];
136        for i in 0..self.rows {
137            for j in 0..self.cols {
138                data[j * self.rows + i] = self.data[i * self.cols + j];
139            }
140        }
141        Self {
142            data,
143            rows: self.cols,
144            cols: self.rows,
145        }
146    }
147
148    /// Matrix-matrix multiplication.
149    ///
150    /// # Errors
151    ///
152    /// Returns an error if dimensions don't match.
153    pub fn matmul(&self, other: &Self) -> Result<Self, &'static str> {
154        if self.cols != other.rows {
155            return Err("Matrix dimensions don't match for multiplication");
156        }
157
158        let mut result = vec![0.0; self.rows * other.cols];
159        for i in 0..self.rows {
160            for j in 0..other.cols {
161                let mut sum = 0.0;
162                for k in 0..self.cols {
163                    sum += self.get(i, k) * other.get(k, j);
164                }
165                result[i * other.cols + j] = sum;
166            }
167        }
168
169        Ok(Self {
170            data: result,
171            rows: self.rows,
172            cols: other.cols,
173        })
174    }
175
176    /// Matrix-vector multiplication.
177    ///
178    /// # Errors
179    ///
180    /// Returns an error if dimensions don't match.
181    pub fn matvec(&self, vec: &Vector<f32>) -> Result<Vector<f32>, &'static str> {
182        if self.cols != vec.len() {
183            return Err("Matrix columns must match vector length");
184        }
185
186        let result: Vec<f32> = (0..self.rows)
187            .map(|i| {
188                let row = self.row(i);
189                row.dot(vec)
190            })
191            .collect();
192
193        Ok(Vector::from_vec(result))
194    }
195
196    /// Adds another matrix element-wise.
197    ///
198    /// # Errors
199    ///
200    /// Returns an error if dimensions don't match.
201    pub fn add(&self, other: &Self) -> Result<Self, &'static str> {
202        if self.rows != other.rows || self.cols != other.cols {
203            return Err("Matrix dimensions must match for addition");
204        }
205
206        let data: Vec<f32> = self
207            .data
208            .iter()
209            .zip(other.data.iter())
210            .map(|(a, b)| a + b)
211            .collect();
212
213        Ok(Self {
214            data,
215            rows: self.rows,
216            cols: self.cols,
217        })
218    }
219
220    /// Subtracts another matrix element-wise.
221    ///
222    /// # Errors
223    ///
224    /// Returns an error if dimensions don't match.
225    pub fn sub(&self, other: &Self) -> Result<Self, &'static str> {
226        if self.rows != other.rows || self.cols != other.cols {
227            return Err("Matrix dimensions must match for subtraction");
228        }
229
230        let data: Vec<f32> = self
231            .data
232            .iter()
233            .zip(other.data.iter())
234            .map(|(a, b)| a - b)
235            .collect();
236
237        Ok(Self {
238            data,
239            rows: self.rows,
240            cols: self.cols,
241        })
242    }
243
244    /// Multiplies each element by a scalar.
245    #[must_use]
246    pub fn mul_scalar(&self, scalar: f32) -> Self {
247        Self {
248            data: self.data.iter().map(|x| x * scalar).collect(),
249            rows: self.rows,
250            cols: self.cols,
251        }
252    }
253
254    /// Solves the linear system Ax = b using Cholesky decomposition.
255    ///
256    /// The matrix must be symmetric positive definite.
257    ///
258    /// # Errors
259    ///
260    /// Returns an error if the matrix is not square or not positive definite.
261    pub fn cholesky_solve(&self, b: &Vector<f32>) -> Result<Vector<f32>, &'static str> {
262        if self.rows != self.cols {
263            return Err("Matrix must be square for Cholesky decomposition");
264        }
265        if self.rows != b.len() {
266            return Err("Matrix rows must match vector length");
267        }
268
269        let n = self.rows;
270
271        // Cholesky decomposition: A = L * L^T
272        let mut l = vec![0.0; n * n];
273
274        for i in 0..n {
275            for j in 0..=i {
276                let mut sum = 0.0;
277
278                if i == j {
279                    for k in 0..j {
280                        sum += l[j * n + k] * l[j * n + k];
281                    }
282                    let diag = self.get(j, j) - sum;
283                    if diag <= 0.0 {
284                        return Err("Matrix is not positive definite");
285                    }
286                    l[j * n + j] = diag.sqrt();
287                } else {
288                    for k in 0..j {
289                        sum += l[i * n + k] * l[j * n + k];
290                    }
291                    l[i * n + j] = (self.get(i, j) - sum) / l[j * n + j];
292                }
293            }
294        }
295
296        // Forward substitution: L * y = b
297        let mut y = vec![0.0; n];
298        for i in 0..n {
299            let mut sum = 0.0;
300            for j in 0..i {
301                sum += l[i * n + j] * y[j];
302            }
303            y[i] = (b[i] - sum) / l[i * n + i];
304        }
305
306        // Backward substitution: L^T * x = y
307        let mut x = vec![0.0; n];
308        for i in (0..n).rev() {
309            let mut sum = 0.0;
310            for j in (i + 1)..n {
311                sum += l[j * n + i] * x[j];
312            }
313            x[i] = (y[i] - sum) / l[i * n + i];
314        }
315
316        Ok(Vector::from_vec(x))
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323
324    #[test]
325    fn test_from_vec() {
326        let m = Matrix::from_vec(2, 3, vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0])
327            .expect("test data has correct dimensions: 2*3=6 elements");
328        assert_eq!(m.shape(), (2, 3));
329        assert!((m.get(0, 0) - 1.0).abs() < 1e-6);
330        assert!((m.get(1, 2) - 6.0).abs() < 1e-6);
331    }
332
333    #[test]
334    fn test_from_vec_error() {
335        let result = Matrix::from_vec(2, 3, vec![1.0_f32, 2.0, 3.0]);
336        assert!(result.is_err());
337    }
338
339    #[test]
340    fn test_zeros() {
341        let m = Matrix::<f32>::zeros(2, 3);
342        assert_eq!(m.shape(), (2, 3));
343        assert!(m.as_slice().iter().all(|&x| x == 0.0));
344    }
345
346    #[test]
347    fn test_eye() {
348        let m = Matrix::<f32>::eye(3);
349        assert!((m.get(0, 0) - 1.0).abs() < 1e-6);
350        assert!((m.get(1, 1) - 1.0).abs() < 1e-6);
351        assert!((m.get(2, 2) - 1.0).abs() < 1e-6);
352        assert!((m.get(0, 1) - 0.0).abs() < 1e-6);
353    }
354
355    #[test]
356    fn test_transpose() {
357        let m = Matrix::from_vec(2, 3, vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0])
358            .expect("test data has correct dimensions: 2*3=6 elements");
359        let t = m.transpose();
360        assert_eq!(t.shape(), (3, 2));
361        assert!((t.get(0, 0) - 1.0).abs() < 1e-6);
362        assert!((t.get(0, 1) - 4.0).abs() < 1e-6);
363        assert!((t.get(2, 1) - 6.0).abs() < 1e-6);
364    }
365
366    #[test]
367    fn test_row() {
368        let m = Matrix::from_vec(2, 3, vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0])
369            .expect("test data has correct dimensions: 2*3=6 elements");
370        let row = m.row(1);
371        assert_eq!(row.len(), 3);
372        assert!((row[0] - 4.0).abs() < 1e-6);
373        assert!((row[1] - 5.0).abs() < 1e-6);
374        assert!((row[2] - 6.0).abs() < 1e-6);
375    }
376
377    #[test]
378    fn test_column() {
379        let m = Matrix::from_vec(2, 3, vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0])
380            .expect("test data has correct dimensions: 2*3=6 elements");
381        let col = m.column(1);
382        assert_eq!(col.len(), 2);
383        assert!((col[0] - 2.0).abs() < 1e-6);
384        assert!((col[1] - 5.0).abs() < 1e-6);
385    }
386
387    #[test]
388    fn test_matmul() {
389        // 2x3 * 3x2 = 2x2
390        let a = Matrix::from_vec(2, 3, vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0])
391            .expect("test data has correct dimensions: 2*3=6 elements");
392        let b = Matrix::from_vec(3, 2, vec![7.0_f32, 8.0, 9.0, 10.0, 11.0, 12.0])
393            .expect("test data has correct dimensions: 3*2=6 elements");
394        let c = a
395            .matmul(&b)
396            .expect("matrix dimensions are compatible for multiplication: 2x3 * 3x2");
397
398        assert_eq!(c.shape(), (2, 2));
399        // c[0,0] = 1*7 + 2*9 + 3*11 = 7 + 18 + 33 = 58
400        assert!((c.get(0, 0) - 58.0).abs() < 1e-6);
401        // c[0,1] = 1*8 + 2*10 + 3*12 = 8 + 20 + 36 = 64
402        assert!((c.get(0, 1) - 64.0).abs() < 1e-6);
403    }
404
405    #[test]
406    fn test_matmul_dimension_error() {
407        let a = Matrix::from_vec(2, 3, vec![1.0_f32; 6])
408            .expect("test data has correct dimensions: 2*3=6 elements");
409        let b = Matrix::from_vec(2, 2, vec![1.0_f32; 4])
410            .expect("test data has correct dimensions: 2*2=4 elements");
411        assert!(a.matmul(&b).is_err());
412    }
413
414    #[test]
415    fn test_matvec() {
416        let m = Matrix::from_vec(2, 3, vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0])
417            .expect("test data has correct dimensions: 2*3=6 elements");
418        let v = Vector::from_slice(&[1.0_f32, 2.0, 3.0]);
419        let result = m
420            .matvec(&v)
421            .expect("matrix columns match vector length: both 3");
422
423        assert_eq!(result.len(), 2);
424        // result[0] = 1*1 + 2*2 + 3*3 = 14
425        assert!((result[0] - 14.0).abs() < 1e-6);
426        // result[1] = 4*1 + 5*2 + 6*3 = 32
427        assert!((result[1] - 32.0).abs() < 1e-6);
428    }
429
430    #[test]
431    fn test_add() {
432        let a = Matrix::from_vec(2, 2, vec![1.0_f32, 2.0, 3.0, 4.0])
433            .expect("test data has correct dimensions: 2*2=4 elements");
434        let b = Matrix::from_vec(2, 2, vec![5.0_f32, 6.0, 7.0, 8.0])
435            .expect("test data has correct dimensions: 2*2=4 elements");
436        let c = a.add(&b).expect("both matrices have same dimensions: 2x2");
437
438        assert!((c.get(0, 0) - 6.0).abs() < 1e-6);
439        assert!((c.get(1, 1) - 12.0).abs() < 1e-6);
440    }
441
442    #[test]
443    fn test_add_dimension_mismatch() {
444        // Test that mismatched dimensions are detected (catches || → && mutation)
445        let a = Matrix::from_vec(2, 2, vec![1.0_f32; 4])
446            .expect("test data has correct dimensions: 2*2=4 elements");
447        let b = Matrix::from_vec(3, 2, vec![1.0_f32; 6])
448            .expect("test data has correct dimensions: 3*2=6 elements");
449        assert!(a.add(&b).is_err());
450
451        let c = Matrix::from_vec(2, 3, vec![1.0_f32; 6])
452            .expect("test data has correct dimensions: 2*3=6 elements");
453        assert!(a.add(&c).is_err());
454    }
455
456    #[test]
457    fn test_sub() {
458        // Test element-wise subtraction
459        let a = Matrix::from_vec(2, 2, vec![10.0_f32, 8.0, 6.0, 12.0])
460            .expect("test data has correct dimensions: 2*2=4 elements");
461        let b = Matrix::from_vec(2, 2, vec![4.0_f32, 3.0, 2.0, 7.0])
462            .expect("test data has correct dimensions: 2*2=4 elements");
463        let c = a.sub(&b).expect("both matrices have same dimensions: 2x2");
464
465        // Verify all elements: a[i] - b[i]
466        assert!((c.get(0, 0) - 6.0).abs() < 1e-6); // 10 - 4 = 6
467        assert!((c.get(0, 1) - 5.0).abs() < 1e-6); // 8 - 3 = 5
468        assert!((c.get(1, 0) - 4.0).abs() < 1e-6); // 6 - 2 = 4
469        assert!((c.get(1, 1) - 5.0).abs() < 1e-6); // 12 - 7 = 5
470    }
471
472    #[test]
473    fn test_sub_dimension_mismatch_rows() {
474        // Test that mismatched rows are detected
475        let a = Matrix::from_vec(2, 2, vec![1.0_f32; 4])
476            .expect("test data has correct dimensions: 2*2=4 elements");
477        let b = Matrix::from_vec(3, 2, vec![1.0_f32; 6])
478            .expect("test data has correct dimensions: 3*2=6 elements");
479        assert!(a.sub(&b).is_err());
480    }
481
482    #[test]
483    fn test_sub_dimension_mismatch_cols() {
484        // Test that mismatched columns are detected
485        let a = Matrix::from_vec(2, 2, vec![1.0_f32; 4])
486            .expect("test data has correct dimensions: 2*2=4 elements");
487        let b = Matrix::from_vec(2, 3, vec![1.0_f32; 6])
488            .expect("test data has correct dimensions: 2*3=6 elements");
489        assert!(a.sub(&b).is_err());
490    }
491
492    #[test]
493    fn test_cholesky_solve() {
494        // Solve A*x = b where A is symmetric positive definite
495        // A = [[4, 2], [2, 3]]
496        // b = [1, 2]
497        // Solution: x = [-0.125, 0.75]
498        let a = Matrix::from_vec(2, 2, vec![4.0_f32, 2.0, 2.0, 3.0])
499            .expect("test data has correct dimensions: 2*2=4 elements");
500        let b = Vector::from_slice(&[1.0_f32, 2.0]);
501        let x = a
502            .cholesky_solve(&b)
503            .expect("matrix is square, symmetric positive definite, and vector matches size");
504
505        assert_eq!(x.len(), 2);
506        assert!((x[0] - (-0.125)).abs() < 1e-5);
507        assert!((x[1] - 0.75).abs() < 1e-5);
508    }
509
510    #[test]
511    fn test_cholesky_solve_3x3() {
512        // A = [[4, 12, -16], [12, 37, -43], [-16, -43, 98]]
513        // b = [1, 2, 3]
514        let a = Matrix::from_vec(
515            3,
516            3,
517            vec![4.0_f32, 12.0, -16.0, 12.0, 37.0, -43.0, -16.0, -43.0, 98.0],
518        )
519        .expect("test data has correct dimensions: 3*3=9 elements");
520        let b = Vector::from_slice(&[1.0_f32, 2.0, 3.0]);
521        let x = a
522            .cholesky_solve(&b)
523            .expect("matrix is square, symmetric positive definite, and vector matches size");
524
525        // Verify A*x ≈ b
526        let result = a
527            .matvec(&x)
528            .expect("matrix columns match vector length: both 3");
529        for i in 0..3 {
530            assert!((result[i] - b[i]).abs() < 1e-4);
531        }
532    }
533
534    #[test]
535    fn test_cholesky_solve_strict() {
536        // Stricter test to catch arithmetic mutations in cholesky_solve
537        // Uses a 4x4 SPD matrix to exercise all accumulation loops
538        // A = [[4, 2, 1, 1],
539        //      [2, 5, 2, 1],
540        //      [1, 2, 6, 2],
541        //      [1, 1, 2, 7]]
542        // This is symmetric positive definite with non-trivial decomposition
543        let a = Matrix::from_vec(
544            4,
545            4,
546            vec![
547                4.0_f32, 2.0, 1.0, 1.0, 2.0, 5.0, 2.0, 1.0, 1.0, 2.0, 6.0, 2.0, 1.0, 1.0, 2.0, 7.0,
548            ],
549        )
550        .expect("test data has correct dimensions: 4*4=16 elements");
551        let b = Vector::from_slice(&[1.0_f32, 2.0, 3.0, 4.0]);
552        let x = a
553            .cholesky_solve(&b)
554            .expect("matrix is square, symmetric positive definite, and vector matches size");
555
556        // Verify A*x = b with very tight tolerance
557        let result = a
558            .matvec(&x)
559            .expect("matrix columns match vector length: both 4");
560        for i in 0..4 {
561            assert!(
562                (result[i] - b[i]).abs() < 1e-5,
563                "Failed at index {}: expected {}, got {}",
564                i,
565                b[i],
566                result[i]
567            );
568        }
569
570        // Also test a 3x3 case with known solution
571        // A = [[9, 3, 3], [3, 5, 1], [3, 1, 4]], b = [15, 9, 8] => x = [1, 1, 1]
572        let a3 = Matrix::from_vec(3, 3, vec![9.0_f32, 3.0, 3.0, 3.0, 5.0, 1.0, 3.0, 1.0, 4.0])
573            .expect("test data has correct dimensions: 3*3=9 elements");
574        let b3 = Vector::from_slice(&[15.0_f32, 9.0, 8.0]);
575        let x3 = a3
576            .cholesky_solve(&b3)
577            .expect("matrix is square, symmetric positive definite, and vector matches size");
578
579        // Verify exact solution [1, 1, 1] with element-by-element check
580        assert!((x3[0] - 1.0).abs() < 1e-6);
581        assert!((x3[1] - 1.0).abs() < 1e-6);
582        assert!((x3[2] - 1.0).abs() < 1e-6);
583
584        // Additional verification: check that A*x3 = b3 with strict tolerance
585        let verify3 = a3
586            .matvec(&x3)
587            .expect("matrix columns match vector length: both 3");
588        assert!((verify3[0] - 15.0).abs() < 1e-6);
589        assert!((verify3[1] - 9.0).abs() < 1e-6);
590        assert!((verify3[2] - 8.0).abs() < 1e-6);
591    }
592
593    #[test]
594    fn test_mul_scalar() {
595        let m = Matrix::from_vec(2, 2, vec![1.0_f32, 2.0, 3.0, 4.0])
596            .expect("test data has correct dimensions: 2*2=4 elements");
597        let result = m.mul_scalar(2.0);
598        assert!((result.get(0, 0) - 2.0).abs() < 1e-6);
599        assert!((result.get(1, 1) - 8.0).abs() < 1e-6);
600    }
601
602    #[test]
603    fn test_set() {
604        let mut m = Matrix::<f32>::zeros(2, 2);
605        m.set(0, 1, 5.0);
606        assert!((m.get(0, 1) - 5.0).abs() < 1e-6);
607    }
608}