1#![no_std]
2use core::marker::Sized;
3use core::result::Result::{Err, Ok};
4use heapless::Vec;
5use matrix_trait::{
6 IsSquareMatrix, IsVectorCol, MatrixConcat, MatrixTrait, SquareMatrix, VectorCol,
7};
8
9pub mod matrix_trait;
10
11pub mod matrix_ops;
12
13use core::clone::Clone;
14use core::iter::Iterator;
15
16#[derive(Debug, Clone)]
17pub struct Matrix<const ROWS: usize, const COLS: usize> {
18 data: Vec<Vec<f64, COLS>, ROWS>,
19}
20
21impl<const ROWS: usize, const COLS: usize> Matrix<ROWS, COLS> {
22 #[allow(dead_code)]
23 fn iter(&self) -> core::slice::Iter<'_, Vec<f64, COLS>> {
24 self.data.iter()
25 }
26
27 #[allow(dead_code)]
28 fn iter_mut(&mut self) -> core::slice::IterMut<'_, Vec<f64, COLS>> {
29 self.data.iter_mut()
30 }
31}
32
33impl<const ROWS: usize, const COLS: usize> MatrixTrait<ROWS, COLS> for Matrix<ROWS, COLS> {
34 type TransposeType = Matrix<COLS, ROWS>;
35
36 fn new() -> Result<Self, &'static str>
37 where
38 Self: Sized,
39 {
40 if ROWS < 1 || COLS < 1 {
41 return Err("Matrix dimensions are invalid");
42 }
43 let mut vec: Vec<Vec<f64, COLS>, ROWS> = Vec::new();
44 for _ in 0..ROWS {
45 let mut helper: Vec<f64, COLS> = Vec::new();
46 for _ in 0..COLS {
47 helper.push(0.).unwrap();
48 }
49 vec.push(helper).unwrap();
50 }
51 Ok(Matrix { data: vec })
52 }
53
54 fn eye() -> Result<Self, &'static str>
55 where
56 Self: Sized,
57 {
58 if ROWS < 1 || COLS < 1 {
59 return Err("Matrix dimensions are invalid");
60 }
61
62 let mut mat: Vec<Vec<f64, COLS>, ROWS> = Vec::new();
63
64 for i in 0..ROWS {
65 let mut row: Vec<f64, COLS> = Vec::new();
66 for j in 0..COLS {
67 if i == j {
68 row.push(1.).unwrap();
69 } else {
70 row.push(0.).unwrap();
71 }
72 }
73
74 mat.push(row).unwrap();
75 }
76
77 Ok(Matrix { data: mat })
78 }
79 fn from_vector(data: [[f64; COLS]; ROWS]) -> Result<Self, &'static str>
94 where
95 Self: Sized,
96 {
97 let mut array_data: Vec<Vec<f64, COLS>, ROWS> = Vec::new();
98 for row in data.iter() {
99 let mut row_data: Vec<f64, COLS> = Vec::new();
100 for &value in row.iter() {
101 row_data.push(value).unwrap();
102 }
103
104 array_data.push(row_data).unwrap();
105 }
106
107 Ok(Matrix { data: array_data })
108 }
109
110 fn to_double(&self) -> Result<f64, &'static str> {
111 if ROWS != 1 && COLS != 1 {
112 return Err("The matrix does not have dimensions 2x2");
113 }
114 Ok(self[0][0])
115 }
116
117 fn transpose(&self) -> Self::TransposeType {
118 let mut transpose: Matrix<COLS, ROWS> = Matrix::new().unwrap();
119
120 for i in 0..ROWS {
121 for j in 0..COLS {
122 transpose[j][i] = self[i][j]
123 }
124 }
125 transpose
126 }
127
128 fn swap_rows(&mut self, row1: usize, row2: usize) -> Result<(), &'static str> {
129 if row1 >= ROWS || row2 >= ROWS {
130 return Err("Row index out of bounds");
131 }
132 self.data.swap(row1, row2);
133 Ok(())
134 }
135
136 fn swap_cols(&mut self, col1: usize, col2: usize) -> Result<(), &'static str> {
137 if col1 >= COLS || col2 >= COLS {
138 return Err("Column indexes are outof bounds");
139 }
140
141 for i in 0..ROWS {
142 let help = self[i][col1];
143 self[i][col1] = self[i][col2];
144 self[i][col2] = help;
145 }
146 Ok(())
147 }
148
149 fn sub_matrix<const NEW_ROWS: usize, const NEW_COLS: usize>(
150 &self,
151 row_start: usize,
152 col_start: usize,
153 ) -> Result<Matrix<NEW_ROWS, NEW_COLS>, &'static str> {
154 if row_start + NEW_ROWS > ROWS || col_start + NEW_COLS > COLS {
155 return Err("Submatrix dimensions are out of bounds");
156 }
157
158 let mut sub_data: Matrix<NEW_ROWS, NEW_COLS> = Matrix::new().unwrap();
159 for i in 0..NEW_ROWS {
160 for j in 0..NEW_COLS {
161 sub_data[i][j] = self[row_start + i][col_start + j];
162 }
163 }
164
165 Ok(sub_data)
166 }
167
168 fn vector_to_row(elems: [f64; ROWS]) -> Result<Matrix<ROWS, 1>, &'static str> {
169 let mut vec_elems = [[0.; 1]; ROWS];
170 for i in 0..ROWS {
171 vec_elems[i][0] = elems[i];
172 }
173
174 Matrix::<ROWS, 1>::from_vector(vec_elems)
175 }
176
177 fn pinv<const DOUBLE: usize>(&self) -> Result<Matrix<COLS, ROWS>, &'static str> {
178 if ROWS > COLS {
179 let mat = self.transpose() * self.clone();
181 let mat = mat.inv::<DOUBLE>()?;
182 Ok(mat * self.transpose())
183 } else {
184 let mat = self.clone() * self.transpose();
186 let mat = mat.inv::<DOUBLE>()?;
187 Ok(self.transpose() * mat)
188 }
189 }
190}
191
192impl<const ROWS: usize, const COLS: usize> MatrixConcat<ROWS, COLS> for Matrix<ROWS, COLS> {
193 fn x_concat<const RHS_COLS: usize, const NEW_COLS: usize>(
210 self,
211 rhs: Matrix<ROWS, RHS_COLS>,
212 ) -> Result<Matrix<ROWS, NEW_COLS>, &'static str> {
213 if RHS_COLS + COLS != NEW_COLS {
214 return Err(
215 "The number of new columns is not equal to the sum of columns of the matrices",
216 );
217 }
218
219 let mut new_data = [[0.; NEW_COLS]; ROWS];
220
221 for i in 0..ROWS {
222 for j in 0..COLS {
223 new_data[i][j] = self.data[i][j];
224 }
225 for j in 0..RHS_COLS {
226 new_data[i][COLS + j] = rhs.data[i][j];
227 }
228 }
229
230 Matrix::from_vector(new_data)
231 }
232
233 fn y_concat<const RHS_ROWS: usize, const NEW_ROWS: usize>(
234 self,
235 rhs: Matrix<RHS_ROWS, COLS>,
236 ) -> Result<Matrix<NEW_ROWS, COLS>, &'static str> {
237 if ROWS + RHS_ROWS != NEW_ROWS {
238 return Err(
239 "The number of new rows is not equal to the sum of rows of the two matrices",
240 );
241 }
242
243 let mut new_data = [[0.; COLS]; NEW_ROWS];
244 for i in 0..COLS {
245 for j in 0..ROWS {
246 new_data[j][i] = self[j][i];
247 }
248 for j in 0..RHS_ROWS {
249 new_data[ROWS + j][i] = rhs[j][i];
250 }
251 }
252
253 Matrix::from_vector(new_data)
254 }
255}
256
257impl<const N: usize> IsSquareMatrix for Matrix<N, N> {}
258
259impl<const N: usize> SquareMatrix<N> for Matrix<N, N> {
260 fn det(&self) -> f64 {
261 let mut copy = self.clone();
262 for j in 0..(N - 1) {
263 for i in ((j + 1)..N).rev() {
264 if copy[j][j] == 0. && copy[i][j] == 0. {
266 return 0.;
267 } else if copy[j][j] == 0. {
268 copy.swap_rows(j, i).unwrap();
269 }
270 let div = copy[i][j] / copy[j][j];
271 for k in 0..N {
272 copy[i][k] -= div * copy[j][k];
274 }
275 }
277 }
278 let mut det = 1.;
279 for i in 0..N {
280 det *= copy[i][i];
281 }
282 det
283 }
284
285 fn inv<const DOUBLE_COLS: usize>(&self) -> Result<Matrix<N, N>, &'static str> {
286 let mat: Matrix<N, N> = Matrix::eye().unwrap();
287 let mut mat = self.clone().x_concat::<N, DOUBLE_COLS>(mat).unwrap();
288
289 for j in 0..(N - 1) {
290 for i in ((j + 1)..N).rev() {
291 if mat[j][j] == 0. && mat[i][j] == 0. {
293 return Err("Matrix cannot be inverted");
294 } else if mat[j][j] == 0. {
295 mat.swap_rows(j, i).unwrap();
296 }
297 let div = mat[i][j] / mat[j][j];
298 for k in 0..DOUBLE_COLS {
299 mat[i][k] -= div * mat[j][k];
301 }
302 }
304 }
305
306 for j in (1..N).rev() {
307 for i in 0..j {
308 if mat[j][j] == 0. && mat[i][j] == 0. {
310 return Err("Matrix cannot be inverted");
311 } else if mat[j][j] == 0. {
312 mat.swap_rows(j, i).unwrap();
313 }
314 let div = mat[i][j] / mat[j][j];
315 for k in 0..DOUBLE_COLS {
316 mat[i][k] -= div * mat[j][k];
318 }
319 }
321 }
322
323 for i in 0..N {
324 let div = mat[i][i];
325 for j in N..DOUBLE_COLS {
326 mat[i][j] /= div;
327 }
328 }
329
330 mat.sub_matrix::<N, N>(0, N)
331 }
332
333 fn pow(&self, n: usize) -> Matrix<N, N> {
334 let mut copy: Matrix<N, N> = Matrix::eye().unwrap();
335 for _ in 0..n {
336 copy *= self;
337 }
338 copy
339 }
340
341 fn diag(elems: [f64; N]) -> Result<Matrix<N, N>, &'static str> {
342 let mut vec_elem = [[0.; N]; N];
343 for i in 0..N {
344 vec_elem[i][i] = elems[i];
345 }
346
347 Matrix::from_vector(vec_elem)
348 }
349}
350
351impl<const ROWS: usize> IsVectorCol for Matrix<ROWS, 1> {}
352
353impl<const ROWS: usize> VectorCol<ROWS> for Matrix<ROWS, 1> {
354 fn shift_data(&mut self, data: f64) {
355 for i in (1..ROWS).rev() {
356 self[i][0] = self[i - 1][0];
357 }
358 self[0][0] = data;
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365 use crate::matrix_ops::approx_equal;
366
367 #[test]
368 fn succes_creation() {
369 type Mat3x3 = Matrix<3, 3>;
370 match Mat3x3::new() {
371 Ok(_) => assert!(true),
372 Err(_) => assert!(false),
373 }
374 }
375
376 #[test]
377 fn fail_creation_1() {
378 type Mat3x3 = Matrix<0, 3>;
379 match Mat3x3::new() {
380 Ok(_) => assert!(false),
381 Err(_) => assert!(true),
382 };
383 }
384
385 #[test]
386 fn indexing_elems() {
387 let mut mat2x2: Matrix<2, 2> = Matrix::new().unwrap();
388
389 for i in 0..2 {
390 for j in 0..2 {
391 mat2x2[i][j] = i as f64 + j as f64;
392 }
393 }
394
395 for i in 0..2 {
396 for j in 0..2 {
397 assert_eq!(i as f64 + j as f64, mat2x2[i][j]);
398 }
399 }
400 }
401
402 #[test]
403 fn iterating_elems() {
404 let mut mat3x3: Matrix<1, 3> = Matrix::new().unwrap();
405
406 for vec in mat3x3.iter_mut() {
407 for elem in vec.iter_mut() {
408 *elem = 3.;
409 }
410 }
411 for vec in mat3x3.iter() {
412 for elem in vec.iter() {
413 assert_eq!(3., *elem);
414 }
415 }
416 }
417
418 #[test]
419 fn testing_transpose1() {
420 let mat4x1: Matrix<4, 1> = Matrix::new().unwrap();
421 let mat1x4 = mat4x1.transpose();
422 for i in 0..4 {
423 for j in 0..1 {
424 assert_eq!(mat4x1[i][j], mat1x4[j][i])
425 }
426 }
427 }
428 #[test]
429 fn testing_eye() {
430 let mat: Matrix<3, 3> = Matrix::eye().unwrap();
431
432 for i in 0..3 {
433 assert_eq!(1., mat[i][i])
434 }
435 }
436
437 #[test]
438 fn from_vector_slices() {
439 let data = [[2., 2.], [3., 3.], [4., 5.]];
440 let mat: Matrix<3, 2> = Matrix::from_vector(data).unwrap();
441 for i in 0..3 {
442 for j in 0..2 {
443 assert_eq!(data[i][j], mat[i][j]);
444 }
445 }
446 }
447
448 #[test]
449 fn some_addition() {
450 let mat1: Matrix<2, 2> = Matrix::from_vector([[1., 2.], [3., 4.]]).unwrap();
451
452 let mat2: Matrix<2, 2> = Matrix::from_vector([[1., 2.], [3., 4.]]).unwrap();
453
454 assert_eq!(
455 Matrix::<2, 2>::from_vector([[2., 4.], [6., 8.],]).unwrap(),
456 mat1 + mat2
457 );
458 }
459
460 #[test]
461 fn some_mul_1() {
462 let mat1: Matrix<1, 4> = Matrix::from_vector([[1., 2., 3., 4.]]).unwrap();
463 let mat2 = mat1.transpose();
464 let res = mat1 * mat2;
465
466 assert_eq!(30., res.to_double().unwrap());
467 }
468
469 #[test]
470 fn basic_cloning() {
471 let mat: Matrix<2, 2> = Matrix::from_vector([[2., 3.], [4., -2.]]).unwrap();
472
473 let clone = mat.clone();
474 assert_eq!(clone, mat);
475 }
476
477 #[test]
478 fn basic_concat() {
479 let mat1: Matrix<2, 2> = Matrix::from_vector([[1., 2.], [3., 4.]]).unwrap();
480
481 let mat2: Matrix<2, 1> = Matrix::from_vector([[1.], [2.]]).unwrap();
482
483 let mat3 = mat1.clone().x_concat::<1, 3>(mat2.clone()).unwrap();
484 for i in 0..2 {
485 for j in 0..2 {
486 assert_eq!(mat3[i][j], mat1[i][j]);
487 }
488 for j in 0..1 {
489 assert_eq!(mat3[i][2 + j], mat2[i][j]);
490 }
491 }
492 }
493
494 #[test]
495 fn some_y_concat() {
496 let mat1: Matrix<2, 2> = Matrix::from_vector([[2., 3.], [1., 4.]]).unwrap();
497
498 let mat2: Matrix<1, 2> = Matrix::from_vector([[1., 2.]]).unwrap();
499
500 let mat3 = mat1.clone().y_concat::<1, 3>(mat2.clone()).unwrap();
501 for i in 0..2 {
502 for j in 0..2 {
503 assert_eq!(mat3[j][i], mat1[j][i]);
504 }
505 for j in 0..1 {
506 assert_eq!(mat3[j + 2][i], mat2[j][i]);
507 }
508 }
509 }
510
511 #[test]
512 fn basic_det() {
513 let mat: Matrix<2, 2> = Matrix::from_vector([[1., 2.], [3., 4.]]).unwrap();
514
515 assert_eq!(-2., mat.det());
516
517 let mat: Matrix<1, 1> = Matrix::eye().unwrap();
518 assert_eq!(1., mat.det());
519
520 let mat: Matrix<3, 3> =
521 Matrix::from_vector([[3., 5., 6.], [-2., 3., 5.], [-1., 2., 7.]]).unwrap();
522 assert!(approx_equal(72., mat.det(), 1e-4));
523
524 let mat: Matrix<4, 4> = Matrix::from_vector([
525 [1., 2., 3., 4.],
526 [6., 7., 8., 9.],
527 [11., 12., 13., 14.],
528 [16., 17., 18., 19.],
529 ])
530 .unwrap();
531 assert!(approx_equal(0., mat.det(), 1e-10));
532
533 let mat: Matrix<5, 5> = Matrix::from_vector([
534 [1., 2., 3., 4., 5.],
535 [6., 7., 8., 9., 10.],
536 [11., 12., 13., 14., 11.],
537 [16., 17., 18., 19., 20.],
538 [21., 22., 23., 24., 25.],
539 ])
540 .unwrap();
541 assert!(approx_equal(-2.7150e-44, mat.det(), 1e-10));
542
543 let mat: Matrix<3, 3> = Matrix::eye().unwrap();
544 assert!(approx_equal(1., mat.det(), 1e-10));
545
546 let mat: Matrix<2, 2> = Matrix::from_vector([[1., 2.], [1., 2.]]).unwrap();
547 assert!(approx_equal(0., mat.det(), 1e-10));
548
549 let mut mat: Matrix<3, 3> = Matrix::eye().unwrap();
550 mat[0][0] = 0.;
551 assert!(approx_equal(0., mat.det(), 1e-10));
552 }
553
554 #[test]
555 fn testing_inversion() {
556 let mat: Matrix<3, 3> =
557 Matrix::from_vector([[2., 3., 2.], [1., 5., 3.], [1., 3., 6.]]).unwrap();
558
559 assert_eq!(
560 Matrix::<3, 3>::eye().unwrap(),
561 mat.clone() * mat.inv::<6>().unwrap()
562 );
563
564 let mat: Matrix<2, 2> = Matrix::from_vector([[2., 3.], [1., 2.]]).unwrap();
565 assert_eq!(
566 Matrix::<2, 2>::eye().unwrap(),
567 mat.clone() * mat.inv::<4>().unwrap()
568 );
569
570 let mat: Matrix<1, 1> = Matrix::from_vector([[2.]]).unwrap();
571 assert_eq!(0.5, mat.inv::<2>().unwrap().to_double().unwrap());
572
573 let mat: Matrix<2, 2> = Matrix::eye().unwrap();
574
575 assert_eq!(Matrix::<2, 2>::eye().unwrap(), mat.inv::<4>().unwrap());
576
577 let mat: Matrix<5, 5> = Matrix::diag([1., 2., 3., 4., 5.]).unwrap();
578 assert_eq!(
579 Matrix::<5, 5>::eye().unwrap(),
580 mat.clone() * mat.inv::<10>().unwrap()
581 );
582 }
583
584 #[test]
585 fn testing_pow() {
586 let mat: Matrix<3, 3> =
587 Matrix::from_vector([[1., 3., 5.], [2., 4., 6.], [-2., -3., -4.]]).unwrap();
588
589 assert_eq!(Matrix::<3, 3>::eye().unwrap(), mat.pow(0));
590 assert_eq!(
591 Matrix::<3, 3>::from_vector([[9., -18., -45.], [-2., -44., -86.], [12., 48., 84.],])
592 .unwrap(),
593 mat.pow(4)
594 )
595 }
596
597 #[test]
598 fn testing_swap_cols() {
599 let mut mat1: Matrix<2, 2> = Matrix::eye().unwrap();
600 let mat2: Matrix<2, 2> = Matrix::from_vector([[0., 1.], [1., 0.]]).unwrap();
601 mat1.swap_cols(0, 1).unwrap();
602 assert_eq!(mat1, mat2)
603 }
604
605 #[test]
606 fn testing_scalar_mul() {
607 let mat: Matrix<2, 2> = Matrix::eye().unwrap();
608 assert_eq!(
609 Matrix::<2, 2>::from_vector([[-1., 0.], [0., -1.],]).unwrap(),
610 -1. * mat
611 );
612 }
613
614 #[test]
615 fn testing_sub() {
616 let mat: Matrix<10, 10> = Matrix::eye().unwrap();
617 assert_eq!(Matrix::<10, 10>::new().unwrap(), mat.clone() - mat.clone());
618 }
619
620 #[test]
621 fn testing_diag() {
622 let mat: Matrix<3, 3> = Matrix::diag([1., 2., 3.]).unwrap();
623 let mat1: Matrix<3, 3> =
624 Matrix::from_vector([[1., 0., 0.], [0., 2., 0.], [0., 0., 3.]]).unwrap();
625 assert_eq!(mat1, mat);
626 }
627
628 #[test]
629 fn testing_vec_row() {
630 let mat = Matrix::<5, 1>::vector_to_row([1., 2., 3., 4., 5.]).unwrap();
631 let mat1 = mat.transpose();
632
633 assert_eq!(55., (mat1 * mat).to_double().unwrap());
634 }
635
636 #[test]
637 fn testing_pinv1() {
638 let mat: Matrix<4, 2> =
639 Matrix::from_vector([[1., 0.5], [5., 1.], [-2., 2.], [1., 5.]]).unwrap();
640 assert_eq!(
641 Matrix::<2, 2>::eye().unwrap(),
642 mat.pinv::<4>().unwrap() * mat
643 );
644 }
645
646 #[test]
647 fn testing_pinv2() {
648 let mat: Matrix<2, 4> =
649 Matrix::from_vector([[1., 5., 3., -2.], [2., -1., 5., 2.]]).unwrap();
650 assert_eq!(
651 Matrix::<2, 2>::eye().unwrap(),
652 mat.clone() * mat.pinv::<4>().unwrap()
653 )
654 }
655
656 #[test]
657 fn testing_shift_vector_col() {
658 let mut mat: Matrix<3, 1> = Matrix::from_vector([[1.], [2.], [3.]]).unwrap();
659 mat.shift_data(0.);
660 assert_eq!(
661 Matrix::<3, 1>::from_vector([[0.], [1.], [2.]]).unwrap(),
662 mat
663 );
664 }
665}