1use crate::tensor::{
2 ephemeral_view::EphemeralView,
3 has_tensor_shape_data::HasTensorShapeData,
4 indexable_tensor::IndexableTensor,
5 tensor_shape::{HasTensorShape, TensorShape},
6 tensor_storage::{HasTensorData, IntoTensorStorage, TensorStorage},
7 Unitless,
8};
9use num::Num;
10use std::{
11 fmt,
12 fmt::Formatter,
13 ops::{Index, IndexMut},
14};
15
16#[derive(Clone, Debug, Eq, PartialEq)]
17pub struct Matrix<Dtype> {
18 shape: TensorShape,
19 storage: TensorStorage<Dtype>,
20}
21
22impl<Dtype> Matrix<Dtype>
23where
24 Dtype: Copy + Num,
25{
26 pub fn from_vec(
27 v: Vec<Dtype>,
28 num_rows: Unitless,
29 num_columns: Unitless,
30 ) -> Matrix<Dtype> {
31 assert_eq!(
32 v.len(),
33 (num_rows * num_columns) as usize,
34 "number of elements in the vector does not match the shape"
35 );
36 Matrix {
37 shape: create_row_major_shape(num_rows, num_columns),
38 storage: v.into_tensor_storage(),
39 }
40 }
41}
42
43impl<Dtype> HasTensorShape for Matrix<Dtype> {
44 fn shape(&self) -> &TensorShape {
45 &self.shape
46 }
47}
48
49impl<Dtype> HasTensorData<Dtype> for Matrix<Dtype> {
50 fn data(&self) -> &TensorStorage<Dtype> {
51 &self.storage
52 }
53}
54
55impl<Dtype> Index<[Unitless; 2]> for Matrix<Dtype>
56where
57 Dtype: Copy + Num,
58{
59 type Output = Dtype;
60
61 fn index(&self, index: [Unitless; 2]) -> &Self::Output {
62 &self.storage[self.coord_to_index(&[index[0], index[1]]) as usize]
63 }
64}
65
66impl<Dtype> IndexMut<[Unitless; 2]> for Matrix<Dtype>
67where
68 Dtype: Copy + Num,
69{
70 fn index_mut(&mut self, index: [Unitless; 2]) -> &mut Self::Output {
71 let index = self.coord_to_index(&[index[0], index[1]]);
72 &mut self.storage[index as usize]
73 }
74}
75
76pub trait MatrixTrait<Dtype> {
77 fn num_rows(&self) -> Unitless;
78
79 fn num_columns(&self) -> Unitless;
80}
81
82pub trait IndexableMatrix<Dtype>:
83 IndexableTensor<Dtype> + MatrixTrait<Dtype>
84where
85 Dtype: Copy + Num, {
86 fn matmul<R>(&self, other: &R) -> Matrix<Dtype>
87 where
88 R: MatrixTrait<Dtype> + IndexableTensor<Dtype>, {
89 let m = self.num_rows();
90 let n = self.num_columns();
91 let n2 = other.num_rows();
92 let l = other.num_columns();
93 assert_eq!(n, n2, "self.num_columns {} != other.num_rows {}", n, n2);
94 let mut result =
95 Matrix::from_vec(vec![Dtype::zero(); (m * l) as usize], m, l);
96 for i in 0..m {
97 for j in 0..l {
98 for k in 0..n {
100 let old = result[[i, j]];
101 let x = self.at([i, k]);
102 let y = other.at([k, j]);
103 result[[i, j]] = old + x * y;
104 }
105 }
106 }
107 result
108 }
109}
110
111impl<Dtype, T> IndexableMatrix<Dtype> for T
112where
113 Dtype: Copy + Num,
114 T: MatrixTrait<Dtype> + IndexableTensor<Dtype>,
115{
116}
117
118impl<Dtype> MatrixTrait<Dtype> for Matrix<Dtype> {
119 fn num_rows(&self) -> Unitless {
120 self.shape.dims_strides[0].0
121 }
122
123 fn num_columns(&self) -> Unitless {
124 self.shape.dims_strides[1].0
125 }
126}
127
128fn create_row_major_shape(
129 num_rows: Unitless,
130 num_columns: Unitless,
131) -> TensorShape {
132 TensorShape {
133 dims_strides: vec![(num_rows, num_columns), (num_columns, 1)],
134 }
135}
136
137impl<'a, Dtype> From<&'a Matrix<Dtype>> for EphemeralView<'a, Dtype> {
138 fn from(matrix: &'a Matrix<Dtype>) -> Self {
139 EphemeralView::new(&matrix.data(), matrix.shape().clone())
140 }
141}
142
143impl<Dtype> fmt::Display for Matrix<Dtype>
144where
145 Dtype: Copy + Num + fmt::Display,
146{
147 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
148 let num_rows = self.num_rows();
149 let num_columns = self.num_columns();
150 write!(f, "[")?;
151 for i in 0..num_rows {
152 write!(f, "[")?;
153 for j in 0..num_columns {
154 if (j + 1) == num_columns {
155 if (i + 1) != num_rows {
156 write!(f, "{}]\n", self[[i, j]])?;
157 } else {
158 write!(f, "{}]", self[[i, j]])?;
159 }
160 } else {
161 write!(f, "{}, ", self[[i, j]])?;
162 }
163 }
164 }
165 write!(f, "]")?;
166 Ok(())
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use crate::tensor::matrix_transpose::MatrixTranspose;
174
175 #[test]
176 fn test_matmul() {
177 let a = Matrix::from_vec(vec![1, 2, 3, 4], 2, 2);
178 let b = Matrix::from_vec(vec![1, 2, 3, 4], 2, 2);
179 let c = a.matmul(&b);
180 assert_eq!(c.data().vec, vec![7, 10, 15, 22]);
181
182 let d = Matrix::from_vec(vec![1, 1], 2, 1);
183 let e = a.matmul(&d);
184 assert_eq!(e.storage.vec, vec![3, 7]);
185
186 let a = Matrix::from_vec(vec![1, 2, 3, 4], 2, 2);
187 let res = a.t().matmul(&a);
188 assert_eq!(res, Matrix::from_vec(vec![10, 14, 14, 20], 2, 2));
189 }
190
191 #[test]
192 fn test_print_matrix() {
193 fn get_display_string(
194 vec: Vec<i32>,
195 num_rows: Unitless,
196 num_columns: Unitless,
197 ) -> String {
198 let m = Matrix::from_vec(vec, num_rows, num_columns);
199 fmt::format(format_args!("{}", m))
200 }
201
202 assert_eq!(
203 get_display_string(vec![1, 2, 3, 4], 2, 2),
204 "[[1, 2]\n[3, 4]]"
205 );
206 assert_eq!(get_display_string(vec![1, 2], 2, 1), "[[1]\n[2]]");
207 assert_eq!(get_display_string(vec![1, 2], 1, 2), "[[1, 2]]");
208 assert_eq!(get_display_string(vec![], 0, 0), "[]");
209 assert_eq!(
210 get_display_string(
211 vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
212 3,
213 4,
214 ),
215 "[[1, 2, 3, 4]\n[5, 6, 7, 8]\n[9, 10, 11, 12]]"
216 );
217 }
218}