1use rand::{Rng, rngs::StdRng};
2use rayon::prelude::*;
3use serde::{Deserialize, Serialize};
4use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign};
5
6#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
8pub struct Matrix {
9 rows: usize,
10 cols: usize,
11 data: Vec<f64>,
12}
13
14impl Matrix {
15 pub fn new(rows: usize, cols: usize) -> Self {
18 let data = vec![0.0; rows * cols];
19 Self { rows, cols, data }
20 }
21
22 pub fn random(rng: &mut StdRng, rows: usize, cols: usize) -> Self {
25 let data = (0..(rows * cols))
26 .map(|_| rng.random_range(-1.0..1.0))
27 .collect();
28 Self { rows, cols, data }
29 }
30
31 pub fn from_vec(rows: usize, cols: usize, data: Vec<f64>) -> Self {
35 if data.len() != rows * cols {
36 panic!("data length does not match row and col count")
37 }
38 Self { rows, cols, data }
39 }
40
41 pub fn from_col_vec(data: Vec<f64>) -> Self {
43 let rows = data.len();
44 let cols = 1;
45 Self::from_vec(rows, cols, data)
46 }
47
48 pub fn transpose(&self) -> Self {
50 let mut transposed_data = vec![0.0; self.rows * self.cols];
51 for i in 0..self.rows {
52 for j in 0..self.cols {
53 transposed_data[j * self.rows + i] = self.data[i * self.cols + j];
54 }
55 }
56 Self::from_vec(self.cols, self.rows, transposed_data)
57 }
58
59 pub fn rows(&self) -> usize {
61 self.rows
62 }
63
64 pub fn cols(&self) -> usize {
66 self.cols
67 }
68
69 pub fn col(&self, col: usize) -> Vec<f64> {
72 if col >= self.cols {
73 panic!("Index out of bounds");
74 }
75 (0..self.rows)
76 .map(|i| self.data[i * self.cols + col])
77 .collect()
78 }
79
80 pub fn data(&self) -> &Vec<f64> {
82 &self.data
83 }
84
85 pub fn data_mut(&mut self) -> &mut Vec<f64> {
87 &mut self.data
88 }
89
90 pub fn get(&self, row: usize, col: usize) -> f64 {
93 if row >= self.rows || col >= self.cols {
94 panic!("Index out of bounds");
95 }
96 self.data[row * self.cols + col]
97 }
98
99 pub fn get_mut(&mut self, row: usize, col: usize) -> &mut f64 {
102 if row >= self.rows || col >= self.cols {
103 panic!("Index out of bounds");
104 }
105 &mut self.data[row * self.cols + col]
106 }
107
108 pub fn set(&mut self, row: usize, col: usize, value: f64) {
111 if row >= self.rows || col >= self.cols {
112 panic!("Index out of bounds");
113 }
114 self.data[row * self.cols + col] = value;
115 }
116
117 pub fn apply<F>(&mut self, f: F)
118 where
119 F: Fn(f64) -> f64,
120 {
121 for i in 0..self.rows {
122 for j in 0..self.cols {
123 let index = i * self.cols + j;
124 self.data[index] = f(self.data[index]);
125 }
126 }
127 }
128
129 pub fn hadamard_product(&mut self, other: &Matrix) {
130 if self.rows != other.rows || self.cols != other.cols {
131 panic!("Matrices must have the same dimensions for Hadamard product");
132 }
133 for i in 0..self.rows {
134 for j in 0..self.cols {
135 self.set(i, j, self.get(i, j) * other.get(i, j));
136 }
137 }
138 }
139
140 fn multiply_matrix_parallelized(&self, other: &Matrix) -> Matrix {
141 if self.cols != other.rows {
142 panic!("Matrices have incompatible dimensions for multiplication");
143 }
144
145 let other_t = Arc::new(other.transpose()); let self_data = &self.data;
147 let other_data = &other_t.data;
148 let self_cols = self.cols;
149 let other_cols = other.cols;
150
151 let result_data: Vec<f64> = (0..self.rows)
152 .into_par_iter()
153 .flat_map_iter(|i| {
154 (0..other_t.rows).map(move |j| {
155 let mut sum = 0.0;
156 let row_start = i * self_cols;
157 let col_start = j * self_cols;
158 for k in 0..self_cols {
159 sum += self_data[row_start + k] * other_data[col_start + k];
160 }
161 sum
162 })
163 })
164 .collect();
165
166 Matrix::from_vec(self.rows, other_cols, result_data)
167 }
168
169 fn multiply_matrix_naive(&self, other: &Matrix) -> Matrix {
170 if self.cols != other.rows {
171 panic!("Matrices have incompatible dimensions for multiplication");
172 }
173
174 let other_t = other.transpose(); let mut result = Matrix::new(self.rows, other.cols);
176
177 let self_data = &self.data;
178 let other_data = &other_t.data;
179 let result_data = &mut result.data;
180
181 let m = self.rows;
182 let n = self.cols;
183 let p = other.cols;
184
185 for i in 0..m {
186 for j in 0..p {
187 let mut sum = 0.0;
188 let a_row = i * n;
189 let b_row = j * n; for k in 0..n {
191 sum += self_data[a_row + k] * other_data[b_row + k];
192 }
193 result_data[i * p + j] = sum;
194 }
195 }
196
197 result
198 }
199
200 pub fn multiply_matrix(&self, other: &Matrix) -> Matrix {
201 if self.rows * other.cols >= 128 * 128 {
202 self.multiply_matrix_parallelized(other)
203 } else {
204 self.multiply_matrix_naive(other)
205 }
206 }
207}
208
209impl Add<&Matrix> for Matrix {
210 type Output = Matrix;
211
212 fn add(self, other: &Matrix) -> Matrix {
215 if self.rows != other.rows || self.cols != other.cols {
216 panic!("Matrices must have the same dimensions to be added");
217 }
218 let mut result = Matrix::new(self.rows, self.cols);
219 for i in 0..self.rows {
220 for j in 0..self.cols {
221 result.set(i, j, self.get(i, j) + other.get(i, j));
222 }
223 }
224 result
225 }
226}
227
228impl AddAssign<&Matrix> for Matrix {
229 fn add_assign(&mut self, other: &Matrix) {
233 if self.rows != other.rows || self.cols != other.cols {
234 panic!("Matrices must have the same dimensions to be added");
235 }
236 for i in 0..self.rows {
237 for j in 0..self.cols {
238 self.set(i, j, self.get(i, j) + other.get(i, j));
239 }
240 }
241 }
242}
243
244impl Sub<&Matrix> for Matrix {
245 type Output = Matrix;
246
247 fn sub(self, rhs: &Matrix) -> Self::Output {
250 if self.rows != rhs.rows || self.cols != rhs.cols {
251 panic!("Matrices must have the same dimensions to be subtracted");
252 }
253 let mut result = Matrix::new(self.rows, self.cols);
254 for i in 0..self.rows {
255 for j in 0..self.cols {
256 result.set(i, j, self.get(i, j) - rhs.get(i, j));
257 }
258 }
259 result
260 }
261}
262
263impl SubAssign<&Matrix> for Matrix {
264 fn sub_assign(&mut self, other: &Matrix) {
265 if self.rows != other.rows || self.cols != other.cols {
266 panic!("Matrices must have the same dimensions to be added");
267 }
268 for i in 0..self.rows {
269 for j in 0..self.cols {
270 self.set(i, j, self.get(i, j) - other.get(i, j));
271 }
272 }
273 }
274}
275
276impl Mul<f64> for Matrix {
277 type Output = Matrix;
278
279 fn mul(self, scalar: f64) -> Matrix {
281 let mut result = Matrix::new(self.rows, self.cols);
282 for i in 0..self.rows {
283 for j in 0..self.cols {
284 result.data[i * self.cols + j] = self.data[i * self.cols + j] * scalar;
285 }
286 }
287 result
288 }
289}
290
291impl MulAssign<f64> for Matrix {
292 fn mul_assign(&mut self, scalar: f64) {
294 for i in 0..self.rows {
295 for j in 0..self.cols {
296 self.data[i * self.cols + j] *= scalar;
297 }
298 }
299 }
300}
301
302use std::sync::Arc;
303
304impl Mul<&Matrix> for &Matrix {
305 type Output = Matrix;
306
307 fn mul(self, other: &Matrix) -> Matrix {
308 self.multiply_matrix(other)
309 }
310}
311
312#[cfg(test)]
313mod matrix_tests {
314 use rand::SeedableRng;
315
316 use super::*;
317
318 #[test]
319 fn it_works() {
320 let m = Matrix::new(2, 3);
321 assert_eq!(m.rows(), 2);
322 assert_eq!(m.cols(), 3);
323 assert_eq!(m.data().len(), 2 * 3);
324 }
325
326 #[test]
327 fn it_creates_random_matrix() {
328 let mut rng = StdRng::from_os_rng();
329 let m = Matrix::random(&mut rng, 2, 3);
330 assert_eq!(m.rows, 2);
331 assert_eq!(m.cols, 3);
332 assert_eq!(m.data.len(), 2 * 3);
333 for i in 0..2 {
334 for j in 0..3 {
335 assert!(m.get(i, j) >= -1.0 && m.get(i, j) <= 1.0);
336 }
337 }
338 }
339
340 #[test]
341 fn it_creates_a_matrix_from_a_vector() {
342 let v = vec![1.0, 2.0, 5.0, 3.0, 4.0, 6.0];
343 let m = Matrix::from_vec(2, 3, v.clone());
344 assert_eq!(m.rows, 2);
345 assert_eq!(m.cols, 3);
346 assert_eq!(m.data, v);
347 }
348
349 #[test]
350 fn it_transposes_matrix() {
351 let m = Matrix::from_vec(
352 3,
353 2,
354 vec![
355 1.0, 2.0, 5.0, 3.0, 4.0, 6.0,
356 ],
357 );
358 let transposed = m.transpose();
359 assert_eq!(transposed.rows, 2);
360 assert_eq!(transposed.cols, 3);
361 assert_eq!(transposed.get(0, 0), 1.0);
362 assert_eq!(transposed.get(0, 1), 5.0);
363 assert_eq!(transposed.get(0, 2), 4.0);
364 assert_eq!(transposed.get(1, 0), 2.0);
365 assert_eq!(transposed.get(1, 1), 3.0);
366 assert_eq!(transposed.get(1, 2), 6.0);
367 }
368
369 #[test]
370 fn it_gets_and_sets_values() {
371 let mut m = Matrix::new(2, 3);
372 m.set(0, 0, 1.0);
373 m.set(1, 2, 2.0);
374 assert_eq!(m.get(0, 0), 1.0);
375 assert_eq!(m.get(1, 2), 2.0);
376 assert_eq!(m.get(0, 1), 0.0);
377 assert_eq!(m.get(1, 0), 0.0);
378 }
379
380 #[test]
381 #[should_panic(expected = "Index out of bounds")]
382 fn it_panics_on_out_of_bounds_get() {
383 let m = Matrix::new(2, 3);
384 m.get(2, 0);
385 }
386
387 #[test]
388 #[should_panic(expected = "Index out of bounds")]
389 fn it_panics_on_out_of_bounds_set() {
390 let mut m = Matrix::new(2, 3);
391 m.set(2, 0, 1.0);
392 }
393
394 #[test]
395 #[should_panic(expected = "Index out of bounds")]
396 fn it_panics_on_out_of_bounds_get_mut() {
397 let mut m = Matrix::new(2, 3);
398 m.get_mut(2, 0);
399 }
400
401 #[test]
402 #[should_panic(expected = "Index out of bounds")]
403 fn it_panics_on_out_of_bounds_set_mut() {
404 let mut m = Matrix::new(2, 3);
405 m.get_mut(2, 0);
406 }
407
408 #[test]
409 fn it_gets_and_sets_mutable_values() {
410 let mut m = Matrix::new(2, 3);
411 *m.get_mut(0, 0) = 1.0;
412 *m.get_mut(1, 2) = 2.0;
413 assert_eq!(m.get(0, 0), 1.0);
414 assert_eq!(m.get(1, 2), 2.0);
415 assert_eq!(m.get(0, 1), 0.0);
416 assert_eq!(m.get(1, 0), 0.0);
417 }
418
419 #[test]
420 fn it_returns_mutable_data() {
421 let mut m = Matrix::new(2, 3);
422 m.data_mut()[0] = 1.0;
423 m.data_mut()[1 * 3 + 2] = 2.0;
424 assert_eq!(m.get(0, 0), 1.0);
425 assert_eq!(m.get(1, 2), 2.0);
426 assert_eq!(m.get(0, 1), 0.0);
427 assert_eq!(m.get(1, 0), 0.0);
428 }
429
430 #[test]
431 fn it_adds_matrices() {
432 let m1 = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
433 let m2 = Matrix::from_vec(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
434 let result = m1 + &m2;
435 assert_eq!(result.get(0, 0), 6.0);
436 assert_eq!(result.get(0, 1), 8.0);
437 assert_eq!(result.get(1, 0), 10.0);
438 assert_eq!(result.get(1, 1), 12.0);
439 }
440
441 #[test]
442 fn it_adds_and_assigns() {
443 let mut m1 = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
444 let m2 = Matrix::from_vec(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
445 m1 += &m2;
446 assert_eq!(m1.get(0, 0), 6.0);
447 assert_eq!(m1.get(0, 1), 8.0);
448 assert_eq!(m1.get(1, 0), 10.0);
449 assert_eq!(m1.get(1, 1), 12.0);
450 }
451
452 #[test]
453 fn it_multiplies_by_scalar() {
454 let m = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
455 let result = m * 2.0;
456 assert_eq!(result.get(0, 0), 2.0);
457 assert_eq!(result.get(0, 1), 4.0);
458 assert_eq!(result.get(1, 0), 6.0);
459 assert_eq!(result.get(1, 1), 8.0);
460 }
461
462 #[test]
463 fn it_multiplies_by_scalar_in_place() {
464 let mut m = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
465 m *= 2.0;
466 assert_eq!(m.get(0, 0), 2.0);
467 assert_eq!(m.get(0, 1), 4.0);
468 assert_eq!(m.get(1, 0), 6.0);
469 assert_eq!(m.get(1, 1), 8.0);
470 }
471
472 #[test]
473 fn it_multiplies_matrices() {
474 let m = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
475 let n = Matrix::from_vec(3, 2, vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]);
476 let e = Matrix::from_vec(2, 2, vec![58.0, 64.0, 139.0, 154.0]);
477 let r = &m * &n;
478 assert_eq!(r, e);
479 }
480
481 #[test]
482 fn it_maps() {
483 let mut m = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
484 m.apply(|x| x * 2.0);
485 assert_eq!(m.get(0, 0), 2.0);
486 assert_eq!(m.get(0, 1), 4.0);
487 assert_eq!(m.get(1, 0), 6.0);
488 assert_eq!(m.get(1, 1), 8.0);
489 }
490}