1use rand::{Rng, rngs::StdRng};
2use rayon::prelude::*;
3use serde::{Deserialize, Serialize};
4use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign};
5
6#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
8pub struct Matrix {
9 rows: usize,
10 cols: usize,
11 data: Vec<f64>,
12}
13
14impl Matrix {
15 pub fn new(rows: usize, cols: usize) -> Self {
18 let data = vec![0.0; rows * cols];
19 Self { rows, cols, data }
20 }
21
22 pub fn random(rng: &mut StdRng, rows: usize, cols: usize) -> Self {
25 Self::random_range(rng, rows, cols, -1.0, 1.0)
26 }
27
28 pub fn random_range(rng: &mut StdRng, rows: usize, cols: usize, min: f64, max: f64) -> Self {
31 let data = (0..(rows * cols))
32 .map(|_| rng.random_range(min..max))
33 .collect();
34 Self { rows, cols, data }
35 }
36
37 pub fn from_vec(rows: usize, cols: usize, data: Vec<f64>) -> Self {
41 if data.len() != rows * cols {
42 panic!("data length does not match row and col count")
43 }
44 Self { rows, cols, data }
45 }
46
47 pub fn from_col_vec(data: Vec<f64>) -> Self {
49 let rows = data.len();
50 let cols = 1;
51 Self::from_vec(rows, cols, data)
52 }
53
54 pub fn transpose(&self) -> Self {
56 let mut transposed_data = vec![0.0; self.rows * self.cols];
57 for i in 0..self.rows {
58 for j in 0..self.cols {
59 transposed_data[j * self.rows + i] = self.data[i * self.cols + j];
60 }
61 }
62 Self::from_vec(self.cols, self.rows, transposed_data)
63 }
64
65 pub fn rows(&self) -> usize {
67 self.rows
68 }
69
70 pub fn cols(&self) -> usize {
72 self.cols
73 }
74
75 pub fn col(&self, col: usize) -> Vec<f64> {
78 if col >= self.cols {
79 panic!("Index out of bounds");
80 }
81 (0..self.rows)
82 .map(|i| self.data[i * self.cols + col])
83 .collect()
84 }
85
86 pub fn data(&self) -> &Vec<f64> {
88 &self.data
89 }
90
91 pub fn data_mut(&mut self) -> &mut Vec<f64> {
93 &mut self.data
94 }
95
96 pub fn get(&self, row: usize, col: usize) -> f64 {
99 if row >= self.rows || col >= self.cols {
100 panic!("Index out of bounds");
101 }
102 self.data[row * self.cols + col]
103 }
104
105 pub fn get_mut(&mut self, row: usize, col: usize) -> &mut f64 {
108 if row >= self.rows || col >= self.cols {
109 panic!("Index out of bounds");
110 }
111 &mut self.data[row * self.cols + col]
112 }
113
114 pub fn set(&mut self, row: usize, col: usize, value: f64) {
117 if row >= self.rows || col >= self.cols {
118 panic!("Index out of bounds");
119 }
120 self.data[row * self.cols + col] = value;
121 }
122
123 pub fn apply<F>(&mut self, f: F)
124 where
125 F: Fn(f64) -> f64,
126 {
127 for i in 0..self.rows {
128 for j in 0..self.cols {
129 let index = i * self.cols + j;
130 self.data[index] = f(self.data[index]);
131 }
132 }
133 }
134
135 pub fn hadamard_product(&mut self, other: &Matrix) {
136 if self.rows != other.rows || self.cols != other.cols {
137 panic!("Matrices must have the same dimensions for Hadamard product");
138 }
139 for i in 0..self.rows {
140 for j in 0..self.cols {
141 self.set(i, j, self.get(i, j) * other.get(i, j));
142 }
143 }
144 }
145
146 fn multiply_matrix_parallelized(&self, other: &Matrix) -> Matrix {
147 if self.cols != other.rows {
148 panic!("Matrices have incompatible dimensions for multiplication");
149 }
150
151 let other_t = Arc::new(other.transpose()); let self_data = &self.data;
153 let other_data = &other_t.data;
154 let self_cols = self.cols;
155 let other_cols = other.cols;
156
157 let result_data: Vec<f64> = (0..self.rows)
158 .into_par_iter()
159 .flat_map_iter(|i| {
160 (0..other_t.rows).map(move |j| {
161 let mut sum = 0.0;
162 let row_start = i * self_cols;
163 let col_start = j * self_cols;
164 for k in 0..self_cols {
165 sum += self_data[row_start + k] * other_data[col_start + k];
166 }
167 sum
168 })
169 })
170 .collect();
171
172 Matrix::from_vec(self.rows, other_cols, result_data)
173 }
174
175 fn multiply_matrix_naive(&self, other: &Matrix) -> Matrix {
176 if self.cols != other.rows {
177 panic!("Matrices have incompatible dimensions for multiplication");
178 }
179
180 let other_t = other.transpose(); let mut result = Matrix::new(self.rows, other.cols);
182
183 let self_data = &self.data;
184 let other_data = &other_t.data;
185 let result_data = &mut result.data;
186
187 let m = self.rows;
188 let n = self.cols;
189 let p = other.cols;
190
191 for i in 0..m {
192 for j in 0..p {
193 let mut sum = 0.0;
194 let a_row = i * n;
195 let b_row = j * n; for k in 0..n {
197 sum += self_data[a_row + k] * other_data[b_row + k];
198 }
199 result_data[i * p + j] = sum;
200 }
201 }
202
203 result
204 }
205
206 pub fn multiply_matrix(&self, other: &Matrix) -> Matrix {
207 if self.rows * other.cols >= 128 * 128 {
208 self.multiply_matrix_parallelized(other)
209 } else {
210 self.multiply_matrix_naive(other)
211 }
212 }
213}
214
215impl Add<&Matrix> for Matrix {
216 type Output = Matrix;
217
218 fn add(self, other: &Matrix) -> Matrix {
221 if self.rows != other.rows || self.cols != other.cols {
222 panic!("Matrices must have the same dimensions to be added");
223 }
224 let mut result = Matrix::new(self.rows, self.cols);
225 for i in 0..self.rows {
226 for j in 0..self.cols {
227 result.set(i, j, self.get(i, j) + other.get(i, j));
228 }
229 }
230 result
231 }
232}
233
234impl AddAssign<&Matrix> for Matrix {
235 fn add_assign(&mut self, other: &Matrix) {
239 if self.rows != other.rows || self.cols != other.cols {
240 panic!("Matrices must have the same dimensions to be added");
241 }
242 for i in 0..self.rows {
243 for j in 0..self.cols {
244 self.set(i, j, self.get(i, j) + other.get(i, j));
245 }
246 }
247 }
248}
249
250impl Sub<&Matrix> for Matrix {
251 type Output = Matrix;
252
253 fn sub(self, rhs: &Matrix) -> Self::Output {
256 if self.rows != rhs.rows || self.cols != rhs.cols {
257 panic!("Matrices must have the same dimensions to be subtracted");
258 }
259 let mut result = Matrix::new(self.rows, self.cols);
260 for i in 0..self.rows {
261 for j in 0..self.cols {
262 result.set(i, j, self.get(i, j) - rhs.get(i, j));
263 }
264 }
265 result
266 }
267}
268
269impl SubAssign<&Matrix> for Matrix {
270 fn sub_assign(&mut self, other: &Matrix) {
271 if self.rows != other.rows || self.cols != other.cols {
272 panic!("Matrices must have the same dimensions to be added");
273 }
274 for i in 0..self.rows {
275 for j in 0..self.cols {
276 self.set(i, j, self.get(i, j) - other.get(i, j));
277 }
278 }
279 }
280}
281
282impl Mul<f64> for Matrix {
283 type Output = Matrix;
284
285 fn mul(self, scalar: f64) -> Matrix {
287 let mut result = Matrix::new(self.rows, self.cols);
288 for i in 0..self.rows {
289 for j in 0..self.cols {
290 result.data[i * self.cols + j] = self.data[i * self.cols + j] * scalar;
291 }
292 }
293 result
294 }
295}
296
297impl MulAssign<f64> for Matrix {
298 fn mul_assign(&mut self, scalar: f64) {
300 for i in 0..self.rows {
301 for j in 0..self.cols {
302 self.data[i * self.cols + j] *= scalar;
303 }
304 }
305 }
306}
307
308use std::sync::Arc;
309
310impl Mul<&Matrix> for &Matrix {
311 type Output = Matrix;
312
313 fn mul(self, other: &Matrix) -> Matrix {
314 self.multiply_matrix(other)
315 }
316}
317
318#[cfg(test)]
319mod matrix_tests {
320 use rand::SeedableRng;
321
322 use super::*;
323
324 #[test]
325 fn it_works() {
326 let m = Matrix::new(2, 3);
327 assert_eq!(m.rows(), 2);
328 assert_eq!(m.cols(), 3);
329 assert_eq!(m.data().len(), 2 * 3);
330 }
331
332 #[test]
333 fn it_creates_random_matrix() {
334 let mut rng = StdRng::from_os_rng();
335 let m = Matrix::random(&mut rng, 2, 3);
336 assert_eq!(m.rows, 2);
337 assert_eq!(m.cols, 3);
338 assert_eq!(m.data.len(), 2 * 3);
339 for i in 0..2 {
340 for j in 0..3 {
341 assert!(m.get(i, j) >= -1.0 && m.get(i, j) <= 1.0);
342 }
343 }
344 }
345
346 #[test]
347 fn it_creates_a_matrix_from_a_vector() {
348 let v = vec![1.0, 2.0, 5.0, 3.0, 4.0, 6.0];
349 let m = Matrix::from_vec(2, 3, v.clone());
350 assert_eq!(m.rows, 2);
351 assert_eq!(m.cols, 3);
352 assert_eq!(m.data, v);
353 }
354
355 #[test]
356 fn it_transposes_matrix() {
357 let m = Matrix::from_vec(
358 3,
359 2,
360 vec![
361 1.0, 2.0, 5.0, 3.0, 4.0, 6.0,
362 ],
363 );
364 let transposed = m.transpose();
365 assert_eq!(transposed.rows, 2);
366 assert_eq!(transposed.cols, 3);
367 assert_eq!(transposed.get(0, 0), 1.0);
368 assert_eq!(transposed.get(0, 1), 5.0);
369 assert_eq!(transposed.get(0, 2), 4.0);
370 assert_eq!(transposed.get(1, 0), 2.0);
371 assert_eq!(transposed.get(1, 1), 3.0);
372 assert_eq!(transposed.get(1, 2), 6.0);
373 }
374
375 #[test]
376 fn it_gets_and_sets_values() {
377 let mut m = Matrix::new(2, 3);
378 m.set(0, 0, 1.0);
379 m.set(1, 2, 2.0);
380 assert_eq!(m.get(0, 0), 1.0);
381 assert_eq!(m.get(1, 2), 2.0);
382 assert_eq!(m.get(0, 1), 0.0);
383 assert_eq!(m.get(1, 0), 0.0);
384 }
385
386 #[test]
387 #[should_panic(expected = "Index out of bounds")]
388 fn it_panics_on_out_of_bounds_get() {
389 let m = Matrix::new(2, 3);
390 m.get(2, 0);
391 }
392
393 #[test]
394 #[should_panic(expected = "Index out of bounds")]
395 fn it_panics_on_out_of_bounds_set() {
396 let mut m = Matrix::new(2, 3);
397 m.set(2, 0, 1.0);
398 }
399
400 #[test]
401 #[should_panic(expected = "Index out of bounds")]
402 fn it_panics_on_out_of_bounds_get_mut() {
403 let mut m = Matrix::new(2, 3);
404 m.get_mut(2, 0);
405 }
406
407 #[test]
408 #[should_panic(expected = "Index out of bounds")]
409 fn it_panics_on_out_of_bounds_set_mut() {
410 let mut m = Matrix::new(2, 3);
411 m.get_mut(2, 0);
412 }
413
414 #[test]
415 fn it_gets_and_sets_mutable_values() {
416 let mut m = Matrix::new(2, 3);
417 *m.get_mut(0, 0) = 1.0;
418 *m.get_mut(1, 2) = 2.0;
419 assert_eq!(m.get(0, 0), 1.0);
420 assert_eq!(m.get(1, 2), 2.0);
421 assert_eq!(m.get(0, 1), 0.0);
422 assert_eq!(m.get(1, 0), 0.0);
423 }
424
425 #[test]
426 fn it_returns_mutable_data() {
427 let mut m = Matrix::new(2, 3);
428 m.data_mut()[0] = 1.0;
429 m.data_mut()[1 * 3 + 2] = 2.0;
430 assert_eq!(m.get(0, 0), 1.0);
431 assert_eq!(m.get(1, 2), 2.0);
432 assert_eq!(m.get(0, 1), 0.0);
433 assert_eq!(m.get(1, 0), 0.0);
434 }
435
436 #[test]
437 fn it_adds_matrices() {
438 let m1 = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
439 let m2 = Matrix::from_vec(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
440 let result = m1 + &m2;
441 assert_eq!(result.get(0, 0), 6.0);
442 assert_eq!(result.get(0, 1), 8.0);
443 assert_eq!(result.get(1, 0), 10.0);
444 assert_eq!(result.get(1, 1), 12.0);
445 }
446
447 #[test]
448 fn it_adds_and_assigns() {
449 let mut m1 = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
450 let m2 = Matrix::from_vec(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
451 m1 += &m2;
452 assert_eq!(m1.get(0, 0), 6.0);
453 assert_eq!(m1.get(0, 1), 8.0);
454 assert_eq!(m1.get(1, 0), 10.0);
455 assert_eq!(m1.get(1, 1), 12.0);
456 }
457
458 #[test]
459 fn it_multiplies_by_scalar() {
460 let m = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
461 let result = m * 2.0;
462 assert_eq!(result.get(0, 0), 2.0);
463 assert_eq!(result.get(0, 1), 4.0);
464 assert_eq!(result.get(1, 0), 6.0);
465 assert_eq!(result.get(1, 1), 8.0);
466 }
467
468 #[test]
469 fn it_multiplies_by_scalar_in_place() {
470 let mut m = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
471 m *= 2.0;
472 assert_eq!(m.get(0, 0), 2.0);
473 assert_eq!(m.get(0, 1), 4.0);
474 assert_eq!(m.get(1, 0), 6.0);
475 assert_eq!(m.get(1, 1), 8.0);
476 }
477
478 #[test]
479 fn it_multiplies_matrices() {
480 let m = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
481 let n = Matrix::from_vec(3, 2, vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]);
482 let e = Matrix::from_vec(2, 2, vec![58.0, 64.0, 139.0, 154.0]);
483 let r = &m * &n;
484 assert_eq!(r, e);
485 }
486
487 #[test]
488 fn it_maps() {
489 let mut m = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
490 m.apply(|x| x * 2.0);
491 assert_eq!(m.get(0, 0), 2.0);
492 assert_eq!(m.get(0, 1), 4.0);
493 assert_eq!(m.get(1, 0), 6.0);
494 assert_eq!(m.get(1, 1), 8.0);
495 }
496}