datas/
matrix.rs

1#[derive(Debug, PartialEq)]
2pub enum MatrixError {
3    InconsistentColumnSizes,
4    MultiplicationDimensionMismatch,
5    DimensionMismatch,
6    RowOutOfBound,
7}
8
9#[derive(Debug, PartialEq, Clone)]
10pub struct Matrix<T> {
11    data: Vec<Vec<T>>,
12    rows: u64,
13    cols: u64,
14}
15
16impl Matrix<i64> {
17    pub fn new(data: Vec<Vec<i64>>) -> Result<Self, MatrixError> {
18        let rows = data.len() as u64;
19        if rows == 0 {
20            return Ok(Self {
21                data,
22                rows: 0,
23                cols: 0,
24            });
25        }
26        let cols = data[0].len() as u64;
27        if !data.iter().all(|row| row.len() as u64 == cols) {
28            return Err(MatrixError::InconsistentColumnSizes);
29        }
30
31        Ok(Self { data, rows, cols })
32    }
33
34    pub fn swap_row(&mut self, f_row: u64, s_row: u64) -> Result<(), MatrixError> {
35        if f_row >= self.rows || s_row >= self.rows {
36            return Err(MatrixError::RowOutOfBound);
37        }
38        let temp: Vec<i64> = self.data[f_row as usize].clone();
39        self.data[f_row as usize] = self.data[s_row as usize].clone();
40        self.data[s_row as usize] = temp;
41        Ok(())
42    }
43
44    pub fn add(&mut self, matrix: &Matrix<i64>) -> Result<(), MatrixError> {
45        if self.cols != matrix.cols || self.rows != matrix.rows {
46            return Err(MatrixError::DimensionMismatch);
47        }
48
49        self.data
50            .iter_mut()
51            .zip(&matrix.data)
52            .for_each(|(self_row, matrix_row)| {
53                self_row
54                    .iter_mut()
55                    .zip(matrix_row)
56                    .for_each(|(a, b)| *a += b)
57            });
58        Ok(())
59    }
60    pub fn scalar_multiplication(&mut self, scalar: i64) {
61        self.data
62            .iter_mut()
63            .for_each(|col| col.iter_mut().for_each(|number| *number *= scalar));
64    }
65
66    pub fn matrix_multiplication(self, matrix: &Matrix<i64>) -> Result<Matrix<i64>, MatrixError> {
67        if self.cols != matrix.rows {
68            return Err(MatrixError::MultiplicationDimensionMismatch);
69        }
70
71        let mut data = Vec::with_capacity(self.rows as usize);
72
73        for row in &self.data {
74            let mut new_row = Vec::with_capacity(matrix.cols as usize);
75
76            for col_idx in 0..matrix.cols {
77                let mut sum = 0;
78                for (i, val) in row.iter().enumerate() {
79                    sum += val * matrix.data[i][col_idx as usize];
80                }
81                new_row.push(sum);
82            }
83
84            data.push(new_row);
85        }
86        return Ok(match Matrix::<i64>::new(data) {
87            Ok(m) => m,
88            Err(e) => panic!("{:?}", e),
89        });
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96    #[test]
97    fn imatrix_addition() {
98        let mut matrix = match Matrix::<i64>::new(vec![vec![1, 2, 3], vec![1, 2, 3], vec![1, 2, 3]])
99        {
100            Ok(m) => m,
101            Err(e) => panic!("{:?}", e),
102        };
103
104        let matrix2 = match Matrix::<i64>::new(vec![vec![1, 2, 3], vec![2, 3, 1], vec![3, 1, 2]]) {
105            Ok(m) => m,
106            Err(e) => panic!("{:?}", e),
107        };
108        let res = matrix.add(&matrix2).err();
109        println!("{:?}", matrix);
110        assert_eq!(res, None);
111    }
112
113    #[test]
114    fn imatrix_cloning() {
115        let matrix = match Matrix::<i64>::new(vec![vec![1, 2, 3], vec![1, 2, 3], vec![1, 2, 3]]) {
116            Ok(m) => m,
117            Err(e) => panic!("{:?}", e),
118        };
119
120        let matrix2 = matrix.clone();
121
122        assert_eq!(matrix, matrix2)
123    }
124
125    #[test]
126    fn imatrix_swap() {
127        let mut matrix = match Matrix::<i64>::new(vec![vec![1, 2, 3], vec![3, 2, 1], vec![1, 2, 3]])
128        {
129            Ok(m) => m,
130            Err(e) => panic!("{:?}", e),
131        };
132        let matrix2 = match Matrix::<i64>::new(vec![vec![1, 2, 3], vec![1, 2, 3], vec![3, 2, 1]]) {
133            Ok(m) => m,
134            Err(e) => panic!("{:?}", e),
135        };
136
137        // 0 based indexing
138        match matrix.swap_row(1, 2) {
139            Ok(m) => m,
140            Err(e) => panic!("{:?}", e),
141        };
142        println!("{:?}", matrix);
143        assert_eq!(matrix, matrix2)
144    }
145
146    #[test]
147    fn imatrix_multiplication() {
148        let matrix = match Matrix::<i64>::new(vec![vec![1, 2], vec![3, 4]]) {
149            Ok(m) => m,
150            Err(e) => panic!("{:?}", e),
151        };
152
153        let matrix2 = match Matrix::<i64>::new(vec![vec![2, 0], vec![1, 2]]) {
154            Ok(m) => m,
155            Err(e) => panic!("{:?}", e),
156        };
157
158        let matrix3 = match Matrix::<i64>::new(vec![vec![4, 4], vec![10, 8]]) {
159            Ok(m) => m,
160            Err(e) => panic!("{:?}", e),
161        };
162
163        assert_eq!(matrix.matrix_multiplication(&matrix2).unwrap(), matrix3)
164    }
165    // Test case for identity matrix
166    #[test]
167    fn imatrix_multiplication_identity() {
168        let matrix = match Matrix::<i64>::new(vec![vec![1, 2], vec![3, 4]]) {
169            Ok(m) => m,
170            Err(e) => panic!("{:?}", e),
171        };
172
173        let identity_matrix = match Matrix::<i64>::new(vec![vec![1, 0], vec![0, 1]]) {
174            Ok(m) => m,
175            Err(e) => panic!("{:?}", e),
176        };
177
178        let expected_result = matrix.clone(); // Multiplying by identity matrix should return the same matrix
179
180        assert_eq!(
181            matrix.matrix_multiplication(&identity_matrix).unwrap(),
182            expected_result
183        );
184    }
185
186    // Test case for multiplication by zero matrix
187    #[test]
188    fn imatrix_multiplication_zero_matrix() {
189        let matrix = match Matrix::<i64>::new(vec![vec![1, 2], vec![3, 4]]) {
190            Ok(m) => m,
191            Err(e) => panic!("{:?}", e),
192        };
193
194        let zero_matrix = match Matrix::<i64>::new(vec![vec![0, 0], vec![0, 0]]) {
195            Ok(m) => m,
196            Err(e) => panic!("{:?}", e),
197        };
198
199        let expected_result = match Matrix::<i64>::new(vec![vec![0, 0], vec![0, 0]]) {
200            Ok(m) => m,
201            Err(e) => panic!("{:?}", e),
202        };
203
204        assert_eq!(
205            matrix.matrix_multiplication(&zero_matrix).unwrap(),
206            expected_result
207        );
208    }
209
210    // Test case for non-square matrices
211    #[test]
212    fn imatrix_multiplication_non_square() {
213        let matrix = match Matrix::<i64>::new(vec![vec![1, 2, 3], vec![4, 5, 6]]) {
214            Ok(m) => m,
215            Err(e) => panic!("{:?}", e),
216        };
217
218        let matrix2 = match Matrix::<i64>::new(vec![vec![7, 8], vec![9, 10], vec![11, 12]]) {
219            Ok(m) => m,
220            Err(e) => panic!("{:?}", e),
221        };
222
223        let expected_result = match Matrix::<i64>::new(vec![vec![58, 64], vec![139, 154]]) {
224            Ok(m) => m,
225            Err(e) => panic!("{:?}", e),
226        };
227
228        assert_eq!(
229            matrix.matrix_multiplication(&matrix2).unwrap(),
230            expected_result
231        );
232    }
233
234    #[test]
235    fn imatrix_multiplication_large_matrix() {
236        let matrix = match Matrix::<i64>::new(vec![
237            vec![1, 2, 3, 4],
238            vec![5, 6, 7, 8],
239            vec![9, 10, 11, 12],
240            vec![13, 14, 15, 16],
241        ]) {
242            Ok(m) => m,
243            Err(e) => panic!("{:?}", e),
244        };
245
246        let matrix2 = match Matrix::<i64>::new(vec![
247            vec![16, 15, 14, 13],
248            vec![12, 11, 10, 9],
249            vec![8, 7, 6, 5],
250            vec![4, 3, 2, 1],
251        ]) {
252            Ok(m) => m,
253            Err(e) => panic!("{:?}", e),
254        };
255
256        let expected_result = match Matrix::<i64>::new(vec![
257            vec![80, 70, 60, 50],
258            vec![240, 214, 188, 162],
259            vec![400, 358, 316, 274],
260            vec![560, 502, 444, 386],
261        ]) {
262            Ok(m) => m,
263            Err(e) => panic!("{:?}", e),
264        };
265
266        assert_eq!(
267            matrix.matrix_multiplication(&matrix2).unwrap(),
268            expected_result
269        );
270    }
271}