brique/
matrix.rs

1use rand::rng;
2use rand_distr::{Distribution, Normal};
3
4// enum DataType {
5//     F32(Vec<f32>),
6//     F64(Vec<f32>),
7// }
8
9#[derive(Clone)]
10pub struct Matrix {
11    pub data: Vec<f64>,
12    pub width: usize,
13    pub height: usize,
14    pub transposed: bool,
15}
16
17impl Matrix {
18    pub fn init_zero(height: usize, width: usize) -> Matrix {
19        Matrix {
20            data: vec![0.0; width * height],
21            width,
22            height,
23            transposed: false,
24        }
25    }
26
27    pub fn init(height: usize, width: usize, data: Vec<f64>) -> Matrix {
28        assert_eq!(
29            height * width,
30            data.len(),
31            "Error while initiating a matrix with data : 
32                   not compatible with the dimension"
33        );
34
35        Matrix {
36            data,
37            width,
38            height,
39            transposed: false,
40        }
41    }
42
43    pub fn init_rand(height: usize, width: usize) -> Matrix {
44        let std_dev = (2.0 / height as f64).sqrt();
45        let normal = Normal::new(0.0, std_dev).unwrap();
46        let mut rng = rng();
47
48        normal.sample(&mut rand::rng());
49        let rand_vec = (0..height * width)
50            .map(|_| normal.sample(&mut rng))
51            .collect();
52
53        Matrix {
54            data: rand_vec,
55            width,
56            height,
57            transposed: false,
58        }
59    }
60
61    pub fn get(&self, row: usize, column: usize) -> f64 {
62        assert!(row < self.height, "Error while accessing matrix data : row greater or equal to height, out of bound index");
63        assert!(column < self.width, "Error while accessing matrix data : column greater or equal to width, out of bound index");
64
65        if !self.transposed {
66            self.data[row * self.width + column]
67        } else {
68            self.data[column * self.height + row]
69        }
70    }
71
72    // access to underlying one dimensional Vec
73    pub fn get_1d(&self, index: usize) -> f64 {
74        assert!(
75            index < self.data.len(),
76            "Error while accessing matrix data : index greater than vec size, out of bound index"
77        );
78
79        self.data[index]
80    }
81
82    pub fn get_row(&self, row: usize) -> Vec<f64> {
83        assert!(row < self.height, "Error while accessing matrix data : row greater or equal to height, out of bound index");
84
85        let mut output: Vec<f64> = Vec::new();
86
87        for i in 0..self.width {
88            output.push(self.get(row, i));
89        }
90
91        output
92    }
93
94    pub fn set(&mut self, value: f64, row: usize, column: usize) {
95        assert!(row < self.height, "Error while modifying matrix data : row greater or equal to height, out of bound index");
96        assert!(column < self.width, "Error while modifying matrix data : column greater or equal to width, out of bound index");
97
98        if !self.transposed {
99            self.data[row * self.width + column] = value;
100        } else {
101            self.data[column * self.height + row] = value;
102        }
103    }
104
105    // access to underlying one dimensional Vec
106    pub fn set_1d(&mut self, value: f64, index: usize) {
107        assert!(
108            index < self.data.len(),
109            "Error while accessing matrix data : index greater than vec size, out of bound index"
110        );
111
112        self.data[index] = value;
113    }
114
115    pub fn set_row(&mut self, new_row: &Vec<f64>, row: usize) {
116        assert!(row < self.height, "Error while accessing matrix data : row greater or equal to height, out of bound index");
117
118        for i in 0..self.width {
119            self.set(new_row[i], row, i);
120        }
121    }
122
123    pub fn dot(&self, m: &Matrix) -> Matrix {
124        let mut res: Matrix = Matrix::init_zero(self.height, m.width);
125        assert_eq!(self.width, m.height, "Error while doing a dot product: Dimension incompatibility, width of vec 1 : {}, height of vec 2 : {}", self.width, m.height);
126        for c in 0..m.width {
127            for r in 0..self.height {
128                let mut tmp: f64 = 0.0;
129                for a in 0..self.width {
130                    tmp = tmp + self.get(r, a) * m.get(a, c);
131                }
132                res.set(tmp, r, c);
133            }
134        }
135        res
136    }
137
138    // adds a matrix of X width and 1 height to a matrix of Y height and X width
139    pub fn add_1d_matrix_to_all_rows(&self, m: &Matrix) -> Matrix {
140        assert_eq!(m.height, 1, "The input matrix should have a height of 1");
141        assert_eq!(
142            m.width, self.width,
143            "The 2 matrices should have the same width"
144        );
145
146        let output_vec: Vec<f64> = (0..self.height * self.width)
147            .map(|i| self.data[i] + m.get(0, i % self.width))
148            .collect();
149
150        Matrix {
151            data: output_vec,
152            width: self.width,
153            height: self.height,
154            transposed: false,
155        }
156    }
157
158    pub fn max(&self) -> f64 {
159        *self.data.iter().max_by(|a, b| a.total_cmp(b)).unwrap()
160    }
161
162    pub fn min(&self) -> f64 {
163        *self.data.iter().min_by(|a, b| a.total_cmp(b)).unwrap()
164    }
165
166    // ! IN PLACE !
167    pub fn normalize(&mut self) {
168        // get the maximum
169        let max: f64 = self.max();
170        let min: f64 = self.min();
171
172        self.data = self.data.iter().map(|x| (x - min) / (max - min)).collect();
173    }
174
175    // transpose ! IN PLACE !
176    pub fn transpose_inplace(&mut self) {
177        self.transposed = !self.transposed;
178        let tmp: usize = self.width;
179        self.width = self.height;
180        self.height = tmp;
181    }
182
183    pub fn t(&self) -> Matrix {
184        let mut output = self.clone();
185        output.transpose_inplace();
186        output
187    }
188
189    // used for test
190    pub fn is_equal(&self, m: &Matrix, precision: i32) -> bool {
191        if self.width != m.width || self.height != m.height || self.transposed != m.transposed {
192            return false;
193        } else {
194            for i in 0..self.height * self.width {
195                let mut a: f64 = self.data[i] * 10_f64.powi(precision);
196                a = a.round() / 10_f64.powi(precision);
197
198                let mut b: f64 = m.data[i] * 10_f64.powi(precision);
199                b = b.round() / 10_f64.powi(precision);
200
201                if a != b {
202                    return false;
203                }
204            }
205        }
206        true
207    }
208
209    pub fn exp_inplace(&mut self) {
210        self.data = self.data.iter().map(|x| x.exp()).collect();
211    }
212
213    pub fn sqrt_inplace(&mut self) {
214        self.data = self
215            .data
216            .iter()
217            .map(|x| {
218                assert!(
219                    *x >= 0.0,
220                    "Trying to square root a negative value in a matrix error"
221                );
222                x.sqrt()
223            })
224            .collect();
225    }
226
227    pub fn exp(&self) -> Matrix {
228        let mut output = self.clone();
229        output.exp_inplace();
230        output
231    }
232
233    pub fn pow_inplace(&mut self, a: i32) {
234        self.data = self.data.iter().map(|x| x.powi(a)).collect();
235    }
236
237    pub fn pow(&self, a: i32) -> Matrix {
238        let mut output: Matrix = self.clone();
239        output.pow_inplace(a);
240
241        output
242    }
243
244    pub fn sum(&self) -> f64 {
245        self.data.iter().sum()
246    }
247
248    pub fn sum_rows(&self) -> Matrix {
249        let mut output: Matrix = Matrix::init_zero(1, self.width);
250
251        self.data
252            .iter()
253            .enumerate()
254            .for_each(|(index, value)| output.data[index % self.width] += value);
255
256        output
257    }
258
259    pub fn add_inplace(&mut self, a: f64) {
260        self.data = self.data.iter().map(|x| x + a).collect();
261    }
262
263    pub fn div_inplace(&mut self, a: f64) {
264        assert_ne!(a, 0.0, "Divide by 0 matrix error");
265        self.data = self.data.iter().map(|x| x / a).collect();
266    }
267
268    pub fn div(&self, a: f64) -> Matrix {
269        let mut output: Matrix = self.clone();
270        output.div_inplace(a);
271        output
272    }
273
274    pub fn mult_inplace(&mut self, a: f64) {
275        self.data = self.data.iter().map(|x| x * a).collect();
276    }
277
278    pub fn mult(&self, a: f64) -> Matrix {
279        let mut output: Matrix = self.clone();
280        output.mult_inplace(a);
281        output
282    }
283
284    pub fn add_two_matrices(&self, m: &Matrix) -> Matrix {
285        assert!(
286            self.height == m.height && self.width == m.width,
287            "The two matrices should have the same dimensions"
288        );
289        let output_vec: Vec<f64> = (0..self.height * self.width)
290            .map(|i| self.data[i] + m.data[i])
291            .collect();
292
293        Matrix {
294            data: output_vec,
295            width: self.width,
296            height: self.height,
297            transposed: false,
298        }
299    }
300
301    pub fn add_two_matrices_inplace(&mut self, m: &Matrix) {
302        assert!(
303            self.height == m.height && self.width == m.width,
304            "The two matrices should have the same dimensions"
305        );
306
307        self.data = self
308            .data
309            .iter()
310            .enumerate()
311            .map(|(i, val)| val + m.data[i])
312            .collect();
313    }
314
315    pub fn div_two_matrices_inplace(&mut self, m: &Matrix) {
316        assert!(
317            self.height == m.height && self.width == m.width,
318            "The two matrices should have the same dimensions"
319        );
320
321        self.data = self
322            .data
323            .iter()
324            .enumerate()
325            .map(|(i, val)| {
326                assert_ne!(m.data[i], 0.0, "Divide by 0 error in matrix to matrix div");
327                val / m.data[i]
328            })
329            .collect();
330    }
331
332    pub fn pop_last_row(&mut self) {
333        let begin_index = self.height * (self.width - 1);
334        let last_index = self.height * self.width;
335
336        for _i in begin_index..last_index {
337            self.data.pop();
338        }
339
340        self.height -= 1;
341    }
342
343    pub fn compute_d_relu_inplace(&mut self, z_minus_1: &Matrix) {
344        self.data = self
345            .data
346            .iter()
347            .enumerate()
348            .map(|(i, v)| if z_minus_1.data[i] <= 0.0 { 0.0 } else { *v })
349            .collect();
350    }
351
352    pub fn display(&self) {
353        print!("\n");
354        print!("-------------");
355        print!("\n");
356        for i in 0..self.height {
357            for j in 0..self.width {
358                print!(" {} |", self.get(i, j));
359            }
360            print!("/ \n");
361        }
362        print!("-------------");
363        print!("\n");
364    }
365
366    pub fn convert_to_csv(&self) -> String {
367        let mut output: String = String::new();
368        for i in 0..self.height {
369            for j in 0..self.width {
370                output.push_str(&self.get(i, j).to_string());
371                output.push(',');
372            }
373            output.push('\n');
374        }
375
376        output
377    }
378}
379
380//unit test
381#[cfg(test)]
382mod tests {
383    use crate::parse_test_csv::parse_test_csv;
384
385    use super::Matrix;
386
387    fn get_test_matrix() -> Matrix {
388        let matrix = Matrix::init(2, 3, vec![0.1, 1.3, 0.5, 12.0, 1.01, -1000.0]);
389
390        matrix
391    }
392
393    #[test]
394    fn valid_get() {
395        let matrix = get_test_matrix();
396
397        assert_eq!(matrix.get(0, 0), 0.1);
398        assert_eq!(matrix.get(0, 1), 1.3);
399        assert_eq!(matrix.get(0, 2), 0.5);
400        assert_eq!(matrix.get(1, 0), 12.0);
401        assert_eq!(matrix.get(1, 1), 1.01);
402        assert_eq!(matrix.get(1, 2), -1000.0);
403    }
404
405    #[test]
406    fn valid_get_on_transposed() {
407        let mut matrix = get_test_matrix();
408        matrix.transpose_inplace();
409
410        assert_eq!(matrix.get(0, 0), 0.1);
411        assert_eq!(matrix.get(0, 1), 12.0);
412        assert_eq!(matrix.get(1, 0), 1.3);
413        assert_eq!(matrix.get(1, 1), 1.01);
414        assert_eq!(matrix.get(2, 0), 0.5);
415        assert_eq!(matrix.get(2, 1), -1000.0);
416    }
417
418    #[test]
419    fn valid_get_on_untransposed() {
420        let mut matrix = get_test_matrix();
421        matrix.transpose_inplace();
422        matrix.transpose_inplace();
423
424        assert_eq!(matrix.get(0, 0), 0.1);
425        assert_eq!(matrix.get(0, 1), 1.3);
426        assert_eq!(matrix.get(0, 2), 0.5);
427        assert_eq!(matrix.get(1, 0), 12.0);
428        assert_eq!(matrix.get(1, 1), 1.01);
429        assert_eq!(matrix.get(1, 2), -1000.0);
430    }
431
432    #[test]
433    #[should_panic]
434    fn unvalid_get_column_out_of_bound() {
435        let matrix = get_test_matrix();
436
437        matrix.get(2, 0);
438    }
439
440    #[test]
441    #[should_panic]
442    fn unvalid_get_row_out_of_bound() {
443        let matrix = get_test_matrix();
444
445        matrix.get(5, 1);
446    }
447
448    #[test]
449    #[should_panic]
450    fn unvalid_get_tranposed_column_out_of_bound() {
451        let mut matrix = get_test_matrix();
452        matrix.transpose_inplace();
453
454        matrix.get(0, 2);
455    }
456
457    #[test]
458    #[should_panic]
459    fn unvalid_get_transposed_row_out_of_bound() {
460        let mut matrix = get_test_matrix();
461        matrix.transpose_inplace();
462
463        matrix.get(3, 1);
464    }
465
466    #[test]
467    fn valid_get_row() {
468        let matrix = get_test_matrix();
469        let expected_vec = vec![12.0, 1.01, -1000.0];
470
471        assert_eq![matrix.get_row(1), expected_vec];
472    }
473
474    #[test]
475    fn valid_get_row_on_transposed() {
476        let mut matrix = get_test_matrix();
477        let expected_vec = vec![0.5, -1000.0];
478        matrix.transpose_inplace();
479
480        assert_eq![matrix.get_row(2), expected_vec];
481    }
482
483    #[test]
484    fn valid_set() {
485        let mut matrix = get_test_matrix();
486        matrix.set(69.69, 1, 1);
487
488        assert_eq![matrix.data[4], 69.69];
489    }
490
491    #[test]
492    #[should_panic]
493    fn unvalid_set_column_out_of_bound() {
494        let mut matrix = get_test_matrix();
495
496        matrix.set(69.69, 2, 0);
497    }
498
499    #[test]
500    #[should_panic]
501    fn unvalid_set_row_out_of_bound() {
502        let mut matrix = get_test_matrix();
503
504        matrix.set(69.69, 5, 1);
505    }
506
507    #[test]
508    #[should_panic]
509    fn unvalid_set_tranposed_column_out_of_bound() {
510        let mut matrix = get_test_matrix();
511        matrix.transpose_inplace();
512
513        matrix.set(69.69, 0, 2);
514    }
515
516    #[test]
517    #[should_panic]
518    fn unvalid_set_transposed_row_out_of_bound() {
519        let mut matrix = get_test_matrix();
520        matrix.transpose_inplace();
521
522        matrix.set(69.69, 3, 1);
523    }
524
525    #[test]
526    fn valid_set_row() {
527        let mut matrix = get_test_matrix();
528        let new_row = vec![0.8, 0.1, 1203123.0];
529
530        matrix.set_row(&new_row, 0);
531
532        assert_eq![matrix.get_row(0), new_row];
533    }
534
535    #[test]
536    fn valid_set_row_on_transposed() {
537        let mut matrix = get_test_matrix();
538        let new_row = vec![0.8, 0.1];
539
540        matrix.transpose_inplace();
541        matrix.set_row(&new_row, 2);
542
543        assert_eq![matrix.get_row(2), new_row];
544    }
545
546    #[test]
547    fn max_test() {
548        let matrix = get_test_matrix();
549
550        assert_eq![matrix.max(), 12.0];
551    }
552
553    #[test]
554    fn min_test() {
555        let matrix = get_test_matrix();
556
557        assert_eq![matrix.min(), -1000.0];
558    }
559
560    #[test]
561    fn add_values_of_a_row_test() {
562        let test_data = parse_test_csv("tests/test_data/add_values_of_a_row_test.csv".to_string());
563
564        assert!(test_data[0]
565            .add_1d_matrix_to_all_rows(&test_data[1])
566            .is_equal(&test_data[2], 10));
567    }
568
569    #[test]
570    fn div_two_matrices_test() {
571        let test_data = parse_test_csv("tests/test_data/div_two_matrices_test.csv".to_string());
572        let mut m1 = test_data[0].clone();
573
574        m1.div_two_matrices_inplace(&test_data[1]);
575        assert!(m1.is_equal(&test_data[2], 10));
576    }
577
578    #[test]
579    #[should_panic]
580    fn unvalid_div_by_0_div_two_matrices_test() {
581        let test_data = parse_test_csv("tests/test_data/div_two_matrices_test.csv".to_string());
582        let mut m1 = test_data[0].clone();
583        let mut m2 = test_data[1].clone();
584        m2.set(0.0, 1, 1);
585        m1.div_two_matrices_inplace(&m2);
586    }
587
588    #[test]
589    fn sqrt_test() {
590        let test_data = parse_test_csv("tests/test_data/sqrt_test.csv".to_string());
591        let mut m1 = test_data[0].clone();
592
593        m1.sqrt_inplace();
594        assert!(m1.is_equal(&test_data[1], 10));
595    }
596
597    #[test]
598    #[should_panic]
599    fn unvalid_sqrt_negavtive_value() {
600        let test_data = parse_test_csv("tests/test_data/div_two_matrices_test.csv".to_string());
601        let mut m1 = test_data[0].clone();
602        m1.set(-1.0, 1, 1);
603        m1.sqrt_inplace();
604    }
605
606    #[test]
607    fn dot_product_test() {
608        let test_data = parse_test_csv("tests/test_data/dot_product_test.csv".to_string());
609
610        assert!(test_data[0].dot(&test_data[1]).is_equal(&test_data[2], 8));
611    }
612
613    #[test]
614    fn normalize_test() {
615        let mut test_data = parse_test_csv("tests/test_data/normalize_test.csv".to_string());
616
617        test_data[0].normalize();
618
619        assert!(test_data[0].is_equal(&test_data[1], 8));
620    }
621}