matrix_simp/
lib.rs

1use num_traits::{Float, Zero};
2use std::{
3        convert::From,
4        fmt,
5        ops};
6
7#[derive(Clone)]
8pub struct Matrix<T> {
9    pub n: usize,
10    pub m: usize,
11    data: Vec<T> 
12}
13
14pub enum MatrixError {
15    DimensionError
16}
17
18impl<T: Float + Zero + From<f32>> Matrix<T> {
19    pub fn new(val: f32, n: usize, m:usize) -> Matrix<T> {
20        Matrix { n , m , data: vec![<T as From<f32>>::from(val); n * m] }
21    }
22
23    pub fn zeros(n: usize, m:usize) -> Matrix<T> {
24        Matrix { n , m , data: vec![<T as From<f32>>::from(0_f32); n * m] }
25    }
26    pub fn identity(n: usize) -> Matrix<T> {
27        let mut mat: Matrix<T> = Matrix::zeros(n,n);
28        for idx in 0..n {
29            mat.data[(n+1) * idx] = <T as From<f32>>::from(1_f32); 
30        }
31        mat
32    }
33
34    pub fn piece_mult(_lhs: Matrix<T>, _rhs: Matrix<T>) -> Result<Matrix<T>, MatrixError> {
35        if _lhs.data.len() != _rhs.data.len() {
36            return Err(MatrixError::DimensionError);
37        }
38
39        let data = _lhs.data.into_iter().zip(_rhs.data.into_iter())
40                       .map(|(a , b)| a * b)
41                       .collect::<Vec<T>>();
42
43        Ok(Matrix{n:_lhs.n, m:_lhs.m, data})
44
45    }
46
47    pub fn get_row(&self, row: usize) -> &[T] {
48        &self.data[row * self.m..(row + 1) * self.m]
49    }
50
51    pub fn get_col(&self, col: usize) -> Vec<T> {
52        self.data.iter()
53                 .enumerate()
54                 .filter(|&(idx, _)| idx%self.m == col)
55                 .map(|(_ , num)| *num)
56                 .collect::<Vec<T>>()
57    }
58
59    pub fn transpose(&self) -> Matrix<T> {
60
61        let data: Vec<Vec<T>> = (0..self.m).map(|x| self.get_col(x))
62                                           .collect::<Vec<Vec<T>>>();
63        
64        Matrix::from(data)
65    }
66
67    pub fn exp(&mut self) -> Matrix<T> {
68
69        let data: Vec<T> = self.data.iter()
70                               .map(|x| x.exp())
71                               .collect::<Vec<T>>();
72        
73        Matrix { n:self.n , m:self.m , data }
74        
75    }
76
77    pub fn one_over(&self) -> Matrix<T> {
78
79        let data: Vec<T> = self.data.iter()
80                               .map(|x| T::one() / *x)
81                               .collect::<Vec<T>>();
82        
83        Matrix { n:self.n , m:self.m , data }
84
85    }
86}
87
88impl<T: Float + Zero + From<f32>> fmt::Display for Matrix<T> {
89    
90    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91        write!(f, "Matrix<{} rows x {} cols>", self.n, self.m)
92    }
93}
94
95
96
97impl<T: Float + Zero + From<f32> + fmt::Display> fmt::Debug for Matrix<T> {
98    
99    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100        write!(f, "Matrix<{} rows x {} cols>:", self.n, self.m)?;
101        for (idx, item) in self.data.iter().enumerate() {
102            if idx%self.m == 0 {
103                write!(f, "\n")?;
104            }
105            write!(f, "{item}\t")?;
106
107        }
108
109    Ok(())
110    }
111}
112
113impl<T: Float + Zero + From<f32> + fmt::Display + ops::Add> ops::Add<Matrix<T>> for Matrix<T> {
114   
115    type Output = Matrix<T>;
116
117    fn add(self, _rhs: Matrix<T>) -> Matrix<T> {
118        
119       let mut mat: Matrix<_> = self.clone();
120
121       for (idx, item) in mat.data.iter_mut().enumerate() {
122            *item = *item + _rhs.data[idx]; 
123       }
124
125       mat
126    }
127}
128
129impl<T: Float + Zero + From<f32> + fmt::Display> ops::Index<usize> for Matrix<T> {
130   
131    type Output = T;
132
133    fn index(&self, index: usize) -> &Self::Output {
134        &self.data[index]
135    }
136}
137
138impl<T: Float + Zero + From<f32> + fmt::Display> ops::IndexMut<usize> for Matrix<T> {
139   
140    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
141        &mut self.data[index]
142    }
143}
144
145
146impl<T: Float + Zero + From<f32> + fmt::Display + ops::Sub> ops::Sub<Matrix<T>> for Matrix<T> {
147   
148    type Output = Matrix<T>;
149
150    fn sub(self, _rhs: Matrix<T>) -> Matrix<T> {
151        
152       let mut mat: Matrix<_> = self.clone();
153
154       for (idx, item) in mat.data.iter_mut().enumerate() {
155            *item = *item - _rhs.data[idx]; 
156       }
157
158       mat
159    }
160}
161
162impl<T: Float + Zero + From<f32> + fmt::Display + fmt::Debug + ops::Mul> ops::Mul<T> for Matrix<T> {
163
164    type Output = Matrix<T>;
165
166    fn mul(self, _rhs: T) -> Matrix<T> {
167
168        let data: Vec<T> = self.data.iter()
169                                    .map(|x| *x * _rhs)
170                                    .collect::<Vec<T>>();
171
172        
173        Matrix { n: self.n, m: self.m, data }
174    }
175
176}
177
178impl<T: Float + Zero + From<f32> + fmt::Display + fmt::Debug + ops::Div> ops::Div<T> for Matrix<T> {
179
180    type Output = Matrix<T>;
181
182    fn div(self, _rhs: T) -> Matrix<T> {
183
184        let data: Vec<T> = self.data.iter()
185                                    .map(|x| *x / _rhs)
186                                    .collect::<Vec<T>>();
187
188        
189        Matrix { n: self.n, m: self.m, data }
190    }
191
192}
193
194impl<T: Float + Zero + From<f32> + fmt::Display + fmt::Debug + ops::Add> ops::Add<T> for Matrix<T> {
195   
196    type Output = Matrix<T>;
197
198    fn add(self, _rhs: T) -> Matrix<T> {
199        let data = self.data.iter()
200                            .map(|x| *x + _rhs)
201                            .collect::<Vec<T>>();
202        
203        Matrix { n:self.n, m:self.m, data }
204    }
205
206}
207impl<T: Float + Zero + From<f32> + fmt::Display + fmt::Debug + ops::Mul> ops::Mul<Matrix<T>> for Matrix<T> {
208   
209    type Output = Matrix<T>;
210
211    fn mul(self, _rhs: Matrix<T>) -> Matrix<T> {
212        
213       let mut mat: Matrix<T> = Matrix::zeros(
214                                              self.n,
215                                              _rhs.m);
216
217       for rdx in 0..self.n {
218           for cdx in 0.._rhs.m {
219               let row = self.get_row(rdx);
220               let col = _rhs.get_col(cdx);
221
222               let sum = row.into_iter()
223                            .enumerate()
224                            .map(|(idx, num)| {
225                                *num * col[idx]
226                            })
227                            .reduce(|acc, x| acc + x)
228                            .unwrap();
229                
230                    mat.data[rdx * mat.n + cdx] = sum;      
231           }
232       }
233       mat
234    }
235}
236
237
238impl<T> From<Vec<Vec<T>>> for Matrix<T> where 
239T : Float + Zero + From<f32> {
240
241    fn from(value: Vec<Vec<T>>) -> Matrix<T> {
242        
243        let n = value.len();
244        let m = value[0].len();
245
246        let data: Vec<T> = value
247                        .iter().fold(Vec::new(),|mut acc, n| {
248                            acc.extend(n);
249                            acc
250                        });
251
252        Matrix { n, m, data }
253    }
254}
255
256impl<T: Float + Zero + From<f32>> From<&[&[T]]> for Matrix<T> {
257
258    fn from(value: &[&[T]]) -> Matrix<T> {
259        
260        let n = value.len();
261        let m = value[0].len();
262
263        let data: Vec<T> = value
264                        .iter().fold(Vec::new(),|mut acc, n| {
265                            acc.extend(*n);
266                            acc
267                        });
268
269        Matrix { n, m, data }
270    }
271}
272
273
274
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    #[test]
280    fn init_zeroes() {
281        let mat: Matrix<f32> = Matrix::zeros(5, 4);
282        assert_eq!(
283            mat.data.into_iter().reduce(|acc, x| acc + x).unwrap(),
284            0_f32)
285    }
286
287    #[test]
288    fn init_val() {
289        let mat: Matrix<f32> = Matrix::new(1_f32, 5, 4);
290        assert_eq!(
291            mat.data.into_iter().reduce(|acc, x| acc + x).unwrap(),
292            20_f32)
293    }
294
295    #[test]
296    fn test_from_2d_vec() {
297        
298        let mut data: Vec<Vec<f32>> = Vec::new();
299        data.push(vec![1_f32,2_f32,3_f32]);
300        data.push(vec![3_f32,2_f32,1_f32]);
301
302        let mat: Matrix<f32> = Matrix::from(data);
303
304        let comp: Vec<f32> = vec![1_f32,2_f32,3_f32,
305                                    3_f32,2_f32,1_f32];
306
307
308        assert_eq!(mat.data, comp);
309        assert_eq!(mat.n, 2);
310        assert_eq!(mat.m, 3);
311
312
313    }
314
315    #[test]
316    fn test_from_2d_slice() {
317        
318        let data: &[&[f32]] = &[&[1_f32,0_f32],
319                                  &[0_f32,1_f32]];
320
321        let mat: Matrix<f32> = Matrix::from(data);
322
323        let comp: Vec<f32> = vec![1_f32,0_f32,
324                                    0_f32,1_f32];
325
326        assert_eq!(mat.data, comp);
327        assert_eq!(mat.n, 2);
328        assert_eq!(mat.m, 2);
329
330
331    }
332
333    #[test]
334    fn test_from_reg_add() {
335        
336        let data: &[&[f32]] = &[&[1_f32,0_f32],
337                                  &[0_f32,1_f32]];
338
339        let data2: &[&[f32]] = &[&[0_f32,1_f32],
340                                  &[1_f32,0_f32]];
341
342        let mat: Matrix<f32> = Matrix::from(data);
343        let mat2: Matrix<f32> = Matrix::from(data2);
344
345        let comp: Vec<f32> = vec![1_f32,1_f32,
346                                    1_f32,1_f32];
347        
348        let res: Matrix<f32> = mat + mat2;
349
350        assert_eq!(res.data, comp);
351
352    }
353
354    #[test]
355    fn test_get_row() {
356        let data: &[&[f32]] = &[&[1_f32,0_f32, 0_f32],
357                                &[0_f32,1_f32, 3_f32],
358                                &[2_f32,9_f32, 3_f32],
359                                &[0_f32,1_f32, 3_f32]];
360
361        let mat: Matrix<f32> = Matrix::from(data);
362
363        let comp: &[f32] = &[2_f32, 9_f32, 3_f32];
364        
365        assert_eq!(
366                    mat.get_row(2),
367                    comp 
368                );
369
370    }
371
372    #[test]
373    fn test_get_col() {
374        let data: &[&[f32]] = &[&[1_f32,0_f32, 0_f32],
375                                &[0_f32,1_f32, 3_f32],
376                                &[2_f32,9_f32, 3_f32],
377                                &[0_f32,1_f32, 3_f32]];
378
379        let mat: Matrix<f32> = Matrix::from(data);
380
381        let comp: Vec<f32> = vec![0_f32, 1_f32, 9_f32, 1_f32];
382        
383        assert_eq!(
384                    mat.get_col(1),
385                    comp 
386                );
387
388    }
389
390    #[test]
391    fn test_mult() {
392        let data1: &[&[f32]] = &[
393                                    &[1_f32, 2_f32],
394                                    &[3_f32, 4_f32],
395                                    &[5_f32, 6_f32],
396                                ];
397        let data2: &[&[f32]] = &[
398                                    &[1_f32, 2_f32, 3_f32],
399                                    &[4_f32, 5_f32, 6_f32]
400                                ];
401
402        let mat1: Matrix<f32> = Matrix::from(data1);
403        let mat2: Matrix<f32> = Matrix::from(data2);
404
405        let comp: Vec<f32> = vec![9_f32, 12_f32, 15_f32,
406                                     19_f32,26_f32, 33_f32,
407                                     29_f32,40_f32, 51_f32];
408
409        assert_eq!(
410                (mat1 * mat2).data,
411                comp);
412    }
413
414    #[test]
415    fn test_transpose() {
416        let data1: &[&[f32]] = &[
417                                    &[1_f32, 2_f32],
418                                    &[3_f32, 4_f32],
419                                    &[5_f32, 6_f32],
420                                ];
421        
422        let mut mat1: Matrix<f32> = Matrix::from(data1);
423
424        let comp: Vec<f32> = vec![1_f32, 3_f32, 5_f32,
425                                2_f32, 4_f32, 6_f32];
426
427
428        assert_eq!(mat1.transpose().data, comp);
429
430    }
431
432    #[test]
433    fn test_exp() {
434
435        let data1: &[&[f32]] = &[
436                                    &[1_f32, 2_f32],
437                                    &[3_f32, 4_f32],
438                                    &[5_f32, 6_f32],
439                                ];
440
441        let mut mat1: Matrix<f32> = Matrix::from(data1);
442
443        let comp: Vec<f32> = vec![1_f32.exp(), 2_f32.exp(), 3_f32.exp(),
444                                    4_f32.exp(), 5_f32.exp(), 6_f32.exp()];
445
446        assert_eq!(mat1.exp().data, comp);
447    }
448
449    #[test]
450    fn test_mult_scalar() {
451        let data1: &[&[f32]] = &[
452                                    &[1_f32, 2_f32],
453                                    &[3_f32, 4_f32],
454                                    &[5_f32, 6_f32],
455                                ];
456
457        let mat1: Matrix<f32> = Matrix::from(data1);
458
459        let comp: Vec<f32> = vec![1_f32 * 5_f32, 2_f32 * 5_f32, 3_f32 * 5_f32,
460                                    4_f32 * 5_f32, 5_f32 * 5_f32, 6_f32 * 5_f32];
461
462        assert_eq!((mat1 * 5_f32).data, comp);
463
464    }
465
466    #[test]
467    fn test_div_scalar() {
468        let data1: &[&[f32]] = &[
469                                    &[1_f32, 2_f32],
470                                    &[3_f32, 4_f32],
471                                    &[5_f32, 6_f32],
472                                ];
473
474        let mat1: Matrix<f32> = Matrix::from(data1);
475
476        let comp: Vec<f32> = vec![1_f32 / 5_f32, 2_f32 / 5_f32, 3_f32 / 5_f32,
477                                    4_f32 / 5_f32, 5_f32 / 5_f32, 6_f32 / 5_f32];
478
479        assert_eq!((mat1 / 5_f32).data, comp);
480
481    }
482
483    #[test]
484    fn test_one_over() {
485        let data1: &[&[f32]] = &[
486                                    &[1_f32, 2_f32],
487                                    &[3_f32, 4_f32],
488                                    &[5_f32, 6_f32],
489                                ];
490
491        let mat1: Matrix<f32> = Matrix::from(data1);
492
493        let comp: Vec<f32> = vec![1_f32 / 1_f32, 1_f32/2_f32, 
494                                  1_f32/  3_f32, 1_f32/4_f32,
495                                  1_f32/  5_f32, 1_f32/6_f32];
496
497        assert_eq!(mat1.one_over().data, comp);
498
499    }
500
501    #[test]
502    fn test_piece_by_mult() {
503        let data1: &[&[f32]] = &[
504                                    &[1_f32, 2_f32,3_f32],
505                                ];
506
507        let data2: &[&[f32]] = &[ &[1_f32], &[2_f32], &[3_f32]];
508
509        let mat1: Matrix<f32> = Matrix::from(data1);
510        let mat2: Matrix<f32> = Matrix::from(data2);
511
512        let comp: Vec<f32> = vec![1_f32, 4_f32, 9_f32];
513        
514        let res = match Matrix::piece_mult(mat1, mat2) {
515            Ok(matrix) => matrix,
516            Err(_) => panic!("err") 
517        };
518
519        assert_eq!(res.data, comp);
520
521    }
522
523    #[test]
524    fn test_matrix_addition() {
525        let data1: &[&[f32]] = &[
526                                    &[1_f32, 2_f32],
527                                    &[3_f32, 4_f32],
528                                    &[5_f32, 6_f32],
529                                ];
530
531        let mat1: Matrix<f32> = Matrix::from(data1);
532        let mat2: Matrix<f32> = Matrix::from(data1);
533
534        let comp: Vec<f32> = vec![2_f32, 4_f32,
535                                  6_f32, 8_f32,
536                                  10_f32, 12_f32];
537
538        assert_eq!(
539                (mat1 + mat2).data,
540                comp);
541    }
542
543    #[test]
544    fn test_matrix_subtraction() {
545        let data1: &[&[f32]] = &[
546                                    &[1_f32, 2_f32],
547                                    &[3_f32, 4_f32],
548                                    &[5_f32, 6_f32],
549                                ];
550
551        let mat1: Matrix<f32> = Matrix::from(data1);
552        let mat2: Matrix<f32> = Matrix::from(data1);
553
554        let comp: Vec<f32> = vec![0_f32; 6];
555
556        assert_eq!(
557                (mat1 - mat2).data,
558                comp);
559    }
560
561}