crabsformer/
matrix.rs

1// Copyright (c) 2019, Bayu Aldi Yansyah <bayualdiyansyah@gmail.com>
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::error::{LoadError, LoadErrorKind};
16use crate::utils;
17use crate::vector;
18use csv;
19use num::{FromPrimitive, Num};
20use rand::distributions::uniform::SampleUniform;
21use rand::distributions::{Distribution, Normal, Uniform};
22use std::fs::File;
23use std::ops;
24use std::path::Path;
25use std::marker::PhantomData;
26
27/// Creates a [matrix] containing the arguments.
28///
29/// `matrix!` allows matrix to be defined with
30/// the same syntax as array expressions.
31///
32/// There are two forms of this macro:
33///
34/// 1. Create a matrix containing a given list of elements:
35///
36/// ```
37/// # use crabsformer::*;
38/// let w = matrix![
39///     3, 1, 4;
40///     1, 5, 9;
41/// ];
42/// assert_eq!(w[0][0], 3);
43/// assert_eq!(w[0][1], 1);
44/// assert_eq!(w[0][2], 4);
45/// assert_eq!(w[1][0], 1);
46/// assert_eq!(w[1][1], 5);
47/// assert_eq!(w[1][2], 9);
48/// ```
49///
50/// 2. Create a matrix from a given element and shape:
51///
52/// ```
53/// # use crabsformer::*;
54/// let w = matrix![1; [3, 3]];
55/// assert_eq!(w, matrix![
56///     1, 1, 1;
57///     1, 1, 1;
58///     1, 1, 1;
59/// ]);
60/// ```
61///
62/// [matrix]: struct.Matrix.html
63#[macro_export]
64macro_rules! matrix {
65    // NOTE: the order of the rules is very important
66
67    // Samples: matrix![0; [3, 3]]
68    ($elem:expr; $shape:expr) => {{
69        let nrows = $shape[0];
70        let ncols = $shape[1];
71        let elements = vec![vec![$elem; ncols]; nrows];
72        $crate::Matrix::from(elements)
73    }};
74
75    // Samples: matrix![1, 3, 4]
76    ($($x:expr),*) => {{
77        let elements = vec![vec![$($x),*]];
78        $crate::Matrix::from(elements)
79    }};
80
81    // Samples: matrix![1, 2, 3, 4,]
82    ($($x:expr,)*) => {{
83        let elements = vec![vec![$($x),*]];
84        Matrix::from(elements)
85    }};
86
87    // Samples: matrix![2.0, 1.0, 4.0; 2.0, 4.0, 2.0;]
88    ($($($x:expr),*;)*) => {{
89        let elements = vec![$(vec![$($x),*]),*];
90        Matrix::from(elements)
91    }};
92
93    // Samples: matrix![2.0, 1.0, 4.0; 2.0, 4.0, 2.0]
94    ($($($x:expr),*);*) => {{
95        let elements = vec![$(vec![$($x),*]),*];
96        Matrix::from(elements)
97    }};
98}
99
100/// Row & Column Matrix
101/// https://en.wikipedia.org/wiki/Row_and_column_vectors
102type RowMatrix<T> = vector::Vector<T>;
103// type ColMatrix<T> = vector::Vector<T>;
104
105/// Matrix.
106///
107/// TODO: add overview about matrix here.
108/// 1. how to create a matrix
109/// 2. Matrix operation
110/// 3. Indexing, etc.
111#[derive(Debug)]
112pub struct Matrix<T> {
113    /// Matrix size
114    nrows: usize,
115    ncols: usize,
116    elements: Vec<RowMatrix<T>>,
117}
118
119impl<T> Matrix<T> {
120    /// The shape of the matrix `[nrows, ncols]`.
121    ///
122    /// # Examples
123    ///
124    /// ```
125    /// # use crabsformer::*;
126    /// let W = matrix![
127    ///     3.0, 1.0;
128    ///     4.0, 1.0;
129    ///     5.0, 9.0;
130    /// ];
131    /// assert_eq!(W.shape(), [3, 2]);
132    /// ```
133    pub fn shape(&self) -> [usize; 2] {
134        [self.nrows, self.ncols]
135    }
136
137    /// Create a new matrix of given shape `shape` and type `T`,
138    /// filled with `value`.
139    ///
140    /// # Examples
141    ///
142    /// ```
143    /// # use crabsformer::*;
144    /// let W = Matrix::full([5, 5], 2.5);
145    /// ```
146    pub fn full(shape: [usize; 2], value: T) -> Matrix<T>
147    where
148        T: FromPrimitive + Num + Copy,
149    {
150        // Initialize and populate the matrix with specified value
151        let nrows = shape[0];
152        let ncols = shape[1];
153        let elements = vec![vector![value; ncols]; nrows];
154        Matrix {
155            nrows,
156            ncols,
157            elements,
158        }
159    }
160
161    /// Create a new matrix that have the same shape and type
162    /// as matrix `m`, filled with `value`.
163    ///
164    /// # Examples
165    ///
166    /// ```
167    /// # use crabsformer::*;
168    /// let w1 = matrix![
169    ///     3.0, 1.0;
170    ///     4.0, 1.0;
171    /// ];
172    /// let w2 = Matrix::full_like(&w1, 3.1415);
173    /// ```
174    pub fn full_like(m: &Matrix<T>, value: T) -> Matrix<T>
175    where
176        T: FromPrimitive + Num + Copy,
177    {
178        // Initialize and populate the matrix with specified value
179        let nrows = m.nrows;
180        let ncols = m.ncols;
181        let elements = vec![vector![value; ncols]; nrows];
182        Matrix {
183            nrows,
184            ncols,
185            elements,
186        }
187    }
188
189    /// Create a new matrix of given shape `shape` and type `T`,
190    /// filled with zeros. You need to explicitly annotate the
191    /// numeric type.
192    ///
193    /// # Examples
194    ///
195    /// ```
196    /// # use crabsformer::*;
197    /// let W: Matrix<i32> = Matrix::zeros([5, 5]);
198    /// ```
199    pub fn zeros(shape: [usize; 2]) -> Matrix<T>
200    where
201        T: FromPrimitive + Num + Copy,
202    {
203        Self::full(shape, T::from_i32(0).unwrap())
204    }
205
206    /// Create a new matrix that have the same shape and type
207    /// as matrix `m`, filled with zeros.
208    ///
209    /// # Examples
210    ///
211    /// ```
212    /// # use crabsformer::*;
213    /// let W1 = matrix![3.0, 1.0; 4.0, 1.0];
214    /// let W2 = Matrix::zeros_like(&W1);
215    /// ```
216    pub fn zeros_like(m: &Matrix<T>) -> Matrix<T>
217    where
218        T: FromPrimitive + Num + Copy,
219    {
220        Self::full([m.nrows, m.ncols], T::from_i32(0).unwrap())
221    }
222
223    /// Create a new matrix of given shaoe `shape` and type `T`,
224    /// filled with ones. You need to explicitly annotate the
225    /// numeric type.
226    ///
227    /// # Examples
228    ///
229    /// ```
230    /// # use crabsformer::*;
231    /// let W: Matrix<i32> = Matrix::ones([3, 5]);
232    /// ```
233    pub fn ones(shape: [usize; 2]) -> Matrix<T>
234    where
235        T: FromPrimitive + Num + Copy,
236    {
237        Self::full(shape, T::from_i32(1).unwrap())
238    }
239
240    /// Create a new matrix that have the same shape and type
241    /// as matrix `m`, filled with ones.
242    ///
243    /// # Examples
244    ///
245    /// ```
246    /// # use crabsformer::*;
247    /// let W1 = matrix![3, 1; 4, 1; 5, 9];
248    /// let W2 = Matrix::ones_like(&W1);
249    /// ```
250    pub fn ones_like(m: &Matrix<T>) -> Matrix<T>
251    where
252        T: FromPrimitive + Num + Copy,
253    {
254        Self::full([m.nrows, m.ncols], T::from_i32(1).unwrap())
255    }
256
257    /// Raises each elements of matrix to the power of `exp`,
258    /// using exponentiation by squaring.
259    ///
260    /// # Examples
261    ///
262    /// ```
263    /// # use crabsformer::*;
264    /// let W1 = matrix![3, 1, 4; 1, 5, 9];
265    /// let W2 = W1.power(2);
266    /// assert_eq!(W2, matrix![9, 1, 16; 1, 25, 81]);
267    /// ```
268    pub fn power(&self, exp: usize) -> Matrix<T>
269    where
270        T: FromPrimitive + Num + Copy,
271    {
272        let elements =
273            self.elements.iter().map(|row| row.power(exp)).collect();
274        Matrix {
275            nrows: self.nrows,
276            ncols: self.ncols,
277            elements,
278        }
279    }
280
281    /// Create a new matrix of the given shape `shape` and
282    /// populate it with random samples from a uniform distribution
283    /// over the half-open interval `[low, high)` (includes `low`,
284    /// but excludes `high`).
285    ///
286    /// # Examples
287    ///
288    /// ```
289    /// # use crabsformer::*;
290    /// let W = Matrix::uniform([5, 5], 0.0, 1.0);
291    /// ```
292    pub fn uniform(shape: [usize; 2], low: T, high: T) -> Matrix<T>
293    where
294        T: Num + SampleUniform + Copy,
295    {
296        // Get the shape of the matrix
297        let nrows = shape[0];
298        let ncols = shape[1];
299
300        let mut elements = Vec::with_capacity(nrows);
301        let uniform_distribution = Uniform::new(low, high);
302        // Populate the Matrix with the default value
303        let mut rng = rand::thread_rng();
304        for _ in 0..nrows {
305            let mut cols = Vec::with_capacity(ncols);
306            for _ in 0..ncols {
307                cols.push(uniform_distribution.sample(&mut rng));
308            }
309            elements.push(RowMatrix::from(cols));
310        }
311
312        Matrix {
313            nrows,
314            ncols,
315            elements,
316        }
317    }
318
319    /// Load Matrix from CSV file. You need to explicitly annotate the numeric type.
320    ///
321    /// # Examples
322    ///
323    /// ```
324    /// use crabsformer::*;
325    ///
326    /// let dataset: Matrix<f32> = Matrix::from("tests/weight.csv").load().unwrap();
327    /// ```
328    ///
329    pub fn from_csv<P>(file_path: P) -> MatrixLoaderForCSV<T, P>
330    where
331        P: AsRef<Path>,
332    {
333        MatrixLoaderForCSV {
334            file_path,
335            has_headers: false,
336            phantom: PhantomData
337        }
338    }
339}
340
341impl Matrix<f64> {
342    /// Create a new matrix of the given shape `shape` and
343    /// populate it with random samples from a normal distribution
344    /// `N(mean, std_dev**2)`.
345    ///
346    /// # Examples
347    ///
348    /// ```
349    /// # use crabsformer::*;
350    /// let W = Matrix::normal([5, 5], 0.0, 1.0); // Gaussian mean=0.0 std_dev=1.0
351    /// ```
352    pub fn normal(shape: [usize; 2], mean: f64, std_dev: f64) -> Matrix<f64> {
353        // Get the shape of the matrix
354        let nrows = shape[0];
355        let ncols = shape[1];
356
357        let mut elements = Vec::with_capacity(nrows);
358        let normal_distribution = Normal::new(mean, std_dev);
359        // Populate the Matrix with the default value
360        let mut rng = rand::thread_rng();
361        for _ in 0..nrows {
362            let mut cols = Vec::with_capacity(ncols);
363            for _ in 0..ncols {
364                cols.push(normal_distribution.sample(&mut rng));
365            }
366            elements.push(RowMatrix::from(cols));
367        }
368
369        Matrix {
370            nrows,
371            ncols,
372            elements,
373        }
374    }
375}
376
377// Conversion from Vec<Vec<T>>
378impl<T> From<Vec<Vec<T>>> for Matrix<T>
379where
380    T: Num + Copy,
381{
382    fn from(source: Vec<Vec<T>>) -> Self {
383        let nrows = source.len();
384        let ncols = source[0].len();
385        // Raise panic if number of columns on each row is inconsistent
386        let ncols_inconsistent = source.iter().any(|v| v.len() != ncols);
387        if ncols_inconsistent {
388            panic!("Invalid matrix: the number of columns is inconsistent")
389        }
390        // Convert each row to RowMatrix
391        let elements = source
392            .iter()
393            .map(|v| {
394                // We cannot directly convert &Vec<T>
395                // to RowMatrix<T> because we cannot
396                // move out borrowed content
397                let mut row = Vec::new();
398                v.iter().for_each(|x| row.push(*x));
399                RowMatrix::from(row)
400            })
401            .collect();
402
403        Matrix {
404            nrows,
405            ncols,
406            elements,
407        }
408    }
409}
410
411// Matrix comparison
412impl<T> PartialEq for Matrix<T>
413where
414    T: Num + Copy,
415{
416    fn eq(&self, other: &Matrix<T>) -> bool {
417        if self.elements != other.elements {
418            return false;
419        }
420        true
421    }
422    fn ne(&self, other: &Matrix<T>) -> bool {
423        if self.elements == other.elements {
424            return false;
425        }
426        true
427    }
428}
429
430// Implement matrix indexing
431impl<T> ops::Index<usize> for Matrix<T> {
432    type Output = RowMatrix<T>;
433
434    fn index(&self, i: usize) -> &RowMatrix<T> {
435        &self.elements[i]
436    }
437}
438
439// Implement iterator for matrix
440impl<T> IntoIterator for Matrix<T> {
441    type Item = RowMatrix<T>;
442    type IntoIter = ::std::vec::IntoIter<RowMatrix<T>>;
443
444    fn into_iter(self) -> Self::IntoIter {
445        self.elements.into_iter()
446    }
447}
448
449// This trait is implemented to support for matrix addition operator
450impl<T> ops::Add<Matrix<T>> for Matrix<T>
451where
452    T: Num + Copy,
453{
454    type Output = Matrix<T>;
455
456    fn add(self, other: Matrix<T>) -> Matrix<T> {
457        if self.shape() != other.shape() {
458            panic!(
459                "Matrix addition with invalid shape: {:?} != {:?}",
460                self.shape(),
461                other.shape()
462            );
463        }
464
465        // Add the matrix
466        let elements = self
467            .elements
468            .iter()
469            .enumerate()
470            .map(|(i, row)| {
471                row.elements
472                    .iter()
473                    .enumerate()
474                    .map(|(j, value)| *value + other[i][j])
475                    .collect()
476            })
477            .collect();
478        Matrix {
479            nrows: self.nrows,
480            ncols: self.ncols,
481            elements,
482        }
483    }
484}
485
486// This trait is implemented to support for matrix addition
487// operator with scalar on the right side,
488// for example:
489//
490// let a = matrix![5, 5; 5, 5] + 6;
491//
492impl<T> ops::Add<T> for Matrix<T>
493where
494    T: Num + Copy,
495{
496    type Output = Matrix<T>;
497
498    fn add(self, value: T) -> Matrix<T> {
499        // Add the matrix
500        let elements = self
501            .elements
502            .iter()
503            .map(|row| {
504                row.elements.iter().map(|elem| *elem + value).collect()
505            })
506            .collect();
507        Matrix {
508            nrows: self.nrows,
509            ncols: self.ncols,
510            elements,
511        }
512    }
513}
514
515// This macro is to generate support for matrix addition
516// operator with scalar on the left side,
517// for example:
518//
519// let a = 6 + matrix![5, 5; 5, 5];
520//
521macro_rules! impl_add_matrix_for_type {
522    ($t: ty) => {
523        impl ops::Add<Matrix<$t>> for $t {
524            type Output = Matrix<$t>;
525
526            fn add(self, m: Matrix<$t>) -> Matrix<$t> {
527                // Add the matrix
528                let elements = m
529                    .elements
530                    .iter()
531                    .map(|row| {
532                        row.elements.iter().map(|elem| elem + self).collect()
533                    })
534                    .collect();
535                Matrix {
536                    nrows: m.nrows,
537                    ncols: m.ncols,
538                    elements,
539                }
540            }
541        }
542    };
543}
544
545impl_add_matrix_for_type!(usize);
546impl_add_matrix_for_type!(i8);
547impl_add_matrix_for_type!(i16);
548impl_add_matrix_for_type!(i32);
549impl_add_matrix_for_type!(i64);
550impl_add_matrix_for_type!(i128);
551impl_add_matrix_for_type!(u8);
552impl_add_matrix_for_type!(u16);
553impl_add_matrix_for_type!(u32);
554impl_add_matrix_for_type!(u64);
555impl_add_matrix_for_type!(u128);
556impl_add_matrix_for_type!(f32);
557impl_add_matrix_for_type!(f64);
558
559// This trait is implemented to support for matrix addition
560// and assignment operator (+=)
561impl<T> ops::AddAssign<Matrix<T>> for Matrix<T>
562where
563    T: Num + Copy + ops::AddAssign,
564{
565    fn add_assign(&mut self, other: Matrix<T>) {
566        if self.shape() != other.shape() {
567            panic!(
568                "Matrix addition with invalid length: {:?} != {:?}",
569                self.shape(),
570                other.shape()
571            );
572        }
573
574        self.elements.iter_mut().enumerate().for_each(|(i, row)| {
575            row.elements
576                .iter_mut()
577                .enumerate()
578                .for_each(|(j, value)| *value += other[i][j])
579        })
580    }
581}
582
583// This trait is implemented to support for matrix addition
584// assignment operator (+=) with scalar on the right side,
585// for example:
586//
587// let a = matrix![5, 5; 5, 5];
588// a += 6;
589//
590impl<T> ops::AddAssign<T> for Matrix<T>
591where
592    T: Num + Copy + ops::AddAssign,
593{
594    fn add_assign(&mut self, value: T) {
595        self.elements.iter_mut().for_each(|row| {
596            row.elements.iter_mut().for_each(|elem| *elem += value)
597        })
598    }
599}
600
601// This trait is implemented to support for matrix
602// substraction operator
603impl<T> ops::Sub<Matrix<T>> for Matrix<T>
604where
605    T: Num + Copy,
606{
607    type Output = Matrix<T>;
608
609    fn sub(self, other: Matrix<T>) -> Matrix<T> {
610        if self.shape() != other.shape() {
611            panic!(
612                "Matrix substraction with invalid shape: {:?} != {:?}",
613                self.shape(),
614                other.shape()
615            );
616        }
617
618        // Substract the matrix
619        let elements = self
620            .elements
621            .iter()
622            .enumerate()
623            .map(|(i, row)| {
624                row.elements
625                    .iter()
626                    .enumerate()
627                    .map(|(j, value)| *value - other[i][j])
628                    .collect()
629            })
630            .collect();
631        Matrix {
632            nrows: self.nrows,
633            ncols: self.ncols,
634            elements,
635        }
636    }
637}
638
639// This trait is implemented to support for matrix substraction
640// operator with scalar on the right side,
641// for example:
642//
643// let a = matrix![5, 5; 5, 5] - 6;
644//
645impl<T> ops::Sub<T> for Matrix<T>
646where
647    T: Num + Copy,
648{
649    type Output = Matrix<T>;
650
651    fn sub(self, value: T) -> Matrix<T> {
652        // Substract the matrix
653        let elements = self
654            .elements
655            .iter()
656            .map(|row| {
657                row.elements.iter().map(|elem| *elem - value).collect()
658            })
659            .collect();
660        Matrix {
661            nrows: self.nrows,
662            ncols: self.ncols,
663            elements,
664        }
665    }
666}
667
668// This macro is to generate support for matrix substraction
669// operator with scalar on the left side,
670// for example:
671//
672// let a = 6 - matrix![5, 5; 5, 5];
673//
674macro_rules! impl_sub_matrix_for_type {
675    ($t: ty) => {
676        impl ops::Sub<Matrix<$t>> for $t {
677            type Output = Matrix<$t>;
678
679            fn sub(self, m: Matrix<$t>) -> Matrix<$t> {
680                // Substract the matrix
681                let elements = m
682                    .elements
683                    .iter()
684                    .map(|row| {
685                        row.elements.iter().map(|elem| self - elem).collect()
686                    })
687                    .collect();
688                Matrix {
689                    nrows: m.nrows,
690                    ncols: m.ncols,
691                    elements,
692                }
693            }
694        }
695    };
696}
697
698impl_sub_matrix_for_type!(usize);
699impl_sub_matrix_for_type!(i8);
700impl_sub_matrix_for_type!(i16);
701impl_sub_matrix_for_type!(i32);
702impl_sub_matrix_for_type!(i64);
703impl_sub_matrix_for_type!(i128);
704impl_sub_matrix_for_type!(u8);
705impl_sub_matrix_for_type!(u16);
706impl_sub_matrix_for_type!(u32);
707impl_sub_matrix_for_type!(u64);
708impl_sub_matrix_for_type!(u128);
709impl_sub_matrix_for_type!(f32);
710impl_sub_matrix_for_type!(f64);
711
712// This trait is implemented to support for matrix substraction
713// and assignment operator (-=)
714impl<T> ops::SubAssign<Matrix<T>> for Matrix<T>
715where
716    T: Num + Copy + ops::SubAssign,
717{
718    fn sub_assign(&mut self, other: Matrix<T>) {
719        if self.shape() != other.shape() {
720            panic!(
721                "Matrix substraction with invalid length: {:?} != {:?}",
722                self.shape(),
723                other.shape()
724            );
725        }
726
727        self.elements.iter_mut().enumerate().for_each(|(i, row)| {
728            row.elements
729                .iter_mut()
730                .enumerate()
731                .for_each(|(j, value)| *value -= other[i][j])
732        })
733    }
734}
735
736// This trait is implemented to support for matrix substraction
737// assignment operator (-=) with scalar on the right side,
738// for example:
739//
740// let a = matrix![5, 5; 5, 5];
741// a -= 6;
742//
743impl<T> ops::SubAssign<T> for Matrix<T>
744where
745    T: Num + Copy + ops::SubAssign,
746{
747    fn sub_assign(&mut self, value: T) {
748        self.elements.iter_mut().for_each(|row| {
749            row.elements.iter_mut().for_each(|elem| *elem -= value)
750        })
751    }
752}
753
754// This trait is implemented to support for matrix
755// multiplication operator
756impl<T> ops::Mul<Matrix<T>> for Matrix<T>
757where
758    T: Num + Copy,
759{
760    type Output = Matrix<T>;
761
762    fn mul(self, other: Matrix<T>) -> Matrix<T> {
763        if self.shape() != other.shape() {
764            panic!(
765                "Matrix multiplication with invalid shape: {:?} != {:?}",
766                self.shape(),
767                other.shape()
768            );
769        }
770
771        // Multiply the matrix
772        let elements = self
773            .elements
774            .iter()
775            .enumerate()
776            .map(|(i, row)| {
777                row.elements
778                    .iter()
779                    .enumerate()
780                    .map(|(j, value)| *value * other[i][j])
781                    .collect()
782            })
783            .collect();
784        Matrix {
785            nrows: self.nrows,
786            ncols: self.ncols,
787            elements,
788        }
789    }
790}
791
792// This trait is implemented to support for matrix multiplication
793// operator with scalar on the right side,
794// for example:
795//
796// let a = matrix![5, 5; 5, 5] * 6;
797//
798impl<T> ops::Mul<T> for Matrix<T>
799where
800    T: Num + Copy,
801{
802    type Output = Matrix<T>;
803
804    fn mul(self, value: T) -> Matrix<T> {
805        // Multiply the matrix
806        let elements = self
807            .elements
808            .iter()
809            .map(|row| {
810                row.elements.iter().map(|elem| *elem * value).collect()
811            })
812            .collect();
813        Matrix {
814            nrows: self.nrows,
815            ncols: self.ncols,
816            elements,
817        }
818    }
819}
820
821// This macro is to generate support for matrix multiplication
822// operator with scalar on the left side,
823// for example:
824//
825// let a = 6 * matrix![5, 5; 5, 5];
826//
827macro_rules! impl_sub_matrix_for_type {
828    ($t: ty) => {
829        impl ops::Mul<Matrix<$t>> for $t {
830            type Output = Matrix<$t>;
831
832            fn mul(self, m: Matrix<$t>) -> Matrix<$t> {
833                // Multiply the matrix
834                let elements = m
835                    .elements
836                    .iter()
837                    .map(|row| {
838                        row.elements.iter().map(|elem| self * elem).collect()
839                    })
840                    .collect();
841                Matrix {
842                    nrows: m.nrows,
843                    ncols: m.ncols,
844                    elements,
845                }
846            }
847        }
848    };
849}
850
851impl_sub_matrix_for_type!(usize);
852impl_sub_matrix_for_type!(i8);
853impl_sub_matrix_for_type!(i16);
854impl_sub_matrix_for_type!(i32);
855impl_sub_matrix_for_type!(i64);
856impl_sub_matrix_for_type!(i128);
857impl_sub_matrix_for_type!(u8);
858impl_sub_matrix_for_type!(u16);
859impl_sub_matrix_for_type!(u32);
860impl_sub_matrix_for_type!(u64);
861impl_sub_matrix_for_type!(u128);
862impl_sub_matrix_for_type!(f32);
863impl_sub_matrix_for_type!(f64);
864
865// This trait is implemented to support for matrix substraction
866// and assignment operator (-=)
867impl<T> ops::MulAssign<Matrix<T>> for Matrix<T>
868where
869    T: Num + Copy + ops::MulAssign,
870{
871    fn mul_assign(&mut self, other: Matrix<T>) {
872        if self.shape() != other.shape() {
873            panic!(
874                "Matrix multiplication with invalid length: {:?} != {:?}",
875                self.shape(),
876                other.shape()
877            );
878        }
879
880        self.elements.iter_mut().enumerate().for_each(|(i, row)| {
881            row.elements
882                .iter_mut()
883                .enumerate()
884                .for_each(|(j, value)| *value *= other[i][j])
885        })
886    }
887}
888
889// This trait is implemented to support for matrix multiplication
890// assignment operator (*=) with scalar on the right side,
891// for example:
892//
893// let a = matrix![5, 5; 5, 5];
894// a *= 6;
895//
896impl<T> ops::MulAssign<T> for Matrix<T>
897where
898    T: Num + Copy + ops::MulAssign,
899{
900    fn mul_assign(&mut self, value: T) {
901        self.elements.iter_mut().for_each(|row| {
902            row.elements.iter_mut().for_each(|elem| *elem *= value)
903        })
904    }
905}
906
907/// Matrix loader for CSV formatted file.
908///
909/// See also: [`Matrix::from_csv`].
910///
911/// [`Matrix::from_csv`]: struct.Matrix.html#method.from_csv
912#[derive(Debug)]
913pub struct MatrixLoaderForCSV<T, P>
914where
915    P: AsRef<Path>,
916{
917    file_path: P,
918    has_headers: bool,
919    phantom: PhantomData<T>,
920}
921
922impl<T, P> MatrixLoaderForCSV<T, P>
923where
924    P: AsRef<Path>,
925{
926    /// Set to true to treat the first row as a special header row. By default, it is set
927    /// to false.
928    ///
929    /// # Examples
930    ///
931    /// ```
932    /// use crabsformer::*;
933    ///
934    /// let dataset: Matrix<f32> = Matrix::from_csv("tests/dataset.csv")
935    ///     .has_headers(true)
936    ///     .load()
937    ///     .unwrap();
938    /// ```
939    pub fn has_headers(self, yes: bool) -> MatrixLoaderForCSV<T, P> {
940        MatrixLoaderForCSV {
941            file_path: self.file_path,
942            has_headers: yes,
943            phantom: PhantomData
944        }
945    }
946
947    /// Load Matrix from CSV file. You need to explicitly annotate the numeric type.
948    ///
949    /// # Examples
950    /// ```
951    /// use crabsformer::*;
952    ///
953    /// let dataset: Matrix<f32> = Matrix::from_csv("tests/weight.csv").load().unwrap();
954    /// ```
955    pub fn load(self) -> Result<Matrix<T>, LoadError>
956    where
957        T: FromPrimitive + Num + Copy + utils::TypeName,
958    {
959        // Open CSV file
960        let file = File::open(self.file_path)?;
961        let mut rdr = csv::ReaderBuilder::new()
962            .has_headers(self.has_headers)
963            .from_reader(file);
964        // Collect each row
965        let mut elements = Vec::new();
966        for result in rdr.records() {
967            // Convert each row in the CSV file to RowMatrix
968            let record = result?;
969            let mut rows = Vec::with_capacity(record.len());
970            for value in record.iter() {
971                // It will return error if any
972                let element = match T::from_str_radix(value.trim(), 10) {
973                    Ok(value) => value,
974                    Err(_err) => {
975                        // Return error early
976                        return Err(LoadError::new(
977                            LoadErrorKind::InvalidElement,
978                            format!(
979                                "{:?} is not valid {}",
980                                value,
981                                T::type_name()
982                            ),
983                        ));
984                    }
985                };
986                rows.push(element);
987            }
988            elements.push(rows);
989        }
990        if elements.len() == 0 {
991            return Err(LoadError::new(
992                LoadErrorKind::Empty,
993                String::from("Cannot load empty file"),
994            ));
995        }
996        Ok(Matrix::from(elements))
997    }
998}