math/tensor/
matrix.rs

1use crate::tensor::{
2    ephemeral_view::EphemeralView,
3    has_tensor_shape_data::HasTensorShapeData,
4    indexable_tensor::IndexableTensor,
5    tensor_shape::{HasTensorShape, TensorShape},
6    tensor_storage::{HasTensorData, IntoTensorStorage, TensorStorage},
7    Unitless,
8};
9use num::Num;
10use std::{
11    fmt,
12    fmt::Formatter,
13    ops::{Index, IndexMut},
14};
15
16#[derive(Clone, Debug, Eq, PartialEq)]
17pub struct Matrix<Dtype> {
18    shape: TensorShape,
19    storage: TensorStorage<Dtype>,
20}
21
22impl<Dtype> Matrix<Dtype>
23where
24    Dtype: Copy + Num,
25{
26    pub fn from_vec(
27        v: Vec<Dtype>,
28        num_rows: Unitless,
29        num_columns: Unitless,
30    ) -> Matrix<Dtype> {
31        assert_eq!(
32            v.len(),
33            (num_rows * num_columns) as usize,
34            "number of elements in the vector does not match the shape"
35        );
36        Matrix {
37            shape: create_row_major_shape(num_rows, num_columns),
38            storage: v.into_tensor_storage(),
39        }
40    }
41}
42
43impl<Dtype> HasTensorShape for Matrix<Dtype> {
44    fn shape(&self) -> &TensorShape {
45        &self.shape
46    }
47}
48
49impl<Dtype> HasTensorData<Dtype> for Matrix<Dtype> {
50    fn data(&self) -> &TensorStorage<Dtype> {
51        &self.storage
52    }
53}
54
55impl<Dtype> Index<[Unitless; 2]> for Matrix<Dtype>
56where
57    Dtype: Copy + Num,
58{
59    type Output = Dtype;
60
61    fn index(&self, index: [Unitless; 2]) -> &Self::Output {
62        &self.storage[self.coord_to_index(&[index[0], index[1]]) as usize]
63    }
64}
65
66impl<Dtype> IndexMut<[Unitless; 2]> for Matrix<Dtype>
67where
68    Dtype: Copy + Num,
69{
70    fn index_mut(&mut self, index: [Unitless; 2]) -> &mut Self::Output {
71        let index = self.coord_to_index(&[index[0], index[1]]);
72        &mut self.storage[index as usize]
73    }
74}
75
76pub trait MatrixTrait<Dtype> {
77    fn num_rows(&self) -> Unitless;
78
79    fn num_columns(&self) -> Unitless;
80}
81
82pub trait IndexableMatrix<Dtype>:
83    IndexableTensor<Dtype> + MatrixTrait<Dtype>
84where
85    Dtype: Copy + Num, {
86    fn matmul<R>(&self, other: &R) -> Matrix<Dtype>
87    where
88        R: MatrixTrait<Dtype> + IndexableTensor<Dtype>, {
89        let m = self.num_rows();
90        let n = self.num_columns();
91        let n2 = other.num_rows();
92        let l = other.num_columns();
93        assert_eq!(n, n2, "self.num_columns {} != other.num_rows {}", n, n2);
94        let mut result =
95            Matrix::from_vec(vec![Dtype::zero(); (m * l) as usize], m, l);
96        for i in 0..m {
97            for j in 0..l {
98                // multiply the i-th row against the j-th column
99                for k in 0..n {
100                    let old = result[[i, j]];
101                    let x = self.at([i, k]);
102                    let y = other.at([k, j]);
103                    result[[i, j]] = old + x * y;
104                }
105            }
106        }
107        result
108    }
109}
110
111impl<Dtype, T> IndexableMatrix<Dtype> for T
112where
113    Dtype: Copy + Num,
114    T: MatrixTrait<Dtype> + IndexableTensor<Dtype>,
115{
116}
117
118impl<Dtype> MatrixTrait<Dtype> for Matrix<Dtype> {
119    fn num_rows(&self) -> Unitless {
120        self.shape.dims_strides[0].0
121    }
122
123    fn num_columns(&self) -> Unitless {
124        self.shape.dims_strides[1].0
125    }
126}
127
128fn create_row_major_shape(
129    num_rows: Unitless,
130    num_columns: Unitless,
131) -> TensorShape {
132    TensorShape {
133        dims_strides: vec![(num_rows, num_columns), (num_columns, 1)],
134    }
135}
136
137impl<'a, Dtype> From<&'a Matrix<Dtype>> for EphemeralView<'a, Dtype> {
138    fn from(matrix: &'a Matrix<Dtype>) -> Self {
139        EphemeralView::new(&matrix.data(), matrix.shape().clone())
140    }
141}
142
143impl<Dtype> fmt::Display for Matrix<Dtype>
144where
145    Dtype: Copy + Num + fmt::Display,
146{
147    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
148        let num_rows = self.num_rows();
149        let num_columns = self.num_columns();
150        write!(f, "[")?;
151        for i in 0..num_rows {
152            write!(f, "[")?;
153            for j in 0..num_columns {
154                if (j + 1) == num_columns {
155                    if (i + 1) != num_rows {
156                        write!(f, "{}]\n", self[[i, j]])?;
157                    } else {
158                        write!(f, "{}]", self[[i, j]])?;
159                    }
160                } else {
161                    write!(f, "{}, ", self[[i, j]])?;
162                }
163            }
164        }
165        write!(f, "]")?;
166        Ok(())
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use crate::tensor::matrix_transpose::MatrixTranspose;
174
175    #[test]
176    fn test_matmul() {
177        let a = Matrix::from_vec(vec![1, 2, 3, 4], 2, 2);
178        let b = Matrix::from_vec(vec![1, 2, 3, 4], 2, 2);
179        let c = a.matmul(&b);
180        assert_eq!(c.data().vec, vec![7, 10, 15, 22]);
181
182        let d = Matrix::from_vec(vec![1, 1], 2, 1);
183        let e = a.matmul(&d);
184        assert_eq!(e.storage.vec, vec![3, 7]);
185
186        let a = Matrix::from_vec(vec![1, 2, 3, 4], 2, 2);
187        let res = a.t().matmul(&a);
188        assert_eq!(res, Matrix::from_vec(vec![10, 14, 14, 20], 2, 2));
189    }
190
191    #[test]
192    fn test_print_matrix() {
193        fn get_display_string(
194            vec: Vec<i32>,
195            num_rows: Unitless,
196            num_columns: Unitless,
197        ) -> String {
198            let m = Matrix::from_vec(vec, num_rows, num_columns);
199            fmt::format(format_args!("{}", m))
200        }
201
202        assert_eq!(
203            get_display_string(vec![1, 2, 3, 4], 2, 2),
204            "[[1, 2]\n[3, 4]]"
205        );
206        assert_eq!(get_display_string(vec![1, 2], 2, 1), "[[1]\n[2]]");
207        assert_eq!(get_display_string(vec![1, 2], 1, 2), "[[1, 2]]");
208        assert_eq!(get_display_string(vec![], 0, 0), "[]");
209        assert_eq!(
210            get_display_string(
211                vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
212                3,
213                4,
214            ),
215            "[[1, 2, 3, 4]\n[5, 6, 7, 8]\n[9, 10, 11, 12]]"
216        );
217    }
218}