1use std::{
2 fmt::Debug,
3 ops::{Add, Div, Index, IndexMut, Mul, Sub},
4};
5
6use rand::random;
7
8use crate::quick_grad::{grad_tape::GradTape, var::Var};
9
10use super::errors::MatrixError;
11
12#[derive(Clone, PartialEq)]
41pub struct Matrix<T> {
42 rows: usize,
44 cols: usize,
46 data: Vec<T>,
48}
49
50impl<T: Debug> Debug for Matrix<T> {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 let mut r = String::new();
53 for i in 0..self.rows {
54 for j in 0..self.cols {
55 r.push_str(&format!("{:?} ", self.data[i * self.cols + j]));
56 }
57
58 r.push_str("\n");
59 }
60
61 write!(f, "{}", r)
62 }
63}
64
65impl<T: Copy> Matrix<T> {
66 pub fn get_data(&self) -> &Vec<T> {
70 &self.data
71 }
72
73 pub fn get_data_mut(&mut self) -> &mut Vec<T> {
74 &mut self.data
75 }
76
77 pub fn get_shape(&self) -> Vec<usize> {
80 vec![self.rows, self.cols]
81 }
82
83 pub fn get_row_slice(&self, row: usize) -> &[T] {
86 &self.data[(row * self.cols)..(row * self.cols + self.cols)]
87 }
88
89 pub fn get_row_slice_mut(&mut self, row: usize) -> &[T] {
92 &mut self.data[(row * self.cols)..(row * self.cols + self.cols)]
93 }
94
95 pub fn get_row(&self, row: usize) -> impl Iterator<Item = &T> {
98 self.data.iter().skip(row * self.cols).take(self.cols)
99 }
100
101 pub fn get_row_mut(&mut self, row: usize) -> impl Iterator<Item = &mut T> {
104 self.data.iter_mut().skip(row * self.cols).take(self.cols)
105 }
106
107 pub fn get_col(&self, col: usize) -> impl Iterator<Item = &T> {
110 self.data.iter().skip(col).step_by(self.cols)
111 }
112
113 pub fn get_col_mut(&mut self, col: usize) -> impl Iterator<Item = &mut T> {
116 self.data.iter_mut().skip(col).step_by(self.cols)
117 }
118
119 pub fn get_rows(&self) -> usize {
122 self.rows
123 }
124 pub fn get_cols(&self) -> usize {
127 self.cols
128 }
129
130 pub fn map(&self, f: fn(T) -> T) -> Matrix<T> {
133 Matrix {
134 rows: self.rows,
135 cols: self.cols,
136 data: self.data.iter().copied().map(f).collect::<Vec<_>>(),
137 }
138 }
139
140 pub fn apply(&mut self, f: fn(T) -> T) {
143 for i in &mut self.data {
144 *i = f(*i);
145 }
146 }
147
148 pub fn reshape(&self, new_rows: usize, new_cols: usize) -> Result<Matrix<T>, MatrixError> {
152 if (self.rows * self.cols) != (new_rows * new_cols) {
153 Err(MatrixError::InvalidReshape {
154 numel: self.get_rows() * self.get_cols(),
155 forcing_into: new_rows * new_cols,
156 })
157 } else {
158 Ok(Matrix {
159 rows: new_rows,
160 cols: new_cols,
161 data: self.data.clone(),
162 })
163 }
164 }
165
166 pub fn repeat_h(&self, times: usize) -> Matrix<T> {
169 let mut data = vec![];
170 for x in self.data.clone() {
171 for _ in 0..times {
172 data.push(x);
173 }
174 }
175
176 Matrix {
177 rows: self.rows,
178 cols: self.cols * times,
179 data,
180 }
181 }
182}
183
184impl Matrix<f64> {
185 pub fn rand(rows: usize, cols: usize) -> Matrix<f64> {
189 let mut r = Matrix::zero(rows, cols);
190
191 for i in 0..r.get_rows() {
192 for j in 0..r.get_cols() {
193 r[(i, j)] = random();
194 }
195 }
196
197 r
198 }
199
200 pub fn from_array<const R: usize, const C: usize>(arr: [[f64; C]; R]) -> Matrix<f64> {
205 let mut data: Vec<f64> = Vec::new();
206 for row in arr {
207 for element in row {
208 data.push(element);
209 }
210 }
211
212 Matrix {
213 rows: R,
214 cols: C,
215 data,
216 }
217 }
218 pub fn zero(rows: usize, cols: usize) -> Matrix<f64> {
221 let mut data: Vec<f64> = Vec::new();
222 for _i in 0..rows * cols {
223 data.push(0f64);
224 }
225 Matrix { rows, cols, data }
226 }
227 pub fn one(rows: usize, cols: usize) -> Matrix<f64> {
232 let mut data: Vec<f64> = Vec::new();
233 for _i in 0..rows * cols {
234 data.push(1f64);
235 }
236 Matrix { rows, cols, data }
237 }
238 pub fn fill(&mut self, value: f64) {
241 for item in self.data.iter_mut() {
242 *item = value;
243 }
244 }
245 pub fn dot(&self, other: &Matrix<f64>) -> Result<Matrix<f64>, MatrixError> {
248 if self.cols != other.get_rows() {
249 return Err(MatrixError::MatMulDimensionsMismatch {
250 size_1: self.get_shape(),
251 size_2: other.get_shape(),
252 });
253 }
254
255 let mut m: Matrix<f64> = Matrix::zero(self.rows, other.cols);
256 for i in 0..self.rows {
257 for j in 0..other.cols {
258 m[(i, j)] = 0f64;
259 m[(i, j)] = vec_dot(
260 self.get_row(i).copied().collect(),
261 other.get_col(j).copied().collect(),
262 )
263 }
264 }
265
266 Ok(m)
267 }
268 pub fn transpose(&self) -> Matrix<f64> {
271 let mut r = Matrix::zero(self.cols, self.rows);
272 for i in 0..self.rows {
273 for j in 0..self.cols {
274 r[(j, i)] = self[(i, j)];
275 }
276 }
277
278 r
279 }
280}
281
282impl Matrix<Var> {
283 pub fn g_rand(tape: &GradTape, rows: usize, cols: usize) -> Matrix<Var> {
286 let mut data: Vec<Var> = Vec::new();
287 for _i in 0..rows * cols {
288 data.push(tape.var(random()));
289 }
290 Matrix { rows, cols, data }
291 }
292
293 pub fn apply_to_value(&mut self, f: fn(x: Var) -> f64) {
294 for i in &mut self.data {
295 *i.value_mut() = f(*i);
296 }
297 }
298
299 pub fn g_from_array<const R: usize, const C: usize>(
300 tape: &GradTape,
301 arr: [[f64; C]; R],
302 ) -> Matrix<Var> {
303 let mut data: Vec<Var> = Vec::new();
304 for row in arr {
305 for element in row {
306 data.push(tape.var(element));
307 }
308 }
309
310 Matrix {
311 rows: R,
312 cols: C,
313 data,
314 }
315 }
316
317 pub fn g_zero(tape: &GradTape, rows: usize, cols: usize) -> Matrix<Var> {
320 let mut data: Vec<Var> = Vec::new();
321 for _i in 0..rows * cols {
322 data.push(tape.var(0.0));
323 }
324 Matrix { rows, cols, data }
325 }
326
327 pub fn g_one(tape: &GradTape, rows: usize, cols: usize) -> Matrix<Var> {
330 let mut data: Vec<Var> = Vec::new();
331 for _i in 0..rows * cols {
332 data.push(tape.var(1.0));
333 }
334 Matrix { rows, cols, data }
335 }
336
337 pub fn g_fill(&mut self, tape: &GradTape, value: f64) {
340 for i in 0..self.rows * self.cols {
341 self.data[i] = tape.var(value);
342 }
343 }
344
345 pub fn g_dot(&self, tape: &GradTape, other: &Matrix<Var>) -> Result<Matrix<Var>, MatrixError> {
348 if self.cols != other.get_rows() {
349 return Err(MatrixError::MatMulDimensionsMismatch {
350 size_1: self.get_shape(),
351 size_2: other.get_shape(),
352 });
353 }
354
355 let mut m: Matrix<Var> = Matrix::g_zero(tape, self.rows, other.cols);
356 for i in 0..self.rows {
357 for j in 0..other.cols {
358 m[(i, j)] = g_vec_dot(
360 tape,
361 self.get_row(i).copied().collect(),
362 other.get_col(j).copied().collect(),
363 )
364 }
365 }
366
367 Ok(m)
368 }
369 pub fn transpose(&self, tape: &GradTape) -> Matrix<Var> {
372 let mut r = Matrix::g_zero(tape, self.cols, self.rows);
373 for i in 0..self.rows {
374 for j in 0..self.cols {
375 r[(j, i)] = self[(i, j)];
376 }
377 }
378
379 r
380 }
381
382 pub fn value(&self) -> Matrix<f64> {
385 let mut r = Matrix::zero(self.rows, self.cols);
386 for i in 0..self.rows {
387 for j in 0..self.cols {
388 r[(i, j)] = self[(i, j)].value();
389 }
390 }
391
392 r
393 }
394}
395
396impl<T> Index<(usize, usize)> for Matrix<T> {
397 type Output = T;
398 fn index(&self, index: (usize, usize)) -> &Self::Output {
399 &self.data[index.0 * self.cols + index.1]
400 }
401}
402
403impl<T> IndexMut<(usize, usize)> for Matrix<T> {
404 fn index_mut(&mut self, index: (usize, usize)) -> &mut T {
405 &mut self.data[index.0 * self.cols + index.1]
406 }
407}
408
409impl<T: Copy + Clone + Add<U, Output = T>, U: Copy + Add<T>> Add<Matrix<U>> for Matrix<T> {
410 type Output = Matrix<T>;
411 fn add(self, other: Matrix<U>) -> Matrix<T> {
412 let mut data = self.data.clone();
413 for i in 0..data.len() {
414 data[i] = data[i] + other.data[i];
415 }
416
417 Matrix {
418 rows: self.rows,
419 cols: self.cols,
420 data,
421 }
422 }
423}
424
425impl<T: Copy + Clone + Add<Output = T>> Add<&Matrix<T>> for Matrix<T> {
426 type Output = Matrix<T>;
427 fn add(self, other: &Matrix<T>) -> Matrix<T> {
428 let mut data = self.data.clone();
429 for i in 0..data.len() {
430 data[i] = data[i] + other.data[i];
431 }
432
433 Matrix {
434 rows: self.rows,
435 cols: self.cols,
436 data,
437 }
438 }
439}
440
441impl<T: Copy + Clone + Add<f64, Output = T>> Add<f64> for Matrix<T> {
442 type Output = Matrix<T>;
443 fn add(self, other: f64) -> Matrix<T> {
444 let mut data = self.data.clone();
445 for i in 0..data.len() {
446 data[i] = data[i] + other;
447 }
448
449 Matrix {
450 rows: self.rows,
451 cols: self.cols,
452 data,
453 }
454 }
455}
456
457impl<T: Copy + Clone + Sub<Output = T>> Sub<Matrix<T>> for Matrix<T> {
458 type Output = Matrix<T>;
459 fn sub(self, other: Matrix<T>) -> Matrix<T> {
460 let mut data = self.data.clone();
461 for i in 0..data.len() {
462 data[i] = data[i] - other.data[i];
463 }
464
465 Matrix {
466 rows: self.rows,
467 cols: self.cols,
468 data,
469 }
470 }
471}
472impl<T: Copy + Clone + Sub<Output = T>> Sub<&Matrix<T>> for Matrix<T> {
473 type Output = Matrix<T>;
474 fn sub(self, other: &Matrix<T>) -> Matrix<T> {
475 let mut data = self.data.clone();
476 for i in 0..data.len() {
477 data[i] = data[i] - other.data[i];
478 }
479
480 Matrix {
481 rows: self.rows,
482 cols: self.cols,
483 data,
484 }
485 }
486}
487impl<T: Copy + Clone + Sub<f64, Output = T>> Sub<f64> for Matrix<T> {
488 type Output = Matrix<T>;
489 fn sub(self, other: f64) -> Matrix<T> {
490 let mut data = self.data.clone();
491 for i in 0..data.len() {
492 data[i] = data[i] - other;
493 }
494
495 Matrix {
496 rows: self.rows,
497 cols: self.cols,
498 data,
499 }
500 }
501}
502
503impl<T: Copy + Clone + Mul<Output = T>> Mul<Matrix<T>> for Matrix<T> {
504 type Output = Matrix<T>;
505 fn mul(self, other: Matrix<T>) -> Matrix<T> {
506 let mut data = self.data.clone();
507 for i in 0..data.len() {
508 data[i] = data[i] * other.data[i];
509 }
510
511 Matrix {
512 rows: self.rows,
513 cols: self.cols,
514 data,
515 }
516 }
517}
518impl<T: Copy + Clone + Mul<Output = T>> Mul<&Matrix<T>> for Matrix<T> {
519 type Output = Matrix<T>;
520 fn mul(self, other: &Matrix<T>) -> Matrix<T> {
521 let mut data = self.data.clone();
522 for i in 0..data.len() {
523 data[i] = data[i] * other.data[i];
524 }
525
526 Matrix {
527 rows: self.rows,
528 cols: self.cols,
529 data,
530 }
531 }
532}
533impl<T: Copy + Clone + Mul<f64, Output = T>> Mul<f64> for Matrix<T> {
534 type Output = Matrix<T>;
535 fn mul(self, other: f64) -> Matrix<T> {
536 let mut data = self.data.clone();
537 for i in 0..data.len() {
538 data[i] = data[i] * other;
539 }
540
541 Matrix {
542 rows: self.rows,
543 cols: self.cols,
544 data,
545 }
546 }
547}
548impl<T: Copy + Clone + Div<Output = T>> Div<Matrix<T>> for Matrix<T> {
549 type Output = Matrix<T>;
550 fn div(self, other: Matrix<T>) -> Matrix<T> {
551 let mut data = self.data.clone();
552 for i in 0..data.len() {
553 data[i] = data[i] / other.data[i];
554 }
555
556 Matrix {
557 rows: self.rows,
558 cols: self.cols,
559 data,
560 }
561 }
562}
563impl<T: Copy + Clone + Div<Output = T>> Div<&Matrix<T>> for Matrix<T> {
564 type Output = Matrix<T>;
565 fn div(self, other: &Matrix<T>) -> Matrix<T> {
566 let mut data = self.data.clone();
567 for i in 0..data.len() {
568 data[i] = data[i] / other.data[i];
569 }
570
571 Matrix {
572 rows: self.rows,
573 cols: self.cols,
574 data,
575 }
576 }
577}
578impl<T: Copy + Clone + Div<f64, Output = T>> Div<f64> for Matrix<T> {
579 type Output = Matrix<T>;
580 fn div(self, other: f64) -> Matrix<T> {
581 let mut data = self.data.clone();
582 for i in 0..data.len() {
583 data[i] = data[i] / other;
584 }
585
586 Matrix {
587 rows: self.rows,
588 cols: self.cols,
589 data,
590 }
591 }
592}
593
594impl<T: Copy + Clone + Add<Output = T>> Add<&Matrix<T>> for &Matrix<T> {
596 type Output = Matrix<T>;
597 fn add(self, other: &Matrix<T>) -> Matrix<T> {
598 let mut data = self.data.clone();
599 for i in 0..data.len() {
600 data[i] = data[i] + other.data[i];
601 }
602
603 Matrix {
604 rows: self.rows,
605 cols: self.cols,
606 data,
607 }
608 }
609}
610
611impl<T: Copy + Clone + Add<Output = T>> Sub<&Matrix<T>> for &Matrix<T> {
612 type Output = Matrix<T>;
613 fn sub(self, other: &Matrix<T>) -> Matrix<T> {
614 let mut data = self.data.clone();
615 for i in 0..data.len() {
616 data[i] = data[i] + other.data[i];
617 }
618
619 Matrix {
620 rows: self.rows,
621 cols: self.cols,
622 data,
623 }
624 }
625}
626
627impl<T: Copy + Clone + Mul<Output = T>> Mul<&Matrix<T>> for &Matrix<T> {
628 type Output = Matrix<T>;
629 fn mul(self, other: &Matrix<T>) -> Matrix<T> {
630 let mut data = self.data.clone();
631 for i in 0..data.len() {
632 data[i] = data[i] * other.data[i];
633 }
634
635 Matrix {
636 rows: self.rows,
637 cols: self.cols,
638 data,
639 }
640 }
641}
642
643impl<T: Copy + Clone + Div<Output = T>> Div<&Matrix<T>> for &Matrix<T> {
644 type Output = Matrix<T>;
645 fn div(self, other: &Matrix<T>) -> Matrix<T> {
646 let mut data = self.data.clone();
647 for i in 0..data.len() {
648 data[i] = data[i] / other.data[i];
649 }
650
651 Matrix {
652 rows: self.rows,
653 cols: self.cols,
654 data,
655 }
656 }
657}
658
659impl<T: Copy + Clone + Add<f64, Output = T>> Add<f64> for &Matrix<T> {
660 type Output = Matrix<T>;
661 fn add(self, other: f64) -> Matrix<T> {
662 let mut data = self.data.clone();
663 for i in 0..data.len() {
664 data[i] = data[i] + other;
665 }
666
667 Matrix {
668 rows: self.rows,
669 cols: self.cols,
670 data,
671 }
672 }
673}
674
675impl<T: Copy + Clone + Sub<f64, Output = T>> Sub<f64> for &Matrix<T> {
676 type Output = Matrix<T>;
677 fn sub(self, other: f64) -> Matrix<T> {
678 let mut data = self.data.clone();
679 for i in 0..data.len() {
680 data[i] = data[i] - other;
681 }
682
683 Matrix {
684 rows: self.rows,
685 cols: self.cols,
686 data,
687 }
688 }
689}
690
691impl<T: Copy + Clone + Mul<f64, Output = T>> Mul<f64> for &Matrix<T> {
692 type Output = Matrix<T>;
693 fn mul(self, other: f64) -> Matrix<T> {
694 let mut data = self.data.clone();
695 for i in 0..data.len() {
696 data[i] = data[i] * other;
697 }
698
699 Matrix {
700 rows: self.rows,
701 cols: self.cols,
702 data,
703 }
704 }
705}
706
707impl<T: Copy + Clone + Div<f64, Output = T>> Div<f64> for &Matrix<T> {
708 type Output = Matrix<T>;
709 fn div(self, other: f64) -> Matrix<T> {
710 let mut data = self.data.clone();
711 for i in 0..data.len() {
712 data[i] = data[i] / other;
713 }
714
715 Matrix {
716 rows: self.rows,
717 cols: self.cols,
718 data,
719 }
720 }
721}
722
723fn vec_dot(v1: Vec<f64>, v2: Vec<f64>) -> f64 {
724 let mut r = 0.0;
726
727 let len = v1.len();
728
729 for i in 0..len {
730 r = r + v1[i] * v2[i];
731 }
732 r
734}
735
736fn g_vec_dot(tape: &GradTape, v1: Vec<Var>, v2: Vec<Var>) -> Var {
737 let mut r = tape.var(0.0);
739
740 let len = v1.len();
741 for i in 0..len {
742 r = r + v1[i] * v2[i];
743 }
744 r
746}
747
748#[cfg(test)]
749mod tests {
750 use super::*;
751 #[test]
752 fn create_matrix() {
753 let _m = Matrix::zero(2, 3);
754 }
755
756 #[test]
757 fn get_row() {
758 let mut m = Matrix::from_array([[1f64, 2f64, 3f64, 4f64], [5f64, 6f64, 7f64, 8f64]]);
759
760 let row: Vec<f64> = m.get_row(0).copied().collect::<Vec<f64>>();
761
762 assert_eq!(row, vec![1f64, 2f64, 3f64, 4f64]);
763 *m.get_row_mut(1).nth(1).unwrap() = 100f64;
764
765 let row: Vec<f64> = m.get_row(1).copied().collect::<Vec<f64>>();
766 assert_eq!(row, vec![5f64, 100f64, 7f64, 8f64]);
767 }
768
769 #[test]
770 fn get_col() {
771 let mut m = Matrix::from_array([[1f64, 2f64, 3f64, 4f64], [5f64, 6f64, 7f64, 8f64]]);
772
773 let col: Vec<f64> = m.get_col(0).copied().collect::<Vec<f64>>();
774
775 assert_eq!(col, vec![1f64, 5f64]);
776
777 *m.get_col_mut(1).nth(1).unwrap() = 100f64;
778
779 let row: Vec<f64> = m.get_col(1).copied().collect::<Vec<f64>>();
780 assert_eq!(row, vec![2f64, 100f64]);
781 }
782
783 #[test]
784 fn map() {
785 let m = Matrix::from_array([[1f64, 2f64, 3f64], [4f64, 5f64, 6f64]]);
786
787 assert_eq!(
788 m.map(|x| x + 2.0f64),
789 Matrix::from_array([[3f64, 4f64, 5f64], [6f64, 7f64, 8f64]])
790 );
791 }
792
793 #[test]
794 fn apply() {
795 let mut m = Matrix::from_array([[1f64, 2f64, 3f64], [4f64, 5f64, 6f64]]);
796
797 m.apply(|x| x + 2.0f64);
798
799 assert_eq!(
800 m,
801 Matrix::from_array([[3f64, 4f64, 5f64], [6f64, 7f64, 8f64]])
802 );
803 }
804
805 #[test]
806 fn reshape() {
807 let m = Matrix::from_array([[1f64, 2f64], [3f64, 4f64], [5f64, 6f64]]);
808
809 assert_eq!(
810 m.reshape(2, 3).unwrap(),
811 Matrix::from_array([[1f64, 2f64, 3f64], [4f64, 5f64, 6f64]])
812 )
813 }
814 #[test]
815 fn transpose() {
816 let m = Matrix::from_array([[1f64, 2f64], [3f64, 4f64], [5f64, 6f64]]);
817
818 assert_eq!(
819 m.transpose(),
820 Matrix::from_array([[1f64, 3f64, 5f64], [2f64, 4f64, 6f64]])
821 )
822 }
823
824 #[test]
825 fn dot() {
826 let m1 = Matrix::from_array([[1f64, 2f64, 3f64], [4f64, 5f64, 6f64]]);
827 let m2 = Matrix::from_array([[1f64, 2f64], [3f64, 4f64], [5f64, 6f64]]);
828
829 assert_eq!(
830 m1.dot(&m2).unwrap(),
831 Matrix::from_array([[22f64, 28f64], [49f64, 64f64]])
832 );
833 }
834
835 #[test]
836 fn basic_matrix_differentiation() {
837 let t = GradTape::new();
838 let mut m1 = Matrix::g_from_array(&t, [[1f64, 2f64, 3f64], [4f64, 5f64, 6f64]]);
839 let mut m2 = Matrix::g_from_array(&t, [[1f64, 2f64, 3f64], [4f64, 5f64, 6f64]]);
840
841 let m3 = &m1 * &m2;
842
843 let grad = m3[(1, 2)].backward();
844
845 assert_eq!(grad[&m2[(1, 2)]], m1[(1, 2)].value());
846
847 t.clear({
848 let mut r = Vec::new();
849 for x in m1.get_data_mut() {
850 r.push(x);
851 }
852 for x in m2.get_data_mut() {
853 r.push(x);
854 }
855
856 r
857 });
858 let m3 = &m1 * &m2;
859 let grad = m3[(1, 2)].backward();
860 assert_eq!(grad[&m2[(1, 2)]], m1[(1, 2)].value());
861 }
862}