1use crate::error::{LoadError, LoadErrorKind};
16use crate::utils;
17use crate::vector;
18use csv;
19use num::{FromPrimitive, Num};
20use rand::distributions::uniform::SampleUniform;
21use rand::distributions::{Distribution, Normal, Uniform};
22use std::fs::File;
23use std::ops;
24use std::path::Path;
25use std::marker::PhantomData;
26
27#[macro_export]
64macro_rules! matrix {
65 ($elem:expr; $shape:expr) => {{
69 let nrows = $shape[0];
70 let ncols = $shape[1];
71 let elements = vec![vec![$elem; ncols]; nrows];
72 $crate::Matrix::from(elements)
73 }};
74
75 ($($x:expr),*) => {{
77 let elements = vec![vec![$($x),*]];
78 $crate::Matrix::from(elements)
79 }};
80
81 ($($x:expr,)*) => {{
83 let elements = vec![vec![$($x),*]];
84 Matrix::from(elements)
85 }};
86
87 ($($($x:expr),*;)*) => {{
89 let elements = vec![$(vec![$($x),*]),*];
90 Matrix::from(elements)
91 }};
92
93 ($($($x:expr),*);*) => {{
95 let elements = vec![$(vec![$($x),*]),*];
96 Matrix::from(elements)
97 }};
98}
99
100type RowMatrix<T> = vector::Vector<T>;
103#[derive(Debug)]
112pub struct Matrix<T> {
113 nrows: usize,
115 ncols: usize,
116 elements: Vec<RowMatrix<T>>,
117}
118
119impl<T> Matrix<T> {
120 pub fn shape(&self) -> [usize; 2] {
134 [self.nrows, self.ncols]
135 }
136
137 pub fn full(shape: [usize; 2], value: T) -> Matrix<T>
147 where
148 T: FromPrimitive + Num + Copy,
149 {
150 let nrows = shape[0];
152 let ncols = shape[1];
153 let elements = vec![vector![value; ncols]; nrows];
154 Matrix {
155 nrows,
156 ncols,
157 elements,
158 }
159 }
160
161 pub fn full_like(m: &Matrix<T>, value: T) -> Matrix<T>
175 where
176 T: FromPrimitive + Num + Copy,
177 {
178 let nrows = m.nrows;
180 let ncols = m.ncols;
181 let elements = vec![vector![value; ncols]; nrows];
182 Matrix {
183 nrows,
184 ncols,
185 elements,
186 }
187 }
188
189 pub fn zeros(shape: [usize; 2]) -> Matrix<T>
200 where
201 T: FromPrimitive + Num + Copy,
202 {
203 Self::full(shape, T::from_i32(0).unwrap())
204 }
205
206 pub fn zeros_like(m: &Matrix<T>) -> Matrix<T>
217 where
218 T: FromPrimitive + Num + Copy,
219 {
220 Self::full([m.nrows, m.ncols], T::from_i32(0).unwrap())
221 }
222
223 pub fn ones(shape: [usize; 2]) -> Matrix<T>
234 where
235 T: FromPrimitive + Num + Copy,
236 {
237 Self::full(shape, T::from_i32(1).unwrap())
238 }
239
240 pub fn ones_like(m: &Matrix<T>) -> Matrix<T>
251 where
252 T: FromPrimitive + Num + Copy,
253 {
254 Self::full([m.nrows, m.ncols], T::from_i32(1).unwrap())
255 }
256
257 pub fn power(&self, exp: usize) -> Matrix<T>
269 where
270 T: FromPrimitive + Num + Copy,
271 {
272 let elements =
273 self.elements.iter().map(|row| row.power(exp)).collect();
274 Matrix {
275 nrows: self.nrows,
276 ncols: self.ncols,
277 elements,
278 }
279 }
280
281 pub fn uniform(shape: [usize; 2], low: T, high: T) -> Matrix<T>
293 where
294 T: Num + SampleUniform + Copy,
295 {
296 let nrows = shape[0];
298 let ncols = shape[1];
299
300 let mut elements = Vec::with_capacity(nrows);
301 let uniform_distribution = Uniform::new(low, high);
302 let mut rng = rand::thread_rng();
304 for _ in 0..nrows {
305 let mut cols = Vec::with_capacity(ncols);
306 for _ in 0..ncols {
307 cols.push(uniform_distribution.sample(&mut rng));
308 }
309 elements.push(RowMatrix::from(cols));
310 }
311
312 Matrix {
313 nrows,
314 ncols,
315 elements,
316 }
317 }
318
319 pub fn from_csv<P>(file_path: P) -> MatrixLoaderForCSV<T, P>
330 where
331 P: AsRef<Path>,
332 {
333 MatrixLoaderForCSV {
334 file_path,
335 has_headers: false,
336 phantom: PhantomData
337 }
338 }
339}
340
341impl Matrix<f64> {
342 pub fn normal(shape: [usize; 2], mean: f64, std_dev: f64) -> Matrix<f64> {
353 let nrows = shape[0];
355 let ncols = shape[1];
356
357 let mut elements = Vec::with_capacity(nrows);
358 let normal_distribution = Normal::new(mean, std_dev);
359 let mut rng = rand::thread_rng();
361 for _ in 0..nrows {
362 let mut cols = Vec::with_capacity(ncols);
363 for _ in 0..ncols {
364 cols.push(normal_distribution.sample(&mut rng));
365 }
366 elements.push(RowMatrix::from(cols));
367 }
368
369 Matrix {
370 nrows,
371 ncols,
372 elements,
373 }
374 }
375}
376
377impl<T> From<Vec<Vec<T>>> for Matrix<T>
379where
380 T: Num + Copy,
381{
382 fn from(source: Vec<Vec<T>>) -> Self {
383 let nrows = source.len();
384 let ncols = source[0].len();
385 let ncols_inconsistent = source.iter().any(|v| v.len() != ncols);
387 if ncols_inconsistent {
388 panic!("Invalid matrix: the number of columns is inconsistent")
389 }
390 let elements = source
392 .iter()
393 .map(|v| {
394 let mut row = Vec::new();
398 v.iter().for_each(|x| row.push(*x));
399 RowMatrix::from(row)
400 })
401 .collect();
402
403 Matrix {
404 nrows,
405 ncols,
406 elements,
407 }
408 }
409}
410
411impl<T> PartialEq for Matrix<T>
413where
414 T: Num + Copy,
415{
416 fn eq(&self, other: &Matrix<T>) -> bool {
417 if self.elements != other.elements {
418 return false;
419 }
420 true
421 }
422 fn ne(&self, other: &Matrix<T>) -> bool {
423 if self.elements == other.elements {
424 return false;
425 }
426 true
427 }
428}
429
430impl<T> ops::Index<usize> for Matrix<T> {
432 type Output = RowMatrix<T>;
433
434 fn index(&self, i: usize) -> &RowMatrix<T> {
435 &self.elements[i]
436 }
437}
438
439impl<T> IntoIterator for Matrix<T> {
441 type Item = RowMatrix<T>;
442 type IntoIter = ::std::vec::IntoIter<RowMatrix<T>>;
443
444 fn into_iter(self) -> Self::IntoIter {
445 self.elements.into_iter()
446 }
447}
448
449impl<T> ops::Add<Matrix<T>> for Matrix<T>
451where
452 T: Num + Copy,
453{
454 type Output = Matrix<T>;
455
456 fn add(self, other: Matrix<T>) -> Matrix<T> {
457 if self.shape() != other.shape() {
458 panic!(
459 "Matrix addition with invalid shape: {:?} != {:?}",
460 self.shape(),
461 other.shape()
462 );
463 }
464
465 let elements = self
467 .elements
468 .iter()
469 .enumerate()
470 .map(|(i, row)| {
471 row.elements
472 .iter()
473 .enumerate()
474 .map(|(j, value)| *value + other[i][j])
475 .collect()
476 })
477 .collect();
478 Matrix {
479 nrows: self.nrows,
480 ncols: self.ncols,
481 elements,
482 }
483 }
484}
485
486impl<T> ops::Add<T> for Matrix<T>
493where
494 T: Num + Copy,
495{
496 type Output = Matrix<T>;
497
498 fn add(self, value: T) -> Matrix<T> {
499 let elements = self
501 .elements
502 .iter()
503 .map(|row| {
504 row.elements.iter().map(|elem| *elem + value).collect()
505 })
506 .collect();
507 Matrix {
508 nrows: self.nrows,
509 ncols: self.ncols,
510 elements,
511 }
512 }
513}
514
515macro_rules! impl_add_matrix_for_type {
522 ($t: ty) => {
523 impl ops::Add<Matrix<$t>> for $t {
524 type Output = Matrix<$t>;
525
526 fn add(self, m: Matrix<$t>) -> Matrix<$t> {
527 let elements = m
529 .elements
530 .iter()
531 .map(|row| {
532 row.elements.iter().map(|elem| elem + self).collect()
533 })
534 .collect();
535 Matrix {
536 nrows: m.nrows,
537 ncols: m.ncols,
538 elements,
539 }
540 }
541 }
542 };
543}
544
545impl_add_matrix_for_type!(usize);
546impl_add_matrix_for_type!(i8);
547impl_add_matrix_for_type!(i16);
548impl_add_matrix_for_type!(i32);
549impl_add_matrix_for_type!(i64);
550impl_add_matrix_for_type!(i128);
551impl_add_matrix_for_type!(u8);
552impl_add_matrix_for_type!(u16);
553impl_add_matrix_for_type!(u32);
554impl_add_matrix_for_type!(u64);
555impl_add_matrix_for_type!(u128);
556impl_add_matrix_for_type!(f32);
557impl_add_matrix_for_type!(f64);
558
559impl<T> ops::AddAssign<Matrix<T>> for Matrix<T>
562where
563 T: Num + Copy + ops::AddAssign,
564{
565 fn add_assign(&mut self, other: Matrix<T>) {
566 if self.shape() != other.shape() {
567 panic!(
568 "Matrix addition with invalid length: {:?} != {:?}",
569 self.shape(),
570 other.shape()
571 );
572 }
573
574 self.elements.iter_mut().enumerate().for_each(|(i, row)| {
575 row.elements
576 .iter_mut()
577 .enumerate()
578 .for_each(|(j, value)| *value += other[i][j])
579 })
580 }
581}
582
583impl<T> ops::AddAssign<T> for Matrix<T>
591where
592 T: Num + Copy + ops::AddAssign,
593{
594 fn add_assign(&mut self, value: T) {
595 self.elements.iter_mut().for_each(|row| {
596 row.elements.iter_mut().for_each(|elem| *elem += value)
597 })
598 }
599}
600
601impl<T> ops::Sub<Matrix<T>> for Matrix<T>
604where
605 T: Num + Copy,
606{
607 type Output = Matrix<T>;
608
609 fn sub(self, other: Matrix<T>) -> Matrix<T> {
610 if self.shape() != other.shape() {
611 panic!(
612 "Matrix substraction with invalid shape: {:?} != {:?}",
613 self.shape(),
614 other.shape()
615 );
616 }
617
618 let elements = self
620 .elements
621 .iter()
622 .enumerate()
623 .map(|(i, row)| {
624 row.elements
625 .iter()
626 .enumerate()
627 .map(|(j, value)| *value - other[i][j])
628 .collect()
629 })
630 .collect();
631 Matrix {
632 nrows: self.nrows,
633 ncols: self.ncols,
634 elements,
635 }
636 }
637}
638
639impl<T> ops::Sub<T> for Matrix<T>
646where
647 T: Num + Copy,
648{
649 type Output = Matrix<T>;
650
651 fn sub(self, value: T) -> Matrix<T> {
652 let elements = self
654 .elements
655 .iter()
656 .map(|row| {
657 row.elements.iter().map(|elem| *elem - value).collect()
658 })
659 .collect();
660 Matrix {
661 nrows: self.nrows,
662 ncols: self.ncols,
663 elements,
664 }
665 }
666}
667
668macro_rules! impl_sub_matrix_for_type {
675 ($t: ty) => {
676 impl ops::Sub<Matrix<$t>> for $t {
677 type Output = Matrix<$t>;
678
679 fn sub(self, m: Matrix<$t>) -> Matrix<$t> {
680 let elements = m
682 .elements
683 .iter()
684 .map(|row| {
685 row.elements.iter().map(|elem| self - elem).collect()
686 })
687 .collect();
688 Matrix {
689 nrows: m.nrows,
690 ncols: m.ncols,
691 elements,
692 }
693 }
694 }
695 };
696}
697
698impl_sub_matrix_for_type!(usize);
699impl_sub_matrix_for_type!(i8);
700impl_sub_matrix_for_type!(i16);
701impl_sub_matrix_for_type!(i32);
702impl_sub_matrix_for_type!(i64);
703impl_sub_matrix_for_type!(i128);
704impl_sub_matrix_for_type!(u8);
705impl_sub_matrix_for_type!(u16);
706impl_sub_matrix_for_type!(u32);
707impl_sub_matrix_for_type!(u64);
708impl_sub_matrix_for_type!(u128);
709impl_sub_matrix_for_type!(f32);
710impl_sub_matrix_for_type!(f64);
711
712impl<T> ops::SubAssign<Matrix<T>> for Matrix<T>
715where
716 T: Num + Copy + ops::SubAssign,
717{
718 fn sub_assign(&mut self, other: Matrix<T>) {
719 if self.shape() != other.shape() {
720 panic!(
721 "Matrix substraction with invalid length: {:?} != {:?}",
722 self.shape(),
723 other.shape()
724 );
725 }
726
727 self.elements.iter_mut().enumerate().for_each(|(i, row)| {
728 row.elements
729 .iter_mut()
730 .enumerate()
731 .for_each(|(j, value)| *value -= other[i][j])
732 })
733 }
734}
735
736impl<T> ops::SubAssign<T> for Matrix<T>
744where
745 T: Num + Copy + ops::SubAssign,
746{
747 fn sub_assign(&mut self, value: T) {
748 self.elements.iter_mut().for_each(|row| {
749 row.elements.iter_mut().for_each(|elem| *elem -= value)
750 })
751 }
752}
753
754impl<T> ops::Mul<Matrix<T>> for Matrix<T>
757where
758 T: Num + Copy,
759{
760 type Output = Matrix<T>;
761
762 fn mul(self, other: Matrix<T>) -> Matrix<T> {
763 if self.shape() != other.shape() {
764 panic!(
765 "Matrix multiplication with invalid shape: {:?} != {:?}",
766 self.shape(),
767 other.shape()
768 );
769 }
770
771 let elements = self
773 .elements
774 .iter()
775 .enumerate()
776 .map(|(i, row)| {
777 row.elements
778 .iter()
779 .enumerate()
780 .map(|(j, value)| *value * other[i][j])
781 .collect()
782 })
783 .collect();
784 Matrix {
785 nrows: self.nrows,
786 ncols: self.ncols,
787 elements,
788 }
789 }
790}
791
792impl<T> ops::Mul<T> for Matrix<T>
799where
800 T: Num + Copy,
801{
802 type Output = Matrix<T>;
803
804 fn mul(self, value: T) -> Matrix<T> {
805 let elements = self
807 .elements
808 .iter()
809 .map(|row| {
810 row.elements.iter().map(|elem| *elem * value).collect()
811 })
812 .collect();
813 Matrix {
814 nrows: self.nrows,
815 ncols: self.ncols,
816 elements,
817 }
818 }
819}
820
821macro_rules! impl_sub_matrix_for_type {
828 ($t: ty) => {
829 impl ops::Mul<Matrix<$t>> for $t {
830 type Output = Matrix<$t>;
831
832 fn mul(self, m: Matrix<$t>) -> Matrix<$t> {
833 let elements = m
835 .elements
836 .iter()
837 .map(|row| {
838 row.elements.iter().map(|elem| self * elem).collect()
839 })
840 .collect();
841 Matrix {
842 nrows: m.nrows,
843 ncols: m.ncols,
844 elements,
845 }
846 }
847 }
848 };
849}
850
851impl_sub_matrix_for_type!(usize);
852impl_sub_matrix_for_type!(i8);
853impl_sub_matrix_for_type!(i16);
854impl_sub_matrix_for_type!(i32);
855impl_sub_matrix_for_type!(i64);
856impl_sub_matrix_for_type!(i128);
857impl_sub_matrix_for_type!(u8);
858impl_sub_matrix_for_type!(u16);
859impl_sub_matrix_for_type!(u32);
860impl_sub_matrix_for_type!(u64);
861impl_sub_matrix_for_type!(u128);
862impl_sub_matrix_for_type!(f32);
863impl_sub_matrix_for_type!(f64);
864
865impl<T> ops::MulAssign<Matrix<T>> for Matrix<T>
868where
869 T: Num + Copy + ops::MulAssign,
870{
871 fn mul_assign(&mut self, other: Matrix<T>) {
872 if self.shape() != other.shape() {
873 panic!(
874 "Matrix multiplication with invalid length: {:?} != {:?}",
875 self.shape(),
876 other.shape()
877 );
878 }
879
880 self.elements.iter_mut().enumerate().for_each(|(i, row)| {
881 row.elements
882 .iter_mut()
883 .enumerate()
884 .for_each(|(j, value)| *value *= other[i][j])
885 })
886 }
887}
888
889impl<T> ops::MulAssign<T> for Matrix<T>
897where
898 T: Num + Copy + ops::MulAssign,
899{
900 fn mul_assign(&mut self, value: T) {
901 self.elements.iter_mut().for_each(|row| {
902 row.elements.iter_mut().for_each(|elem| *elem *= value)
903 })
904 }
905}
906
907#[derive(Debug)]
913pub struct MatrixLoaderForCSV<T, P>
914where
915 P: AsRef<Path>,
916{
917 file_path: P,
918 has_headers: bool,
919 phantom: PhantomData<T>,
920}
921
922impl<T, P> MatrixLoaderForCSV<T, P>
923where
924 P: AsRef<Path>,
925{
926 pub fn has_headers(self, yes: bool) -> MatrixLoaderForCSV<T, P> {
940 MatrixLoaderForCSV {
941 file_path: self.file_path,
942 has_headers: yes,
943 phantom: PhantomData
944 }
945 }
946
947 pub fn load(self) -> Result<Matrix<T>, LoadError>
956 where
957 T: FromPrimitive + Num + Copy + utils::TypeName,
958 {
959 let file = File::open(self.file_path)?;
961 let mut rdr = csv::ReaderBuilder::new()
962 .has_headers(self.has_headers)
963 .from_reader(file);
964 let mut elements = Vec::new();
966 for result in rdr.records() {
967 let record = result?;
969 let mut rows = Vec::with_capacity(record.len());
970 for value in record.iter() {
971 let element = match T::from_str_radix(value.trim(), 10) {
973 Ok(value) => value,
974 Err(_err) => {
975 return Err(LoadError::new(
977 LoadErrorKind::InvalidElement,
978 format!(
979 "{:?} is not valid {}",
980 value,
981 T::type_name()
982 ),
983 ));
984 }
985 };
986 rows.push(element);
987 }
988 elements.push(rows);
989 }
990 if elements.len() == 0 {
991 return Err(LoadError::new(
992 LoadErrorKind::Empty,
993 String::from("Cannot load empty file"),
994 ));
995 }
996 Ok(Matrix::from(elements))
997 }
998}