1use num_traits::{Float, Zero};
2use std::{
3 convert::From,
4 fmt,
5 ops};
6
7#[derive(Clone)]
8pub struct Matrix<T> {
9 pub n: usize,
10 pub m: usize,
11 data: Vec<T>
12}
13
14pub enum MatrixError {
15 DimensionError
16}
17
18impl<T: Float + Zero + From<f32>> Matrix<T> {
19 pub fn new(val: f32, n: usize, m:usize) -> Matrix<T> {
20 Matrix { n , m , data: vec![<T as From<f32>>::from(val); n * m] }
21 }
22
23 pub fn zeros(n: usize, m:usize) -> Matrix<T> {
24 Matrix { n , m , data: vec![<T as From<f32>>::from(0_f32); n * m] }
25 }
26 pub fn identity(n: usize) -> Matrix<T> {
27 let mut mat: Matrix<T> = Matrix::zeros(n,n);
28 for idx in 0..n {
29 mat.data[(n+1) * idx] = <T as From<f32>>::from(1_f32);
30 }
31 mat
32 }
33
34 pub fn piece_mult(_lhs: Matrix<T>, _rhs: Matrix<T>) -> Result<Matrix<T>, MatrixError> {
35 if _lhs.data.len() != _rhs.data.len() {
36 return Err(MatrixError::DimensionError);
37 }
38
39 let data = _lhs.data.into_iter().zip(_rhs.data.into_iter())
40 .map(|(a , b)| a * b)
41 .collect::<Vec<T>>();
42
43 Ok(Matrix{n:_lhs.n, m:_lhs.m, data})
44
45 }
46
47 pub fn get_row(&self, row: usize) -> &[T] {
48 &self.data[row * self.m..(row + 1) * self.m]
49 }
50
51 pub fn get_col(&self, col: usize) -> Vec<T> {
52 self.data.iter()
53 .enumerate()
54 .filter(|&(idx, _)| idx%self.m == col)
55 .map(|(_ , num)| *num)
56 .collect::<Vec<T>>()
57 }
58
59 pub fn transpose(&self) -> Matrix<T> {
60
61 let data: Vec<Vec<T>> = (0..self.m).map(|x| self.get_col(x))
62 .collect::<Vec<Vec<T>>>();
63
64 Matrix::from(data)
65 }
66
67 pub fn exp(&mut self) -> Matrix<T> {
68
69 let data: Vec<T> = self.data.iter()
70 .map(|x| x.exp())
71 .collect::<Vec<T>>();
72
73 Matrix { n:self.n , m:self.m , data }
74
75 }
76
77 pub fn one_over(&self) -> Matrix<T> {
78
79 let data: Vec<T> = self.data.iter()
80 .map(|x| T::one() / *x)
81 .collect::<Vec<T>>();
82
83 Matrix { n:self.n , m:self.m , data }
84
85 }
86}
87
88impl<T: Float + Zero + From<f32>> fmt::Display for Matrix<T> {
89
90 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91 write!(f, "Matrix<{} rows x {} cols>", self.n, self.m)
92 }
93}
94
95
96
97impl<T: Float + Zero + From<f32> + fmt::Display> fmt::Debug for Matrix<T> {
98
99 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100 write!(f, "Matrix<{} rows x {} cols>:", self.n, self.m)?;
101 for (idx, item) in self.data.iter().enumerate() {
102 if idx%self.m == 0 {
103 write!(f, "\n")?;
104 }
105 write!(f, "{item}\t")?;
106
107 }
108
109 Ok(())
110 }
111}
112
113impl<T: Float + Zero + From<f32> + fmt::Display + ops::Add> ops::Add<Matrix<T>> for Matrix<T> {
114
115 type Output = Matrix<T>;
116
117 fn add(self, _rhs: Matrix<T>) -> Matrix<T> {
118
119 let mut mat: Matrix<_> = self.clone();
120
121 for (idx, item) in mat.data.iter_mut().enumerate() {
122 *item = *item + _rhs.data[idx];
123 }
124
125 mat
126 }
127}
128
129impl<T: Float + Zero + From<f32> + fmt::Display> ops::Index<usize> for Matrix<T> {
130
131 type Output = T;
132
133 fn index(&self, index: usize) -> &Self::Output {
134 &self.data[index]
135 }
136}
137
138impl<T: Float + Zero + From<f32> + fmt::Display> ops::IndexMut<usize> for Matrix<T> {
139
140 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
141 &mut self.data[index]
142 }
143}
144
145
146impl<T: Float + Zero + From<f32> + fmt::Display + ops::Sub> ops::Sub<Matrix<T>> for Matrix<T> {
147
148 type Output = Matrix<T>;
149
150 fn sub(self, _rhs: Matrix<T>) -> Matrix<T> {
151
152 let mut mat: Matrix<_> = self.clone();
153
154 for (idx, item) in mat.data.iter_mut().enumerate() {
155 *item = *item - _rhs.data[idx];
156 }
157
158 mat
159 }
160}
161
162impl<T: Float + Zero + From<f32> + fmt::Display + fmt::Debug + ops::Mul> ops::Mul<T> for Matrix<T> {
163
164 type Output = Matrix<T>;
165
166 fn mul(self, _rhs: T) -> Matrix<T> {
167
168 let data: Vec<T> = self.data.iter()
169 .map(|x| *x * _rhs)
170 .collect::<Vec<T>>();
171
172
173 Matrix { n: self.n, m: self.m, data }
174 }
175
176}
177
178impl<T: Float + Zero + From<f32> + fmt::Display + fmt::Debug + ops::Div> ops::Div<T> for Matrix<T> {
179
180 type Output = Matrix<T>;
181
182 fn div(self, _rhs: T) -> Matrix<T> {
183
184 let data: Vec<T> = self.data.iter()
185 .map(|x| *x / _rhs)
186 .collect::<Vec<T>>();
187
188
189 Matrix { n: self.n, m: self.m, data }
190 }
191
192}
193
194impl<T: Float + Zero + From<f32> + fmt::Display + fmt::Debug + ops::Add> ops::Add<T> for Matrix<T> {
195
196 type Output = Matrix<T>;
197
198 fn add(self, _rhs: T) -> Matrix<T> {
199 let data = self.data.iter()
200 .map(|x| *x + _rhs)
201 .collect::<Vec<T>>();
202
203 Matrix { n:self.n, m:self.m, data }
204 }
205
206}
207impl<T: Float + Zero + From<f32> + fmt::Display + fmt::Debug + ops::Mul> ops::Mul<Matrix<T>> for Matrix<T> {
208
209 type Output = Matrix<T>;
210
211 fn mul(self, _rhs: Matrix<T>) -> Matrix<T> {
212
213 let mut mat: Matrix<T> = Matrix::zeros(
214 self.n,
215 _rhs.m);
216
217 for rdx in 0..self.n {
218 for cdx in 0.._rhs.m {
219 let row = self.get_row(rdx);
220 let col = _rhs.get_col(cdx);
221
222 let sum = row.into_iter()
223 .enumerate()
224 .map(|(idx, num)| {
225 *num * col[idx]
226 })
227 .reduce(|acc, x| acc + x)
228 .unwrap();
229
230 mat.data[rdx * mat.n + cdx] = sum;
231 }
232 }
233 mat
234 }
235}
236
237
238impl<T> From<Vec<Vec<T>>> for Matrix<T> where
239T : Float + Zero + From<f32> {
240
241 fn from(value: Vec<Vec<T>>) -> Matrix<T> {
242
243 let n = value.len();
244 let m = value[0].len();
245
246 let data: Vec<T> = value
247 .iter().fold(Vec::new(),|mut acc, n| {
248 acc.extend(n);
249 acc
250 });
251
252 Matrix { n, m, data }
253 }
254}
255
256impl<T: Float + Zero + From<f32>> From<&[&[T]]> for Matrix<T> {
257
258 fn from(value: &[&[T]]) -> Matrix<T> {
259
260 let n = value.len();
261 let m = value[0].len();
262
263 let data: Vec<T> = value
264 .iter().fold(Vec::new(),|mut acc, n| {
265 acc.extend(*n);
266 acc
267 });
268
269 Matrix { n, m, data }
270 }
271}
272
273
274
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 #[test]
280 fn init_zeroes() {
281 let mat: Matrix<f32> = Matrix::zeros(5, 4);
282 assert_eq!(
283 mat.data.into_iter().reduce(|acc, x| acc + x).unwrap(),
284 0_f32)
285 }
286
287 #[test]
288 fn init_val() {
289 let mat: Matrix<f32> = Matrix::new(1_f32, 5, 4);
290 assert_eq!(
291 mat.data.into_iter().reduce(|acc, x| acc + x).unwrap(),
292 20_f32)
293 }
294
295 #[test]
296 fn test_from_2d_vec() {
297
298 let mut data: Vec<Vec<f32>> = Vec::new();
299 data.push(vec![1_f32,2_f32,3_f32]);
300 data.push(vec![3_f32,2_f32,1_f32]);
301
302 let mat: Matrix<f32> = Matrix::from(data);
303
304 let comp: Vec<f32> = vec![1_f32,2_f32,3_f32,
305 3_f32,2_f32,1_f32];
306
307
308 assert_eq!(mat.data, comp);
309 assert_eq!(mat.n, 2);
310 assert_eq!(mat.m, 3);
311
312
313 }
314
315 #[test]
316 fn test_from_2d_slice() {
317
318 let data: &[&[f32]] = &[&[1_f32,0_f32],
319 &[0_f32,1_f32]];
320
321 let mat: Matrix<f32> = Matrix::from(data);
322
323 let comp: Vec<f32> = vec![1_f32,0_f32,
324 0_f32,1_f32];
325
326 assert_eq!(mat.data, comp);
327 assert_eq!(mat.n, 2);
328 assert_eq!(mat.m, 2);
329
330
331 }
332
333 #[test]
334 fn test_from_reg_add() {
335
336 let data: &[&[f32]] = &[&[1_f32,0_f32],
337 &[0_f32,1_f32]];
338
339 let data2: &[&[f32]] = &[&[0_f32,1_f32],
340 &[1_f32,0_f32]];
341
342 let mat: Matrix<f32> = Matrix::from(data);
343 let mat2: Matrix<f32> = Matrix::from(data2);
344
345 let comp: Vec<f32> = vec![1_f32,1_f32,
346 1_f32,1_f32];
347
348 let res: Matrix<f32> = mat + mat2;
349
350 assert_eq!(res.data, comp);
351
352 }
353
354 #[test]
355 fn test_get_row() {
356 let data: &[&[f32]] = &[&[1_f32,0_f32, 0_f32],
357 &[0_f32,1_f32, 3_f32],
358 &[2_f32,9_f32, 3_f32],
359 &[0_f32,1_f32, 3_f32]];
360
361 let mat: Matrix<f32> = Matrix::from(data);
362
363 let comp: &[f32] = &[2_f32, 9_f32, 3_f32];
364
365 assert_eq!(
366 mat.get_row(2),
367 comp
368 );
369
370 }
371
372 #[test]
373 fn test_get_col() {
374 let data: &[&[f32]] = &[&[1_f32,0_f32, 0_f32],
375 &[0_f32,1_f32, 3_f32],
376 &[2_f32,9_f32, 3_f32],
377 &[0_f32,1_f32, 3_f32]];
378
379 let mat: Matrix<f32> = Matrix::from(data);
380
381 let comp: Vec<f32> = vec![0_f32, 1_f32, 9_f32, 1_f32];
382
383 assert_eq!(
384 mat.get_col(1),
385 comp
386 );
387
388 }
389
390 #[test]
391 fn test_mult() {
392 let data1: &[&[f32]] = &[
393 &[1_f32, 2_f32],
394 &[3_f32, 4_f32],
395 &[5_f32, 6_f32],
396 ];
397 let data2: &[&[f32]] = &[
398 &[1_f32, 2_f32, 3_f32],
399 &[4_f32, 5_f32, 6_f32]
400 ];
401
402 let mat1: Matrix<f32> = Matrix::from(data1);
403 let mat2: Matrix<f32> = Matrix::from(data2);
404
405 let comp: Vec<f32> = vec![9_f32, 12_f32, 15_f32,
406 19_f32,26_f32, 33_f32,
407 29_f32,40_f32, 51_f32];
408
409 assert_eq!(
410 (mat1 * mat2).data,
411 comp);
412 }
413
414 #[test]
415 fn test_transpose() {
416 let data1: &[&[f32]] = &[
417 &[1_f32, 2_f32],
418 &[3_f32, 4_f32],
419 &[5_f32, 6_f32],
420 ];
421
422 let mut mat1: Matrix<f32> = Matrix::from(data1);
423
424 let comp: Vec<f32> = vec![1_f32, 3_f32, 5_f32,
425 2_f32, 4_f32, 6_f32];
426
427
428 assert_eq!(mat1.transpose().data, comp);
429
430 }
431
432 #[test]
433 fn test_exp() {
434
435 let data1: &[&[f32]] = &[
436 &[1_f32, 2_f32],
437 &[3_f32, 4_f32],
438 &[5_f32, 6_f32],
439 ];
440
441 let mut mat1: Matrix<f32> = Matrix::from(data1);
442
443 let comp: Vec<f32> = vec![1_f32.exp(), 2_f32.exp(), 3_f32.exp(),
444 4_f32.exp(), 5_f32.exp(), 6_f32.exp()];
445
446 assert_eq!(mat1.exp().data, comp);
447 }
448
449 #[test]
450 fn test_mult_scalar() {
451 let data1: &[&[f32]] = &[
452 &[1_f32, 2_f32],
453 &[3_f32, 4_f32],
454 &[5_f32, 6_f32],
455 ];
456
457 let mat1: Matrix<f32> = Matrix::from(data1);
458
459 let comp: Vec<f32> = vec![1_f32 * 5_f32, 2_f32 * 5_f32, 3_f32 * 5_f32,
460 4_f32 * 5_f32, 5_f32 * 5_f32, 6_f32 * 5_f32];
461
462 assert_eq!((mat1 * 5_f32).data, comp);
463
464 }
465
466 #[test]
467 fn test_div_scalar() {
468 let data1: &[&[f32]] = &[
469 &[1_f32, 2_f32],
470 &[3_f32, 4_f32],
471 &[5_f32, 6_f32],
472 ];
473
474 let mat1: Matrix<f32> = Matrix::from(data1);
475
476 let comp: Vec<f32> = vec![1_f32 / 5_f32, 2_f32 / 5_f32, 3_f32 / 5_f32,
477 4_f32 / 5_f32, 5_f32 / 5_f32, 6_f32 / 5_f32];
478
479 assert_eq!((mat1 / 5_f32).data, comp);
480
481 }
482
483 #[test]
484 fn test_one_over() {
485 let data1: &[&[f32]] = &[
486 &[1_f32, 2_f32],
487 &[3_f32, 4_f32],
488 &[5_f32, 6_f32],
489 ];
490
491 let mat1: Matrix<f32> = Matrix::from(data1);
492
493 let comp: Vec<f32> = vec![1_f32 / 1_f32, 1_f32/2_f32,
494 1_f32/ 3_f32, 1_f32/4_f32,
495 1_f32/ 5_f32, 1_f32/6_f32];
496
497 assert_eq!(mat1.one_over().data, comp);
498
499 }
500
501 #[test]
502 fn test_piece_by_mult() {
503 let data1: &[&[f32]] = &[
504 &[1_f32, 2_f32,3_f32],
505 ];
506
507 let data2: &[&[f32]] = &[ &[1_f32], &[2_f32], &[3_f32]];
508
509 let mat1: Matrix<f32> = Matrix::from(data1);
510 let mat2: Matrix<f32> = Matrix::from(data2);
511
512 let comp: Vec<f32> = vec![1_f32, 4_f32, 9_f32];
513
514 let res = match Matrix::piece_mult(mat1, mat2) {
515 Ok(matrix) => matrix,
516 Err(_) => panic!("err")
517 };
518
519 assert_eq!(res.data, comp);
520
521 }
522
523 #[test]
524 fn test_matrix_addition() {
525 let data1: &[&[f32]] = &[
526 &[1_f32, 2_f32],
527 &[3_f32, 4_f32],
528 &[5_f32, 6_f32],
529 ];
530
531 let mat1: Matrix<f32> = Matrix::from(data1);
532 let mat2: Matrix<f32> = Matrix::from(data1);
533
534 let comp: Vec<f32> = vec![2_f32, 4_f32,
535 6_f32, 8_f32,
536 10_f32, 12_f32];
537
538 assert_eq!(
539 (mat1 + mat2).data,
540 comp);
541 }
542
543 #[test]
544 fn test_matrix_subtraction() {
545 let data1: &[&[f32]] = &[
546 &[1_f32, 2_f32],
547 &[3_f32, 4_f32],
548 &[5_f32, 6_f32],
549 ];
550
551 let mat1: Matrix<f32> = Matrix::from(data1);
552 let mat2: Matrix<f32> = Matrix::from(data1);
553
554 let comp: Vec<f32> = vec![0_f32; 6];
555
556 assert_eq!(
557 (mat1 - mat2).data,
558 comp);
559 }
560
561}