quickbrain/quick_math/
matrix.rs

1use std::{
2    fmt::Debug,
3    ops::{Add, Div, Index, IndexMut, Mul, Sub},
4};
5
6use rand::random;
7
8use crate::quick_grad::{grad_tape::GradTape, var::Var};
9
10use super::errors::MatrixError;
11
12/// # Struct **Matrix**
13/// A classic implementation of the Matrix data structure
14///
15/// ---
16/// Includes
17/// - The possibility to create a Matrix structure starting from various other structures
18/// - The possibility to manipulate entire matrices as they were individual items
19/// - Operator overloading
20/// - Matrix math
21/// - Safely reshaping and transposing matrices
22/// - Trapping you in a simulation
23/// ---
24/// ### A demonstration
25/// ```
26/// use quickbrain::quick_math::matrix::Matrix;
27/// let m : Matrix<f64> = Matrix::from_array([[1., 2., 3.], [4., 5., 6.]]);
28/// // [1, 2, 3]
29/// // [4, 5, 6]
30/// let m2 : Matrix<f64> = Matrix::from_array([[1., 2.], [3., 4.], [5., 6.]]);
31/// // [1, 2]
32/// // [3, 4]
33/// // [5, 6]
34///
35/// // Multiplying matrices by a scalar
36/// // Multiplying matrices by matrices
37/// // Mapping a function to a matrix
38/// let r = m.dot(&(m2 * 2.0).map(|x| x * x));
39/// ```
40#[derive(Clone, PartialEq)]
41pub struct Matrix<T> {
42    /// The number of rows that the matrix has
43    rows: usize,
44    /// The number of columns that the matrix has
45    cols: usize,
46    /// Raw vector for the Data contained by the Matrix
47    data: Vec<T>,
48}
49
50impl<T: Debug> Debug for Matrix<T> {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        let mut r = String::new();
53        for i in 0..self.rows {
54            for j in 0..self.cols {
55                r.push_str(&format!("{:?} ", self.data[i * self.cols + j]));
56            }
57
58            r.push_str("\n");
59        }
60
61        write!(f, "{}", r)
62    }
63}
64
65impl<T: Copy> Matrix<T> {
66    /// # Get Data
67    /// Returns a reference to the plain vector holding the raw data of the matrix
68    /// Likely not an useful method
69    pub fn get_data(&self) -> &Vec<T> {
70        &self.data
71    }
72
73    pub fn get_data_mut(&mut self) -> &mut Vec<T> {
74        &mut self.data
75    }
76
77    /// # Get Shape
78    /// Returns a [`Vec<usize>`] representing the shape of the Matrix
79    pub fn get_shape(&self) -> Vec<usize> {
80        vec![self.rows, self.cols]
81    }
82
83    /// # Get row slice
84    /// Returns a row of the matrix as a slice, it's very fast!
85    pub fn get_row_slice(&self, row: usize) -> &[T] {
86        &self.data[(row * self.cols)..(row * self.cols + self.cols)]
87    }
88
89    /// # Get row slice mut
90    /// Returns a mutable reference to a row of the matrix as a slice
91    pub fn get_row_slice_mut(&mut self, row: usize) -> &[T] {
92        &mut self.data[(row * self.cols)..(row * self.cols + self.cols)]
93    }
94
95    /// # Get row
96    /// Returns a reference to a given row of the Matrix as an [Iterator]
97    pub fn get_row(&self, row: usize) -> impl Iterator<Item = &T> {
98        self.data.iter().skip(row * self.cols).take(self.cols)
99    }
100
101    /// # Get row mut
102    /// Returns a mutable reference to a given row of the Matrix as an [Iterator]
103    pub fn get_row_mut(&mut self, row: usize) -> impl Iterator<Item = &mut T> {
104        self.data.iter_mut().skip(row * self.cols).take(self.cols)
105    }
106
107    /// # Get column
108    /// Returns a reference to a given column of the Matrix as an [Iterator]
109    pub fn get_col(&self, col: usize) -> impl Iterator<Item = &T> {
110        self.data.iter().skip(col).step_by(self.cols)
111    }
112
113    /// # Get column mutable
114    /// Returns a mutable reference to a given column of the Matrix as an [Iterator]
115    pub fn get_col_mut(&mut self, col: usize) -> impl Iterator<Item = &mut T> {
116        self.data.iter_mut().skip(col).step_by(self.cols)
117    }
118
119    /// # Get rows
120    /// A simple getter for the number of rows in the Matrix
121    pub fn get_rows(&self) -> usize {
122        self.rows
123    }
124    /// # Get cols
125    /// A simple getter for the number of columns in the Matrix
126    pub fn get_cols(&self) -> usize {
127        self.cols
128    }
129
130    /// # Map
131    /// Returns a copy of the matrix with a function f applied to it
132    pub fn map(&self, f: fn(T) -> T) -> Matrix<T> {
133        Matrix {
134            rows: self.rows,
135            cols: self.cols,
136            data: self.data.iter().copied().map(f).collect::<Vec<_>>(),
137        }
138    }
139
140    /// # Apply
141    /// Mutates the matrix by applying a function F to each element
142    pub fn apply(&mut self, f: fn(T) -> T) {
143        for i in &mut self.data {
144            *i = f(*i);
145        }
146    }
147
148    /// # Reshape
149    /// Returns [MatrixError::InvalidReshape] if the data doesn't fit the new size
150    /// Returns the reshaped matrix otherwise
151    pub fn reshape(&self, new_rows: usize, new_cols: usize) -> Result<Matrix<T>, MatrixError> {
152        if (self.rows * self.cols) != (new_rows * new_cols) {
153            Err(MatrixError::InvalidReshape {
154                numel: self.get_rows() * self.get_cols(),
155                forcing_into: new_rows * new_cols,
156            })
157        } else {
158            Ok(Matrix {
159                rows: new_rows,
160                cols: new_cols,
161                data: self.data.clone(),
162            })
163        }
164    }
165
166    // # Repeat
167    // Replicates the matrix by copying its columns N times
168    pub fn repeat_h(&self, times: usize) -> Matrix<T> {
169        let mut data = vec![];
170        for x in self.data.clone() {
171            for _ in 0..times {
172                data.push(x);
173            }
174        }
175
176        Matrix {
177            rows: self.rows,
178            cols: self.cols * times,
179            data,
180        }
181    }
182}
183
184impl Matrix<f64> {
185    /// # Rand
186    /// Creates a matrix filled with random values in the range [0, 1)
187    /// of a given size
188    pub fn rand(rows: usize, cols: usize) -> Matrix<f64> {
189        let mut r = Matrix::zero(rows, cols);
190
191        for i in 0..r.get_rows() {
192            for j in 0..r.get_cols() {
193                r[(i, j)] = random();
194            }
195        }
196
197        r
198    }
199
200    /// # From array
201    /// Creates a new [Matrix] from a 2D array of any shape
202    /// # WARNING: Uses static dispatch, this is mostly used for a bunch of constants and for
203    /// testing, to not push the limits of it or you will end up with a hige executable
204    pub fn from_array<const R: usize, const C: usize>(arr: [[f64; C]; R]) -> Matrix<f64> {
205        let mut data: Vec<f64> = Vec::new();
206        for row in arr {
207            for element in row {
208                data.push(element);
209            }
210        }
211
212        Matrix {
213            rows: R,
214            cols: C,
215            data,
216        }
217    }
218    /// # Zero
219    /// Creates a new [Matrix] of the given shape and fills it with ZEROS
220    pub fn zero(rows: usize, cols: usize) -> Matrix<f64> {
221        let mut data: Vec<f64> = Vec::new();
222        for _i in 0..rows * cols {
223            data.push(0f64);
224        }
225        Matrix { rows, cols, data }
226    }
227    /// # One
228    /// Creates a new [Matrix] of the given shape and fills it with ONES
229    /// Useful for initializing matrices with dummy values that do not give 0 as a multiplication
230    /// result
231    pub fn one(rows: usize, cols: usize) -> Matrix<f64> {
232        let mut data: Vec<f64> = Vec::new();
233        for _i in 0..rows * cols {
234            data.push(1f64);
235        }
236        Matrix { rows, cols, data }
237    }
238    /// # Fill
239    /// Mutates the matrix by filling it with the given value
240    pub fn fill(&mut self, value: f64) {
241        for item in self.data.iter_mut() {
242            *item = value;
243        }
244    }
245    /// # Dot
246    /// Returns the result of a Matrix multiplication operation -> Dot product
247    pub fn dot(&self, other: &Matrix<f64>) -> Result<Matrix<f64>, MatrixError> {
248        if self.cols != other.get_rows() {
249            return Err(MatrixError::MatMulDimensionsMismatch {
250                size_1: self.get_shape(),
251                size_2: other.get_shape(),
252            });
253        }
254
255        let mut m: Matrix<f64> = Matrix::zero(self.rows, other.cols);
256        for i in 0..self.rows {
257            for j in 0..other.cols {
258                m[(i, j)] = 0f64;
259                m[(i, j)] = vec_dot(
260                    self.get_row(i).copied().collect(),
261                    other.get_col(j).copied().collect(),
262                )
263            }
264        }
265
266        Ok(m)
267    }
268    /// # Transpose
269    /// Returns a copy of the transposed matrix
270    pub fn transpose(&self) -> Matrix<f64> {
271        let mut r = Matrix::zero(self.cols, self.rows);
272        for i in 0..self.rows {
273            for j in 0..self.cols {
274                r[(j, i)] = self[(i, j)];
275            }
276        }
277
278        r
279    }
280}
281
282impl Matrix<Var> {
283    /// # Random
284    /// Creates a new [Matrix] of the given shape and fills it with random values
285    pub fn g_rand(tape: &GradTape, rows: usize, cols: usize) -> Matrix<Var> {
286        let mut data: Vec<Var> = Vec::new();
287        for _i in 0..rows * cols {
288            data.push(tape.var(random()));
289        }
290        Matrix { rows, cols, data }
291    }
292
293    pub fn apply_to_value(&mut self, f: fn(x: Var) -> f64) {
294        for i in &mut self.data {
295            *i.value_mut() = f(*i);
296        }
297    }
298
299    pub fn g_from_array<const R: usize, const C: usize>(
300        tape: &GradTape,
301        arr: [[f64; C]; R],
302    ) -> Matrix<Var> {
303        let mut data: Vec<Var> = Vec::new();
304        for row in arr {
305            for element in row {
306                data.push(tape.var(element));
307            }
308        }
309
310        Matrix {
311            rows: R,
312            cols: C,
313            data,
314        }
315    }
316
317    /// # G Zero
318    /// Creates a new [Matrix] of the given shape and fills it with ZERO [Var]s
319    pub fn g_zero(tape: &GradTape, rows: usize, cols: usize) -> Matrix<Var> {
320        let mut data: Vec<Var> = Vec::new();
321        for _i in 0..rows * cols {
322            data.push(tape.var(0.0));
323        }
324        Matrix { rows, cols, data }
325    }
326
327    /// # G One
328    /// Creates a new [Matrix] of the given shape and fills it with ONE [Var]s
329    pub fn g_one(tape: &GradTape, rows: usize, cols: usize) -> Matrix<Var> {
330        let mut data: Vec<Var> = Vec::new();
331        for _i in 0..rows * cols {
332            data.push(tape.var(1.0));
333        }
334        Matrix { rows, cols, data }
335    }
336
337    /// # G Fill
338    /// Mutates the matrix by filling it with the given value
339    pub fn g_fill(&mut self, tape: &GradTape, value: f64) {
340        for i in 0..self.rows * self.cols {
341            self.data[i] = tape.var(value);
342        }
343    }
344
345    /// # Dot
346    /// Returns the result of a Matrix multiplication operation -> Dot product
347    pub fn g_dot(&self, tape: &GradTape, other: &Matrix<Var>) -> Result<Matrix<Var>, MatrixError> {
348        if self.cols != other.get_rows() {
349            return Err(MatrixError::MatMulDimensionsMismatch {
350                size_1: self.get_shape(),
351                size_2: other.get_shape(),
352            });
353        }
354
355        let mut m: Matrix<Var> = Matrix::g_zero(tape, self.rows, other.cols);
356        for i in 0..self.rows {
357            for j in 0..other.cols {
358                // m[(i, j)] = 0f64;
359                m[(i, j)] = g_vec_dot(
360                    tape,
361                    self.get_row(i).copied().collect(),
362                    other.get_col(j).copied().collect(),
363                )
364            }
365        }
366
367        Ok(m)
368    }
369    /// # Transpose
370    /// Returns a copy of the transposed matrix
371    pub fn transpose(&self, tape: &GradTape) -> Matrix<Var> {
372        let mut r = Matrix::g_zero(tape, self.cols, self.rows);
373        for i in 0..self.rows {
374            for j in 0..self.cols {
375                r[(j, i)] = self[(i, j)];
376            }
377        }
378
379        r
380    }
381
382    /// # Value
383    /// Returns a copy of the matrix with all the variables replaced with their values
384    pub fn value(&self) -> Matrix<f64> {
385        let mut r = Matrix::zero(self.rows, self.cols);
386        for i in 0..self.rows {
387            for j in 0..self.cols {
388                r[(i, j)] = self[(i, j)].value();
389            }
390        }
391
392        r
393    }
394}
395
396impl<T> Index<(usize, usize)> for Matrix<T> {
397    type Output = T;
398    fn index(&self, index: (usize, usize)) -> &Self::Output {
399        &self.data[index.0 * self.cols + index.1]
400    }
401}
402
403impl<T> IndexMut<(usize, usize)> for Matrix<T> {
404    fn index_mut(&mut self, index: (usize, usize)) -> &mut T {
405        &mut self.data[index.0 * self.cols + index.1]
406    }
407}
408
409impl<T: Copy + Clone + Add<U, Output = T>, U: Copy + Add<T>> Add<Matrix<U>> for Matrix<T> {
410    type Output = Matrix<T>;
411    fn add(self, other: Matrix<U>) -> Matrix<T> {
412        let mut data = self.data.clone();
413        for i in 0..data.len() {
414            data[i] = data[i] + other.data[i];
415        }
416
417        Matrix {
418            rows: self.rows,
419            cols: self.cols,
420            data,
421        }
422    }
423}
424
425impl<T: Copy + Clone + Add<Output = T>> Add<&Matrix<T>> for Matrix<T> {
426    type Output = Matrix<T>;
427    fn add(self, other: &Matrix<T>) -> Matrix<T> {
428        let mut data = self.data.clone();
429        for i in 0..data.len() {
430            data[i] = data[i] + other.data[i];
431        }
432
433        Matrix {
434            rows: self.rows,
435            cols: self.cols,
436            data,
437        }
438    }
439}
440
441impl<T: Copy + Clone + Add<f64, Output = T>> Add<f64> for Matrix<T> {
442    type Output = Matrix<T>;
443    fn add(self, other: f64) -> Matrix<T> {
444        let mut data = self.data.clone();
445        for i in 0..data.len() {
446            data[i] = data[i] + other;
447        }
448
449        Matrix {
450            rows: self.rows,
451            cols: self.cols,
452            data,
453        }
454    }
455}
456
457impl<T: Copy + Clone + Sub<Output = T>> Sub<Matrix<T>> for Matrix<T> {
458    type Output = Matrix<T>;
459    fn sub(self, other: Matrix<T>) -> Matrix<T> {
460        let mut data = self.data.clone();
461        for i in 0..data.len() {
462            data[i] = data[i] - other.data[i];
463        }
464
465        Matrix {
466            rows: self.rows,
467            cols: self.cols,
468            data,
469        }
470    }
471}
472impl<T: Copy + Clone + Sub<Output = T>> Sub<&Matrix<T>> for Matrix<T> {
473    type Output = Matrix<T>;
474    fn sub(self, other: &Matrix<T>) -> Matrix<T> {
475        let mut data = self.data.clone();
476        for i in 0..data.len() {
477            data[i] = data[i] - other.data[i];
478        }
479
480        Matrix {
481            rows: self.rows,
482            cols: self.cols,
483            data,
484        }
485    }
486}
487impl<T: Copy + Clone + Sub<f64, Output = T>> Sub<f64> for Matrix<T> {
488    type Output = Matrix<T>;
489    fn sub(self, other: f64) -> Matrix<T> {
490        let mut data = self.data.clone();
491        for i in 0..data.len() {
492            data[i] = data[i] - other;
493        }
494
495        Matrix {
496            rows: self.rows,
497            cols: self.cols,
498            data,
499        }
500    }
501}
502
503impl<T: Copy + Clone + Mul<Output = T>> Mul<Matrix<T>> for Matrix<T> {
504    type Output = Matrix<T>;
505    fn mul(self, other: Matrix<T>) -> Matrix<T> {
506        let mut data = self.data.clone();
507        for i in 0..data.len() {
508            data[i] = data[i] * other.data[i];
509        }
510
511        Matrix {
512            rows: self.rows,
513            cols: self.cols,
514            data,
515        }
516    }
517}
518impl<T: Copy + Clone + Mul<Output = T>> Mul<&Matrix<T>> for Matrix<T> {
519    type Output = Matrix<T>;
520    fn mul(self, other: &Matrix<T>) -> Matrix<T> {
521        let mut data = self.data.clone();
522        for i in 0..data.len() {
523            data[i] = data[i] * other.data[i];
524        }
525
526        Matrix {
527            rows: self.rows,
528            cols: self.cols,
529            data,
530        }
531    }
532}
533impl<T: Copy + Clone + Mul<f64, Output = T>> Mul<f64> for Matrix<T> {
534    type Output = Matrix<T>;
535    fn mul(self, other: f64) -> Matrix<T> {
536        let mut data = self.data.clone();
537        for i in 0..data.len() {
538            data[i] = data[i] * other;
539        }
540
541        Matrix {
542            rows: self.rows,
543            cols: self.cols,
544            data,
545        }
546    }
547}
548impl<T: Copy + Clone + Div<Output = T>> Div<Matrix<T>> for Matrix<T> {
549    type Output = Matrix<T>;
550    fn div(self, other: Matrix<T>) -> Matrix<T> {
551        let mut data = self.data.clone();
552        for i in 0..data.len() {
553            data[i] = data[i] / other.data[i];
554        }
555
556        Matrix {
557            rows: self.rows,
558            cols: self.cols,
559            data,
560        }
561    }
562}
563impl<T: Copy + Clone + Div<Output = T>> Div<&Matrix<T>> for Matrix<T> {
564    type Output = Matrix<T>;
565    fn div(self, other: &Matrix<T>) -> Matrix<T> {
566        let mut data = self.data.clone();
567        for i in 0..data.len() {
568            data[i] = data[i] / other.data[i];
569        }
570
571        Matrix {
572            rows: self.rows,
573            cols: self.cols,
574            data,
575        }
576    }
577}
578impl<T: Copy + Clone + Div<f64, Output = T>> Div<f64> for Matrix<T> {
579    type Output = Matrix<T>;
580    fn div(self, other: f64) -> Matrix<T> {
581        let mut data = self.data.clone();
582        for i in 0..data.len() {
583            data[i] = data[i] / other;
584        }
585
586        Matrix {
587            rows: self.rows,
588            cols: self.cols,
589            data,
590        }
591    }
592}
593
594// Operator overloading for matrix references
595impl<T: Copy + Clone + Add<Output = T>> Add<&Matrix<T>> for &Matrix<T> {
596    type Output = Matrix<T>;
597    fn add(self, other: &Matrix<T>) -> Matrix<T> {
598        let mut data = self.data.clone();
599        for i in 0..data.len() {
600            data[i] = data[i] + other.data[i];
601        }
602
603        Matrix {
604            rows: self.rows,
605            cols: self.cols,
606            data,
607        }
608    }
609}
610
611impl<T: Copy + Clone + Add<Output = T>> Sub<&Matrix<T>> for &Matrix<T> {
612    type Output = Matrix<T>;
613    fn sub(self, other: &Matrix<T>) -> Matrix<T> {
614        let mut data = self.data.clone();
615        for i in 0..data.len() {
616            data[i] = data[i] + other.data[i];
617        }
618
619        Matrix {
620            rows: self.rows,
621            cols: self.cols,
622            data,
623        }
624    }
625}
626
627impl<T: Copy + Clone + Mul<Output = T>> Mul<&Matrix<T>> for &Matrix<T> {
628    type Output = Matrix<T>;
629    fn mul(self, other: &Matrix<T>) -> Matrix<T> {
630        let mut data = self.data.clone();
631        for i in 0..data.len() {
632            data[i] = data[i] * other.data[i];
633        }
634
635        Matrix {
636            rows: self.rows,
637            cols: self.cols,
638            data,
639        }
640    }
641}
642
643impl<T: Copy + Clone + Div<Output = T>> Div<&Matrix<T>> for &Matrix<T> {
644    type Output = Matrix<T>;
645    fn div(self, other: &Matrix<T>) -> Matrix<T> {
646        let mut data = self.data.clone();
647        for i in 0..data.len() {
648            data[i] = data[i] / other.data[i];
649        }
650
651        Matrix {
652            rows: self.rows,
653            cols: self.cols,
654            data,
655        }
656    }
657}
658
659impl<T: Copy + Clone + Add<f64, Output = T>> Add<f64> for &Matrix<T> {
660    type Output = Matrix<T>;
661    fn add(self, other: f64) -> Matrix<T> {
662        let mut data = self.data.clone();
663        for i in 0..data.len() {
664            data[i] = data[i] + other;
665        }
666
667        Matrix {
668            rows: self.rows,
669            cols: self.cols,
670            data,
671        }
672    }
673}
674
675impl<T: Copy + Clone + Sub<f64, Output = T>> Sub<f64> for &Matrix<T> {
676    type Output = Matrix<T>;
677    fn sub(self, other: f64) -> Matrix<T> {
678        let mut data = self.data.clone();
679        for i in 0..data.len() {
680            data[i] = data[i] - other;
681        }
682
683        Matrix {
684            rows: self.rows,
685            cols: self.cols,
686            data,
687        }
688    }
689}
690
691impl<T: Copy + Clone + Mul<f64, Output = T>> Mul<f64> for &Matrix<T> {
692    type Output = Matrix<T>;
693    fn mul(self, other: f64) -> Matrix<T> {
694        let mut data = self.data.clone();
695        for i in 0..data.len() {
696            data[i] = data[i] * other;
697        }
698
699        Matrix {
700            rows: self.rows,
701            cols: self.cols,
702            data,
703        }
704    }
705}
706
707impl<T: Copy + Clone + Div<f64, Output = T>> Div<f64> for &Matrix<T> {
708    type Output = Matrix<T>;
709    fn div(self, other: f64) -> Matrix<T> {
710        let mut data = self.data.clone();
711        for i in 0..data.len() {
712            data[i] = data[i] / other;
713        }
714
715        Matrix {
716            rows: self.rows,
717            cols: self.cols,
718            data,
719        }
720    }
721}
722
723fn vec_dot(v1: Vec<f64>, v2: Vec<f64>) -> f64 {
724    // print!("{:?} . {:?} ", v1, v2);
725    let mut r = 0.0;
726
727    let len = v1.len();
728
729    for i in 0..len {
730        r = r + v1[i] * v2[i];
731    }
732    // println!("= {}", r);
733    r
734}
735
736fn g_vec_dot(tape: &GradTape, v1: Vec<Var>, v2: Vec<Var>) -> Var {
737    // print!("{:?} . {:?} ", v1, v2);
738    let mut r = tape.var(0.0);
739
740    let len = v1.len();
741    for i in 0..len {
742        r = r + v1[i] * v2[i];
743    }
744    // println!(" = {}", r);
745    r
746}
747
748#[cfg(test)]
749mod tests {
750    use super::*;
751    #[test]
752    fn create_matrix() {
753        let _m = Matrix::zero(2, 3);
754    }
755
756    #[test]
757    fn get_row() {
758        let mut m = Matrix::from_array([[1f64, 2f64, 3f64, 4f64], [5f64, 6f64, 7f64, 8f64]]);
759
760        let row: Vec<f64> = m.get_row(0).copied().collect::<Vec<f64>>();
761
762        assert_eq!(row, vec![1f64, 2f64, 3f64, 4f64]);
763        *m.get_row_mut(1).nth(1).unwrap() = 100f64;
764
765        let row: Vec<f64> = m.get_row(1).copied().collect::<Vec<f64>>();
766        assert_eq!(row, vec![5f64, 100f64, 7f64, 8f64]);
767    }
768
769    #[test]
770    fn get_col() {
771        let mut m = Matrix::from_array([[1f64, 2f64, 3f64, 4f64], [5f64, 6f64, 7f64, 8f64]]);
772
773        let col: Vec<f64> = m.get_col(0).copied().collect::<Vec<f64>>();
774
775        assert_eq!(col, vec![1f64, 5f64]);
776
777        *m.get_col_mut(1).nth(1).unwrap() = 100f64;
778
779        let row: Vec<f64> = m.get_col(1).copied().collect::<Vec<f64>>();
780        assert_eq!(row, vec![2f64, 100f64]);
781    }
782
783    #[test]
784    fn map() {
785        let m = Matrix::from_array([[1f64, 2f64, 3f64], [4f64, 5f64, 6f64]]);
786
787        assert_eq!(
788            m.map(|x| x + 2.0f64),
789            Matrix::from_array([[3f64, 4f64, 5f64], [6f64, 7f64, 8f64]])
790        );
791    }
792
793    #[test]
794    fn apply() {
795        let mut m = Matrix::from_array([[1f64, 2f64, 3f64], [4f64, 5f64, 6f64]]);
796
797        m.apply(|x| x + 2.0f64);
798
799        assert_eq!(
800            m,
801            Matrix::from_array([[3f64, 4f64, 5f64], [6f64, 7f64, 8f64]])
802        );
803    }
804
805    #[test]
806    fn reshape() {
807        let m = Matrix::from_array([[1f64, 2f64], [3f64, 4f64], [5f64, 6f64]]);
808
809        assert_eq!(
810            m.reshape(2, 3).unwrap(),
811            Matrix::from_array([[1f64, 2f64, 3f64], [4f64, 5f64, 6f64]])
812        )
813    }
814    #[test]
815    fn transpose() {
816        let m = Matrix::from_array([[1f64, 2f64], [3f64, 4f64], [5f64, 6f64]]);
817
818        assert_eq!(
819            m.transpose(),
820            Matrix::from_array([[1f64, 3f64, 5f64], [2f64, 4f64, 6f64]])
821        )
822    }
823
824    #[test]
825    fn dot() {
826        let m1 = Matrix::from_array([[1f64, 2f64, 3f64], [4f64, 5f64, 6f64]]);
827        let m2 = Matrix::from_array([[1f64, 2f64], [3f64, 4f64], [5f64, 6f64]]);
828
829        assert_eq!(
830            m1.dot(&m2).unwrap(),
831            Matrix::from_array([[22f64, 28f64], [49f64, 64f64]])
832        );
833    }
834
835    #[test]
836    fn basic_matrix_differentiation() {
837        let t = GradTape::new();
838        let mut m1 = Matrix::g_from_array(&t, [[1f64, 2f64, 3f64], [4f64, 5f64, 6f64]]);
839        let mut m2 = Matrix::g_from_array(&t, [[1f64, 2f64, 3f64], [4f64, 5f64, 6f64]]);
840
841        let m3 = &m1 * &m2;
842
843        let grad = m3[(1, 2)].backward();
844
845        assert_eq!(grad[&m2[(1, 2)]], m1[(1, 2)].value());
846
847        t.clear({
848            let mut r = Vec::new();
849            for x in m1.get_data_mut() {
850                r.push(x);
851            }
852            for x in m2.get_data_mut() {
853                r.push(x);
854            }
855
856            r
857        });
858        let m3 = &m1 * &m2;
859        let grad = m3[(1, 2)].backward();
860        assert_eq!(grad[&m2[(1, 2)]], m1[(1, 2)].value());
861    }
862}