1use rand::rng;
2use rand_distr::{Distribution, Normal};
3
4#[derive(Clone)]
10pub struct Matrix {
11 pub data: Vec<f64>,
12 pub width: usize,
13 pub height: usize,
14 pub transposed: bool,
15}
16
17impl Matrix {
18 pub fn init_zero(height: usize, width: usize) -> Matrix {
19 Matrix {
20 data: vec![0.0; width * height],
21 width,
22 height,
23 transposed: false,
24 }
25 }
26
27 pub fn init(height: usize, width: usize, data: Vec<f64>) -> Matrix {
28 assert_eq!(
29 height * width,
30 data.len(),
31 "Error while initiating a matrix with data :
32 not compatible with the dimension"
33 );
34
35 Matrix {
36 data,
37 width,
38 height,
39 transposed: false,
40 }
41 }
42
43 pub fn init_rand(height: usize, width: usize) -> Matrix {
44 let std_dev = (2.0 / height as f64).sqrt();
45 let normal = Normal::new(0.0, std_dev).unwrap();
46 let mut rng = rng();
47
48 normal.sample(&mut rand::rng());
49 let rand_vec = (0..height * width)
50 .map(|_| normal.sample(&mut rng))
51 .collect();
52
53 Matrix {
54 data: rand_vec,
55 width,
56 height,
57 transposed: false,
58 }
59 }
60
61 pub fn get(&self, row: usize, column: usize) -> f64 {
62 assert!(row < self.height, "Error while accessing matrix data : row greater or equal to height, out of bound index");
63 assert!(column < self.width, "Error while accessing matrix data : column greater or equal to width, out of bound index");
64
65 if !self.transposed {
66 self.data[row * self.width + column]
67 } else {
68 self.data[column * self.height + row]
69 }
70 }
71
72 pub fn get_1d(&self, index: usize) -> f64 {
74 assert!(
75 index < self.data.len(),
76 "Error while accessing matrix data : index greater than vec size, out of bound index"
77 );
78
79 self.data[index]
80 }
81
82 pub fn get_row(&self, row: usize) -> Vec<f64> {
83 assert!(row < self.height, "Error while accessing matrix data : row greater or equal to height, out of bound index");
84
85 let mut output: Vec<f64> = Vec::new();
86
87 for i in 0..self.width {
88 output.push(self.get(row, i));
89 }
90
91 output
92 }
93
94 pub fn set(&mut self, value: f64, row: usize, column: usize) {
95 assert!(row < self.height, "Error while modifying matrix data : row greater or equal to height, out of bound index");
96 assert!(column < self.width, "Error while modifying matrix data : column greater or equal to width, out of bound index");
97
98 if !self.transposed {
99 self.data[row * self.width + column] = value;
100 } else {
101 self.data[column * self.height + row] = value;
102 }
103 }
104
105 pub fn set_1d(&mut self, value: f64, index: usize) {
107 assert!(
108 index < self.data.len(),
109 "Error while accessing matrix data : index greater than vec size, out of bound index"
110 );
111
112 self.data[index] = value;
113 }
114
115 pub fn set_row(&mut self, new_row: &Vec<f64>, row: usize) {
116 assert!(row < self.height, "Error while accessing matrix data : row greater or equal to height, out of bound index");
117
118 for i in 0..self.width {
119 self.set(new_row[i], row, i);
120 }
121 }
122
123 pub fn dot(&self, m: &Matrix) -> Matrix {
124 let mut res: Matrix = Matrix::init_zero(self.height, m.width);
125 assert_eq!(self.width, m.height, "Error while doing a dot product: Dimension incompatibility, width of vec 1 : {}, height of vec 2 : {}", self.width, m.height);
126 for c in 0..m.width {
127 for r in 0..self.height {
128 let mut tmp: f64 = 0.0;
129 for a in 0..self.width {
130 tmp = tmp + self.get(r, a) * m.get(a, c);
131 }
132 res.set(tmp, r, c);
133 }
134 }
135 res
136 }
137
138 pub fn add_1d_matrix_to_all_rows(&self, m: &Matrix) -> Matrix {
140 assert_eq!(m.height, 1, "The input matrix should have a height of 1");
141 assert_eq!(
142 m.width, self.width,
143 "The 2 matrices should have the same width"
144 );
145
146 let output_vec: Vec<f64> = (0..self.height * self.width)
147 .map(|i| self.data[i] + m.get(0, i % self.width))
148 .collect();
149
150 Matrix {
151 data: output_vec,
152 width: self.width,
153 height: self.height,
154 transposed: false,
155 }
156 }
157
158 pub fn max(&self) -> f64 {
159 *self.data.iter().max_by(|a, b| a.total_cmp(b)).unwrap()
160 }
161
162 pub fn min(&self) -> f64 {
163 *self.data.iter().min_by(|a, b| a.total_cmp(b)).unwrap()
164 }
165
166 pub fn normalize(&mut self) {
168 let max: f64 = self.max();
170 let min: f64 = self.min();
171
172 self.data = self.data.iter().map(|x| (x - min) / (max - min)).collect();
173 }
174
175 pub fn transpose_inplace(&mut self) {
177 self.transposed = !self.transposed;
178 let tmp: usize = self.width;
179 self.width = self.height;
180 self.height = tmp;
181 }
182
183 pub fn t(&self) -> Matrix {
184 let mut output = self.clone();
185 output.transpose_inplace();
186 output
187 }
188
189 pub fn is_equal(&self, m: &Matrix, precision: i32) -> bool {
191 if self.width != m.width || self.height != m.height || self.transposed != m.transposed {
192 return false;
193 } else {
194 for i in 0..self.height * self.width {
195 let mut a: f64 = self.data[i] * 10_f64.powi(precision);
196 a = a.round() / 10_f64.powi(precision);
197
198 let mut b: f64 = m.data[i] * 10_f64.powi(precision);
199 b = b.round() / 10_f64.powi(precision);
200
201 if a != b {
202 return false;
203 }
204 }
205 }
206 true
207 }
208
209 pub fn exp_inplace(&mut self) {
210 self.data = self.data.iter().map(|x| x.exp()).collect();
211 }
212
213 pub fn sqrt_inplace(&mut self) {
214 self.data = self
215 .data
216 .iter()
217 .map(|x| {
218 assert!(
219 *x >= 0.0,
220 "Trying to square root a negative value in a matrix error"
221 );
222 x.sqrt()
223 })
224 .collect();
225 }
226
227 pub fn exp(&self) -> Matrix {
228 let mut output = self.clone();
229 output.exp_inplace();
230 output
231 }
232
233 pub fn pow_inplace(&mut self, a: i32) {
234 self.data = self.data.iter().map(|x| x.powi(a)).collect();
235 }
236
237 pub fn pow(&self, a: i32) -> Matrix {
238 let mut output: Matrix = self.clone();
239 output.pow_inplace(a);
240
241 output
242 }
243
244 pub fn sum(&self) -> f64 {
245 self.data.iter().sum()
246 }
247
248 pub fn sum_rows(&self) -> Matrix {
249 let mut output: Matrix = Matrix::init_zero(1, self.width);
250
251 self.data
252 .iter()
253 .enumerate()
254 .for_each(|(index, value)| output.data[index % self.width] += value);
255
256 output
257 }
258
259 pub fn add_inplace(&mut self, a: f64) {
260 self.data = self.data.iter().map(|x| x + a).collect();
261 }
262
263 pub fn div_inplace(&mut self, a: f64) {
264 assert_ne!(a, 0.0, "Divide by 0 matrix error");
265 self.data = self.data.iter().map(|x| x / a).collect();
266 }
267
268 pub fn div(&self, a: f64) -> Matrix {
269 let mut output: Matrix = self.clone();
270 output.div_inplace(a);
271 output
272 }
273
274 pub fn mult_inplace(&mut self, a: f64) {
275 self.data = self.data.iter().map(|x| x * a).collect();
276 }
277
278 pub fn mult(&self, a: f64) -> Matrix {
279 let mut output: Matrix = self.clone();
280 output.mult_inplace(a);
281 output
282 }
283
284 pub fn add_two_matrices(&self, m: &Matrix) -> Matrix {
285 assert!(
286 self.height == m.height && self.width == m.width,
287 "The two matrices should have the same dimensions"
288 );
289 let output_vec: Vec<f64> = (0..self.height * self.width)
290 .map(|i| self.data[i] + m.data[i])
291 .collect();
292
293 Matrix {
294 data: output_vec,
295 width: self.width,
296 height: self.height,
297 transposed: false,
298 }
299 }
300
301 pub fn add_two_matrices_inplace(&mut self, m: &Matrix) {
302 assert!(
303 self.height == m.height && self.width == m.width,
304 "The two matrices should have the same dimensions"
305 );
306
307 self.data = self
308 .data
309 .iter()
310 .enumerate()
311 .map(|(i, val)| val + m.data[i])
312 .collect();
313 }
314
315 pub fn div_two_matrices_inplace(&mut self, m: &Matrix) {
316 assert!(
317 self.height == m.height && self.width == m.width,
318 "The two matrices should have the same dimensions"
319 );
320
321 self.data = self
322 .data
323 .iter()
324 .enumerate()
325 .map(|(i, val)| {
326 assert_ne!(m.data[i], 0.0, "Divide by 0 error in matrix to matrix div");
327 val / m.data[i]
328 })
329 .collect();
330 }
331
332 pub fn pop_last_row(&mut self) {
333 let begin_index = self.height * (self.width - 1);
334 let last_index = self.height * self.width;
335
336 for _i in begin_index..last_index {
337 self.data.pop();
338 }
339
340 self.height -= 1;
341 }
342
343 pub fn compute_d_relu_inplace(&mut self, z_minus_1: &Matrix) {
344 self.data = self
345 .data
346 .iter()
347 .enumerate()
348 .map(|(i, v)| if z_minus_1.data[i] <= 0.0 { 0.0 } else { *v })
349 .collect();
350 }
351
352 pub fn display(&self) {
353 print!("\n");
354 print!("-------------");
355 print!("\n");
356 for i in 0..self.height {
357 for j in 0..self.width {
358 print!(" {} |", self.get(i, j));
359 }
360 print!("/ \n");
361 }
362 print!("-------------");
363 print!("\n");
364 }
365
366 pub fn convert_to_csv(&self) -> String {
367 let mut output: String = String::new();
368 for i in 0..self.height {
369 for j in 0..self.width {
370 output.push_str(&self.get(i, j).to_string());
371 output.push(',');
372 }
373 output.push('\n');
374 }
375
376 output
377 }
378}
379
380#[cfg(test)]
382mod tests {
383 use crate::parse_test_csv::parse_test_csv;
384
385 use super::Matrix;
386
387 fn get_test_matrix() -> Matrix {
388 let matrix = Matrix::init(2, 3, vec![0.1, 1.3, 0.5, 12.0, 1.01, -1000.0]);
389
390 matrix
391 }
392
393 #[test]
394 fn valid_get() {
395 let matrix = get_test_matrix();
396
397 assert_eq!(matrix.get(0, 0), 0.1);
398 assert_eq!(matrix.get(0, 1), 1.3);
399 assert_eq!(matrix.get(0, 2), 0.5);
400 assert_eq!(matrix.get(1, 0), 12.0);
401 assert_eq!(matrix.get(1, 1), 1.01);
402 assert_eq!(matrix.get(1, 2), -1000.0);
403 }
404
405 #[test]
406 fn valid_get_on_transposed() {
407 let mut matrix = get_test_matrix();
408 matrix.transpose_inplace();
409
410 assert_eq!(matrix.get(0, 0), 0.1);
411 assert_eq!(matrix.get(0, 1), 12.0);
412 assert_eq!(matrix.get(1, 0), 1.3);
413 assert_eq!(matrix.get(1, 1), 1.01);
414 assert_eq!(matrix.get(2, 0), 0.5);
415 assert_eq!(matrix.get(2, 1), -1000.0);
416 }
417
418 #[test]
419 fn valid_get_on_untransposed() {
420 let mut matrix = get_test_matrix();
421 matrix.transpose_inplace();
422 matrix.transpose_inplace();
423
424 assert_eq!(matrix.get(0, 0), 0.1);
425 assert_eq!(matrix.get(0, 1), 1.3);
426 assert_eq!(matrix.get(0, 2), 0.5);
427 assert_eq!(matrix.get(1, 0), 12.0);
428 assert_eq!(matrix.get(1, 1), 1.01);
429 assert_eq!(matrix.get(1, 2), -1000.0);
430 }
431
432 #[test]
433 #[should_panic]
434 fn unvalid_get_column_out_of_bound() {
435 let matrix = get_test_matrix();
436
437 matrix.get(2, 0);
438 }
439
440 #[test]
441 #[should_panic]
442 fn unvalid_get_row_out_of_bound() {
443 let matrix = get_test_matrix();
444
445 matrix.get(5, 1);
446 }
447
448 #[test]
449 #[should_panic]
450 fn unvalid_get_tranposed_column_out_of_bound() {
451 let mut matrix = get_test_matrix();
452 matrix.transpose_inplace();
453
454 matrix.get(0, 2);
455 }
456
457 #[test]
458 #[should_panic]
459 fn unvalid_get_transposed_row_out_of_bound() {
460 let mut matrix = get_test_matrix();
461 matrix.transpose_inplace();
462
463 matrix.get(3, 1);
464 }
465
466 #[test]
467 fn valid_get_row() {
468 let matrix = get_test_matrix();
469 let expected_vec = vec![12.0, 1.01, -1000.0];
470
471 assert_eq![matrix.get_row(1), expected_vec];
472 }
473
474 #[test]
475 fn valid_get_row_on_transposed() {
476 let mut matrix = get_test_matrix();
477 let expected_vec = vec![0.5, -1000.0];
478 matrix.transpose_inplace();
479
480 assert_eq![matrix.get_row(2), expected_vec];
481 }
482
483 #[test]
484 fn valid_set() {
485 let mut matrix = get_test_matrix();
486 matrix.set(69.69, 1, 1);
487
488 assert_eq![matrix.data[4], 69.69];
489 }
490
491 #[test]
492 #[should_panic]
493 fn unvalid_set_column_out_of_bound() {
494 let mut matrix = get_test_matrix();
495
496 matrix.set(69.69, 2, 0);
497 }
498
499 #[test]
500 #[should_panic]
501 fn unvalid_set_row_out_of_bound() {
502 let mut matrix = get_test_matrix();
503
504 matrix.set(69.69, 5, 1);
505 }
506
507 #[test]
508 #[should_panic]
509 fn unvalid_set_tranposed_column_out_of_bound() {
510 let mut matrix = get_test_matrix();
511 matrix.transpose_inplace();
512
513 matrix.set(69.69, 0, 2);
514 }
515
516 #[test]
517 #[should_panic]
518 fn unvalid_set_transposed_row_out_of_bound() {
519 let mut matrix = get_test_matrix();
520 matrix.transpose_inplace();
521
522 matrix.set(69.69, 3, 1);
523 }
524
525 #[test]
526 fn valid_set_row() {
527 let mut matrix = get_test_matrix();
528 let new_row = vec![0.8, 0.1, 1203123.0];
529
530 matrix.set_row(&new_row, 0);
531
532 assert_eq![matrix.get_row(0), new_row];
533 }
534
535 #[test]
536 fn valid_set_row_on_transposed() {
537 let mut matrix = get_test_matrix();
538 let new_row = vec![0.8, 0.1];
539
540 matrix.transpose_inplace();
541 matrix.set_row(&new_row, 2);
542
543 assert_eq![matrix.get_row(2), new_row];
544 }
545
546 #[test]
547 fn max_test() {
548 let matrix = get_test_matrix();
549
550 assert_eq![matrix.max(), 12.0];
551 }
552
553 #[test]
554 fn min_test() {
555 let matrix = get_test_matrix();
556
557 assert_eq![matrix.min(), -1000.0];
558 }
559
560 #[test]
561 fn add_values_of_a_row_test() {
562 let test_data = parse_test_csv("tests/test_data/add_values_of_a_row_test.csv".to_string());
563
564 assert!(test_data[0]
565 .add_1d_matrix_to_all_rows(&test_data[1])
566 .is_equal(&test_data[2], 10));
567 }
568
569 #[test]
570 fn div_two_matrices_test() {
571 let test_data = parse_test_csv("tests/test_data/div_two_matrices_test.csv".to_string());
572 let mut m1 = test_data[0].clone();
573
574 m1.div_two_matrices_inplace(&test_data[1]);
575 assert!(m1.is_equal(&test_data[2], 10));
576 }
577
578 #[test]
579 #[should_panic]
580 fn unvalid_div_by_0_div_two_matrices_test() {
581 let test_data = parse_test_csv("tests/test_data/div_two_matrices_test.csv".to_string());
582 let mut m1 = test_data[0].clone();
583 let mut m2 = test_data[1].clone();
584 m2.set(0.0, 1, 1);
585 m1.div_two_matrices_inplace(&m2);
586 }
587
588 #[test]
589 fn sqrt_test() {
590 let test_data = parse_test_csv("tests/test_data/sqrt_test.csv".to_string());
591 let mut m1 = test_data[0].clone();
592
593 m1.sqrt_inplace();
594 assert!(m1.is_equal(&test_data[1], 10));
595 }
596
597 #[test]
598 #[should_panic]
599 fn unvalid_sqrt_negavtive_value() {
600 let test_data = parse_test_csv("tests/test_data/div_two_matrices_test.csv".to_string());
601 let mut m1 = test_data[0].clone();
602 m1.set(-1.0, 1, 1);
603 m1.sqrt_inplace();
604 }
605
606 #[test]
607 fn dot_product_test() {
608 let test_data = parse_test_csv("tests/test_data/dot_product_test.csv".to_string());
609
610 assert!(test_data[0].dot(&test_data[1]).is_equal(&test_data[2], 8));
611 }
612
613 #[test]
614 fn normalize_test() {
615 let mut test_data = parse_test_csv("tests/test_data/normalize_test.csv".to_string());
616
617 test_data[0].normalize();
618
619 assert!(test_data[0].is_equal(&test_data[1], 8));
620 }
621}