heapless_matrix/
lib.rs

1#![no_std]
2use core::marker::Sized;
3use core::result::Result::{Err, Ok};
4use heapless::Vec;
5use matrix_trait::{
6    IsSquareMatrix, IsVectorCol, MatrixConcat, MatrixTrait, SquareMatrix, VectorCol,
7};
8
9pub mod matrix_trait;
10
11pub mod matrix_ops;
12
13use core::clone::Clone;
14use core::iter::Iterator;
15
16#[derive(Debug, Clone)]
17pub struct Matrix<const ROWS: usize, const COLS: usize> {
18    data: Vec<Vec<f64, COLS>, ROWS>,
19}
20
21impl<const ROWS: usize, const COLS: usize> Matrix<ROWS, COLS> {
22    #[allow(dead_code)]
23    fn iter(&self) -> core::slice::Iter<'_, Vec<f64, COLS>> {
24        self.data.iter()
25    }
26
27    #[allow(dead_code)]
28    fn iter_mut(&mut self) -> core::slice::IterMut<'_, Vec<f64, COLS>> {
29        self.data.iter_mut()
30    }
31}
32
33impl<const ROWS: usize, const COLS: usize> MatrixTrait<ROWS, COLS> for Matrix<ROWS, COLS> {
34    type TransposeType = Matrix<COLS, ROWS>;
35
36    fn new() -> Result<Self, &'static str>
37    where
38        Self: Sized,
39    {
40        if ROWS < 1 || COLS < 1 {
41            return Err("Matrix dimensions are invalid");
42        }
43        let mut vec: Vec<Vec<f64, COLS>, ROWS> = Vec::new();
44        for _ in 0..ROWS {
45            let mut helper: Vec<f64, COLS> = Vec::new();
46            for _ in 0..COLS {
47                helper.push(0.).unwrap();
48            }
49            vec.push(helper).unwrap();
50        }
51        Ok(Matrix { data: vec })
52    }
53
54    fn eye() -> Result<Self, &'static str>
55    where
56        Self: Sized,
57    {
58        if ROWS < 1 || COLS < 1 {
59            return Err("Matrix dimensions are invalid");
60        }
61
62        let mut mat: Vec<Vec<f64, COLS>, ROWS> = Vec::new();
63
64        for i in 0..ROWS {
65            let mut row: Vec<f64, COLS> = Vec::new();
66            for j in 0..COLS {
67                if i == j {
68                    row.push(1.).unwrap();
69                } else {
70                    row.push(0.).unwrap();
71                }
72            }
73
74            mat.push(row).unwrap();
75        }
76
77        Ok(Matrix { data: mat })
78    }
79    /// Function used to create a heapless matrix from a 2D array
80    /// Example:
81    /// ```
82    /// use heapless_matrix::{matrix_trait::MatrixTrait as _, Matrix};
83    /// let data = [[1., 2.],
84    ///             [3., 4.]];
85    /// let mat: Matrix<2, 2> = Matrix::from_vector(data).unwrap();
86    ///
87    /// for i in 0..2 {
88    ///     for j in 0..2 {
89    ///         assert_eq!(data[i][j], mat[i][j]);
90    ///     }
91    /// }
92    /// ```
93    fn from_vector(data: [[f64; COLS]; ROWS]) -> Result<Self, &'static str>
94    where
95        Self: Sized,
96    {
97        let mut array_data: Vec<Vec<f64, COLS>, ROWS> = Vec::new();
98        for row in data.iter() {
99            let mut row_data: Vec<f64, COLS> = Vec::new();
100            for &value in row.iter() {
101                row_data.push(value).unwrap();
102            }
103
104            array_data.push(row_data).unwrap();
105        }
106
107        Ok(Matrix { data: array_data })
108    }
109
110    fn to_double(&self) -> Result<f64, &'static str> {
111        if ROWS != 1 && COLS != 1 {
112            return Err("The matrix does not have dimensions 2x2");
113        }
114        Ok(self[0][0])
115    }
116
117    fn transpose(&self) -> Self::TransposeType {
118        let mut transpose: Matrix<COLS, ROWS> = Matrix::new().unwrap();
119
120        for i in 0..ROWS {
121            for j in 0..COLS {
122                transpose[j][i] = self[i][j]
123            }
124        }
125        transpose
126    }
127
128    fn swap_rows(&mut self, row1: usize, row2: usize) -> Result<(), &'static str> {
129        if row1 >= ROWS || row2 >= ROWS {
130            return Err("Row index out of bounds");
131        }
132        self.data.swap(row1, row2);
133        Ok(())
134    }
135
136    fn swap_cols(&mut self, col1: usize, col2: usize) -> Result<(), &'static str> {
137        if col1 >= COLS || col2 >= COLS {
138            return Err("Column indexes are outof bounds");
139        }
140
141        for i in 0..ROWS {
142            let help = self[i][col1];
143            self[i][col1] = self[i][col2];
144            self[i][col2] = help;
145        }
146        Ok(())
147    }
148
149    fn sub_matrix<const NEW_ROWS: usize, const NEW_COLS: usize>(
150        &self,
151        row_start: usize,
152        col_start: usize,
153    ) -> Result<Matrix<NEW_ROWS, NEW_COLS>, &'static str> {
154        if row_start + NEW_ROWS > ROWS || col_start + NEW_COLS > COLS {
155            return Err("Submatrix dimensions are out of bounds");
156        }
157
158        let mut sub_data: Matrix<NEW_ROWS, NEW_COLS> = Matrix::new().unwrap();
159        for i in 0..NEW_ROWS {
160            for j in 0..NEW_COLS {
161                sub_data[i][j] = self[row_start + i][col_start + j];
162            }
163        }
164
165        Ok(sub_data)
166    }
167
168    fn vector_to_row(elems: [f64; ROWS]) -> Result<Matrix<ROWS, 1>, &'static str> {
169        let mut vec_elems = [[0.; 1]; ROWS];
170        for i in 0..ROWS {
171            vec_elems[i][0] = elems[i];
172        }
173
174        Matrix::<ROWS, 1>::from_vector(vec_elems)
175    }
176
177    fn pinv<const DOUBLE: usize>(&self) -> Result<Matrix<COLS, ROWS>, &'static str> {
178        if ROWS > COLS {
179            // Left pseuduinversion
180            let mat = self.transpose() * self.clone();
181            let mat = mat.inv::<DOUBLE>()?;
182            Ok(mat * self.transpose())
183        } else {
184            // Right pseuduinversion
185            let mat = self.clone() * self.transpose();
186            let mat = mat.inv::<DOUBLE>()?;
187            Ok(self.transpose() * mat)
188        }
189    }
190}
191
192impl<const ROWS: usize, const COLS: usize> MatrixConcat<ROWS, COLS> for Matrix<ROWS, COLS> {
193    ///
194    /// Example:
195    /// ```
196    /// use heapless_matrix::{matrix_trait::{MatrixTrait, MatrixConcat}, Matrix};
197    /// let mat1: Matrix<2, 2> = Matrix::eye().unwrap();
198    /// let mat2: Matrix<2, 2> = Matrix::new().unwrap();
199    /// let mat3 = mat1.clone().x_concat::<2, 4>(mat2.clone()).unwrap();
200    /// for i in 0..2 {
201    ///     for j in 0..2 {
202    ///         assert_eq!(mat3[i][j], mat1[i][j]);
203    ///     }
204    ///     for j in 0..2 {
205    ///         assert_eq!(mat3[i][j + 2], mat2[i][j]);
206    ///     }
207    /// }
208    /// ```
209    fn x_concat<const RHS_COLS: usize, const NEW_COLS: usize>(
210        self,
211        rhs: Matrix<ROWS, RHS_COLS>,
212    ) -> Result<Matrix<ROWS, NEW_COLS>, &'static str> {
213        if RHS_COLS + COLS != NEW_COLS {
214            return Err(
215                "The number of new columns is not equal to the sum of columns of the matrices",
216            );
217        }
218
219        let mut new_data = [[0.; NEW_COLS]; ROWS];
220
221        for i in 0..ROWS {
222            for j in 0..COLS {
223                new_data[i][j] = self.data[i][j];
224            }
225            for j in 0..RHS_COLS {
226                new_data[i][COLS + j] = rhs.data[i][j];
227            }
228        }
229
230        Matrix::from_vector(new_data)
231    }
232
233    fn y_concat<const RHS_ROWS: usize, const NEW_ROWS: usize>(
234        self,
235        rhs: Matrix<RHS_ROWS, COLS>,
236    ) -> Result<Matrix<NEW_ROWS, COLS>, &'static str> {
237        if ROWS + RHS_ROWS != NEW_ROWS {
238            return Err(
239                "The number of new rows is not equal to the sum of rows of the two matrices",
240            );
241        }
242
243        let mut new_data = [[0.; COLS]; NEW_ROWS];
244        for i in 0..COLS {
245            for j in 0..ROWS {
246                new_data[j][i] = self[j][i];
247            }
248            for j in 0..RHS_ROWS {
249                new_data[ROWS + j][i] = rhs[j][i];
250            }
251        }
252
253        Matrix::from_vector(new_data)
254    }
255}
256
257impl<const N: usize> IsSquareMatrix for Matrix<N, N> {}
258
259impl<const N: usize> SquareMatrix<N> for Matrix<N, N> {
260    fn det(&self) -> f64 {
261        let mut copy = self.clone();
262        for j in 0..(N - 1) {
263            for i in ((j + 1)..N).rev() {
264                // println!("i: {}, j: {}", i, j);
265                if copy[j][j] == 0. && copy[i][j] == 0. {
266                    return 0.;
267                } else if copy[j][j] == 0. {
268                    copy.swap_rows(j, i).unwrap();
269                }
270                let div = copy[i][j] / copy[j][j];
271                for k in 0..N {
272                    // println!("{}, {}, {}", copy[i][j],copy[0][j] , copy[0][k]);
273                    copy[i][k] -= div * copy[j][k];
274                }
275                // println!("{:#?}", copy);
276            }
277        }
278        let mut det = 1.;
279        for i in 0..N {
280            det *= copy[i][i];
281        }
282        det
283    }
284
285    fn inv<const DOUBLE_COLS: usize>(&self) -> Result<Matrix<N, N>, &'static str> {
286        let mat: Matrix<N, N> = Matrix::eye().unwrap();
287        let mut mat = self.clone().x_concat::<N, DOUBLE_COLS>(mat).unwrap();
288
289        for j in 0..(N - 1) {
290            for i in ((j + 1)..N).rev() {
291                // println!("i: {}, j: {}", i, j);
292                if mat[j][j] == 0. && mat[i][j] == 0. {
293                    return Err("Matrix cannot be inverted");
294                } else if mat[j][j] == 0. {
295                    mat.swap_rows(j, i).unwrap();
296                }
297                let div = mat[i][j] / mat[j][j];
298                for k in 0..DOUBLE_COLS {
299                    // println!("{}, {}, {}", mat[i][j],mat[0][j] , mat[0][k]);
300                    mat[i][k] -= div * mat[j][k];
301                }
302                // println!("{:#?}", mat);
303            }
304        }
305
306        for j in (1..N).rev() {
307            for i in 0..j {
308                // println!("i: {}, j: {}", i, j);
309                if mat[j][j] == 0. && mat[i][j] == 0. {
310                    return Err("Matrix cannot be inverted");
311                } else if mat[j][j] == 0. {
312                    mat.swap_rows(j, i).unwrap();
313                }
314                let div = mat[i][j] / mat[j][j];
315                for k in 0..DOUBLE_COLS {
316                    // println!("{}, {}, {}", mat[i][j],mat[0][j] , mat[0][k]);
317                    mat[i][k] -= div * mat[j][k];
318                }
319                // println!("{:#?}", mat);
320            }
321        }
322
323        for i in 0..N {
324            let div = mat[i][i];
325            for j in N..DOUBLE_COLS {
326                mat[i][j] /= div;
327            }
328        }
329
330        mat.sub_matrix::<N, N>(0, N)
331    }
332
333    fn pow(&self, n: usize) -> Matrix<N, N> {
334        let mut copy: Matrix<N, N> = Matrix::eye().unwrap();
335        for _ in 0..n {
336            copy *= self;
337        }
338        copy
339    }
340
341    fn diag(elems: [f64; N]) -> Result<Matrix<N, N>, &'static str> {
342        let mut vec_elem = [[0.; N]; N];
343        for i in 0..N {
344            vec_elem[i][i] = elems[i];
345        }
346
347        Matrix::from_vector(vec_elem)
348    }
349}
350
351impl<const ROWS: usize> IsVectorCol for Matrix<ROWS, 1> {}
352
353impl<const ROWS: usize> VectorCol<ROWS> for Matrix<ROWS, 1> {
354    fn shift_data(&mut self, data: f64) {
355        for i in (1..ROWS).rev() {
356            self[i][0] = self[i - 1][0];
357        }
358        self[0][0] = data;
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365    use crate::matrix_ops::approx_equal;
366
367    #[test]
368    fn succes_creation() {
369        type Mat3x3 = Matrix<3, 3>;
370        match Mat3x3::new() {
371            Ok(_) => assert!(true),
372            Err(_) => assert!(false),
373        }
374    }
375
376    #[test]
377    fn fail_creation_1() {
378        type Mat3x3 = Matrix<0, 3>;
379        match Mat3x3::new() {
380            Ok(_) => assert!(false),
381            Err(_) => assert!(true),
382        };
383    }
384
385    #[test]
386    fn indexing_elems() {
387        let mut mat2x2: Matrix<2, 2> = Matrix::new().unwrap();
388
389        for i in 0..2 {
390            for j in 0..2 {
391                mat2x2[i][j] = i as f64 + j as f64;
392            }
393        }
394
395        for i in 0..2 {
396            for j in 0..2 {
397                assert_eq!(i as f64 + j as f64, mat2x2[i][j]);
398            }
399        }
400    }
401
402    #[test]
403    fn iterating_elems() {
404        let mut mat3x3: Matrix<1, 3> = Matrix::new().unwrap();
405
406        for vec in mat3x3.iter_mut() {
407            for elem in vec.iter_mut() {
408                *elem = 3.;
409            }
410        }
411        for vec in mat3x3.iter() {
412            for elem in vec.iter() {
413                assert_eq!(3., *elem);
414            }
415        }
416    }
417
418    #[test]
419    fn testing_transpose1() {
420        let mat4x1: Matrix<4, 1> = Matrix::new().unwrap();
421        let mat1x4 = mat4x1.transpose();
422        for i in 0..4 {
423            for j in 0..1 {
424                assert_eq!(mat4x1[i][j], mat1x4[j][i])
425            }
426        }
427    }
428    #[test]
429    fn testing_eye() {
430        let mat: Matrix<3, 3> = Matrix::eye().unwrap();
431
432        for i in 0..3 {
433            assert_eq!(1., mat[i][i])
434        }
435    }
436
437    #[test]
438    fn from_vector_slices() {
439        let data = [[2., 2.], [3., 3.], [4., 5.]];
440        let mat: Matrix<3, 2> = Matrix::from_vector(data).unwrap();
441        for i in 0..3 {
442            for j in 0..2 {
443                assert_eq!(data[i][j], mat[i][j]);
444            }
445        }
446    }
447
448    #[test]
449    fn some_addition() {
450        let mat1: Matrix<2, 2> = Matrix::from_vector([[1., 2.], [3., 4.]]).unwrap();
451
452        let mat2: Matrix<2, 2> = Matrix::from_vector([[1., 2.], [3., 4.]]).unwrap();
453
454        assert_eq!(
455            Matrix::<2, 2>::from_vector([[2., 4.], [6., 8.],]).unwrap(),
456            mat1 + mat2
457        );
458    }
459
460    #[test]
461    fn some_mul_1() {
462        let mat1: Matrix<1, 4> = Matrix::from_vector([[1., 2., 3., 4.]]).unwrap();
463        let mat2 = mat1.transpose();
464        let res = mat1 * mat2;
465
466        assert_eq!(30., res.to_double().unwrap());
467    }
468
469    #[test]
470    fn basic_cloning() {
471        let mat: Matrix<2, 2> = Matrix::from_vector([[2., 3.], [4., -2.]]).unwrap();
472
473        let clone = mat.clone();
474        assert_eq!(clone, mat);
475    }
476
477    #[test]
478    fn basic_concat() {
479        let mat1: Matrix<2, 2> = Matrix::from_vector([[1., 2.], [3., 4.]]).unwrap();
480
481        let mat2: Matrix<2, 1> = Matrix::from_vector([[1.], [2.]]).unwrap();
482
483        let mat3 = mat1.clone().x_concat::<1, 3>(mat2.clone()).unwrap();
484        for i in 0..2 {
485            for j in 0..2 {
486                assert_eq!(mat3[i][j], mat1[i][j]);
487            }
488            for j in 0..1 {
489                assert_eq!(mat3[i][2 + j], mat2[i][j]);
490            }
491        }
492    }
493
494    #[test]
495    fn some_y_concat() {
496        let mat1: Matrix<2, 2> = Matrix::from_vector([[2., 3.], [1., 4.]]).unwrap();
497
498        let mat2: Matrix<1, 2> = Matrix::from_vector([[1., 2.]]).unwrap();
499
500        let mat3 = mat1.clone().y_concat::<1, 3>(mat2.clone()).unwrap();
501        for i in 0..2 {
502            for j in 0..2 {
503                assert_eq!(mat3[j][i], mat1[j][i]);
504            }
505            for j in 0..1 {
506                assert_eq!(mat3[j + 2][i], mat2[j][i]);
507            }
508        }
509    }
510
511    #[test]
512    fn basic_det() {
513        let mat: Matrix<2, 2> = Matrix::from_vector([[1., 2.], [3., 4.]]).unwrap();
514
515        assert_eq!(-2., mat.det());
516
517        let mat: Matrix<1, 1> = Matrix::eye().unwrap();
518        assert_eq!(1., mat.det());
519
520        let mat: Matrix<3, 3> =
521            Matrix::from_vector([[3., 5., 6.], [-2., 3., 5.], [-1., 2., 7.]]).unwrap();
522        assert!(approx_equal(72., mat.det(), 1e-4));
523
524        let mat: Matrix<4, 4> = Matrix::from_vector([
525            [1., 2., 3., 4.],
526            [6., 7., 8., 9.],
527            [11., 12., 13., 14.],
528            [16., 17., 18., 19.],
529        ])
530        .unwrap();
531        assert!(approx_equal(0., mat.det(), 1e-10));
532
533        let mat: Matrix<5, 5> = Matrix::from_vector([
534            [1., 2., 3., 4., 5.],
535            [6., 7., 8., 9., 10.],
536            [11., 12., 13., 14., 11.],
537            [16., 17., 18., 19., 20.],
538            [21., 22., 23., 24., 25.],
539        ])
540        .unwrap();
541        assert!(approx_equal(-2.7150e-44, mat.det(), 1e-10));
542
543        let mat: Matrix<3, 3> = Matrix::eye().unwrap();
544        assert!(approx_equal(1., mat.det(), 1e-10));
545
546        let mat: Matrix<2, 2> = Matrix::from_vector([[1., 2.], [1., 2.]]).unwrap();
547        assert!(approx_equal(0., mat.det(), 1e-10));
548
549        let mut mat: Matrix<3, 3> = Matrix::eye().unwrap();
550        mat[0][0] = 0.;
551        assert!(approx_equal(0., mat.det(), 1e-10));
552    }
553
554    #[test]
555    fn testing_inversion() {
556        let mat: Matrix<3, 3> =
557            Matrix::from_vector([[2., 3., 2.], [1., 5., 3.], [1., 3., 6.]]).unwrap();
558
559        assert_eq!(
560            Matrix::<3, 3>::eye().unwrap(),
561            mat.clone() * mat.inv::<6>().unwrap()
562        );
563
564        let mat: Matrix<2, 2> = Matrix::from_vector([[2., 3.], [1., 2.]]).unwrap();
565        assert_eq!(
566            Matrix::<2, 2>::eye().unwrap(),
567            mat.clone() * mat.inv::<4>().unwrap()
568        );
569
570        let mat: Matrix<1, 1> = Matrix::from_vector([[2.]]).unwrap();
571        assert_eq!(0.5, mat.inv::<2>().unwrap().to_double().unwrap());
572
573        let mat: Matrix<2, 2> = Matrix::eye().unwrap();
574
575        assert_eq!(Matrix::<2, 2>::eye().unwrap(), mat.inv::<4>().unwrap());
576
577        let mat: Matrix<5, 5> = Matrix::diag([1., 2., 3., 4., 5.]).unwrap();
578        assert_eq!(
579            Matrix::<5, 5>::eye().unwrap(),
580            mat.clone() * mat.inv::<10>().unwrap()
581        );
582    }
583
584    #[test]
585    fn testing_pow() {
586        let mat: Matrix<3, 3> =
587            Matrix::from_vector([[1., 3., 5.], [2., 4., 6.], [-2., -3., -4.]]).unwrap();
588
589        assert_eq!(Matrix::<3, 3>::eye().unwrap(), mat.pow(0));
590        assert_eq!(
591            Matrix::<3, 3>::from_vector([[9., -18., -45.], [-2., -44., -86.], [12., 48., 84.],])
592                .unwrap(),
593            mat.pow(4)
594        )
595    }
596
597    #[test]
598    fn testing_swap_cols() {
599        let mut mat1: Matrix<2, 2> = Matrix::eye().unwrap();
600        let mat2: Matrix<2, 2> = Matrix::from_vector([[0., 1.], [1., 0.]]).unwrap();
601        mat1.swap_cols(0, 1).unwrap();
602        assert_eq!(mat1, mat2)
603    }
604
605    #[test]
606    fn testing_scalar_mul() {
607        let mat: Matrix<2, 2> = Matrix::eye().unwrap();
608        assert_eq!(
609            Matrix::<2, 2>::from_vector([[-1., 0.], [0., -1.],]).unwrap(),
610            -1. * mat
611        );
612    }
613
614    #[test]
615    fn testing_sub() {
616        let mat: Matrix<10, 10> = Matrix::eye().unwrap();
617        assert_eq!(Matrix::<10, 10>::new().unwrap(), mat.clone() - mat.clone());
618    }
619
620    #[test]
621    fn testing_diag() {
622        let mat: Matrix<3, 3> = Matrix::diag([1., 2., 3.]).unwrap();
623        let mat1: Matrix<3, 3> =
624            Matrix::from_vector([[1., 0., 0.], [0., 2., 0.], [0., 0., 3.]]).unwrap();
625        assert_eq!(mat1, mat);
626    }
627
628    #[test]
629    fn testing_vec_row() {
630        let mat = Matrix::<5, 1>::vector_to_row([1., 2., 3., 4., 5.]).unwrap();
631        let mat1 = mat.transpose();
632
633        assert_eq!(55., (mat1 * mat).to_double().unwrap());
634    }
635
636    #[test]
637    fn testing_pinv1() {
638        let mat: Matrix<4, 2> =
639            Matrix::from_vector([[1., 0.5], [5., 1.], [-2., 2.], [1., 5.]]).unwrap();
640        assert_eq!(
641            Matrix::<2, 2>::eye().unwrap(),
642            mat.pinv::<4>().unwrap() * mat
643        );
644    }
645
646    #[test]
647    fn testing_pinv2() {
648        let mat: Matrix<2, 4> =
649            Matrix::from_vector([[1., 5., 3., -2.], [2., -1., 5., 2.]]).unwrap();
650        assert_eq!(
651            Matrix::<2, 2>::eye().unwrap(),
652            mat.clone() * mat.pinv::<4>().unwrap()
653        )
654    }
655
656    #[test]
657    fn testing_shift_vector_col() {
658        let mut mat: Matrix<3, 1> = Matrix::from_vector([[1.], [2.], [3.]]).unwrap();
659        mat.shift_data(0.);
660        assert_eq!(
661            Matrix::<3, 1>::from_vector([[0.], [1.], [2.]]).unwrap(),
662            mat
663        );
664    }
665}