1use super::Vector;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
17pub struct Matrix<T> {
18 data: Vec<T>,
19 rows: usize,
20 cols: usize,
21}
22
23impl<T: Copy> Matrix<T> {
24 pub fn from_vec(rows: usize, cols: usize, data: Vec<T>) -> Result<Self, &'static str> {
30 if data.len() != rows * cols {
31 return Err("Data length must equal rows * cols");
32 }
33 Ok(Self { data, rows, cols })
34 }
35
36 #[must_use]
38 pub fn shape(&self) -> (usize, usize) {
39 (self.rows, self.cols)
40 }
41
42 #[must_use]
44 pub fn n_rows(&self) -> usize {
45 self.rows
46 }
47
48 #[must_use]
50 pub fn n_cols(&self) -> usize {
51 self.cols
52 }
53
54 #[must_use]
60 pub fn get(&self, row: usize, col: usize) -> T {
61 self.data[row * self.cols + col]
62 }
63
64 pub fn set(&mut self, row: usize, col: usize, value: T) {
70 self.data[row * self.cols + col] = value;
71 }
72
73 #[must_use]
75 pub fn row(&self, row_idx: usize) -> Vector<T> {
76 let start = row_idx * self.cols;
77 let end = start + self.cols;
78 Vector::from_slice(&self.data[start..end])
79 }
80
81 #[must_use]
83 pub fn column(&self, col_idx: usize) -> Vector<T> {
84 let data: Vec<T> = (0..self.rows)
85 .map(|row| self.data[row * self.cols + col_idx])
86 .collect();
87 Vector::from_vec(data)
88 }
89
90 #[must_use]
92 pub fn as_slice(&self) -> &[T] {
93 &self.data
94 }
95}
96
97impl Matrix<f32> {
98 #[must_use]
100 pub fn zeros(rows: usize, cols: usize) -> Self {
101 Self {
102 data: vec![0.0; rows * cols],
103 rows,
104 cols,
105 }
106 }
107
108 #[must_use]
110 pub fn ones(rows: usize, cols: usize) -> Self {
111 Self {
112 data: vec![1.0; rows * cols],
113 rows,
114 cols,
115 }
116 }
117
118 #[must_use]
120 pub fn eye(n: usize) -> Self {
121 let mut data = vec![0.0; n * n];
122 for i in 0..n {
123 data[i * n + i] = 1.0;
124 }
125 Self {
126 data,
127 rows: n,
128 cols: n,
129 }
130 }
131
132 #[must_use]
134 pub fn transpose(&self) -> Self {
135 let mut data = vec![0.0; self.rows * self.cols];
136 for i in 0..self.rows {
137 for j in 0..self.cols {
138 data[j * self.rows + i] = self.data[i * self.cols + j];
139 }
140 }
141 Self {
142 data,
143 rows: self.cols,
144 cols: self.rows,
145 }
146 }
147
148 pub fn matmul(&self, other: &Self) -> Result<Self, &'static str> {
154 if self.cols != other.rows {
155 return Err("Matrix dimensions don't match for multiplication");
156 }
157
158 let mut result = vec![0.0; self.rows * other.cols];
159 for i in 0..self.rows {
160 for j in 0..other.cols {
161 let mut sum = 0.0;
162 for k in 0..self.cols {
163 sum += self.get(i, k) * other.get(k, j);
164 }
165 result[i * other.cols + j] = sum;
166 }
167 }
168
169 Ok(Self {
170 data: result,
171 rows: self.rows,
172 cols: other.cols,
173 })
174 }
175
176 pub fn matvec(&self, vec: &Vector<f32>) -> Result<Vector<f32>, &'static str> {
182 if self.cols != vec.len() {
183 return Err("Matrix columns must match vector length");
184 }
185
186 let result: Vec<f32> = (0..self.rows)
187 .map(|i| {
188 let row = self.row(i);
189 row.dot(vec)
190 })
191 .collect();
192
193 Ok(Vector::from_vec(result))
194 }
195
196 pub fn add(&self, other: &Self) -> Result<Self, &'static str> {
202 if self.rows != other.rows || self.cols != other.cols {
203 return Err("Matrix dimensions must match for addition");
204 }
205
206 let data: Vec<f32> = self
207 .data
208 .iter()
209 .zip(other.data.iter())
210 .map(|(a, b)| a + b)
211 .collect();
212
213 Ok(Self {
214 data,
215 rows: self.rows,
216 cols: self.cols,
217 })
218 }
219
220 pub fn sub(&self, other: &Self) -> Result<Self, &'static str> {
226 if self.rows != other.rows || self.cols != other.cols {
227 return Err("Matrix dimensions must match for subtraction");
228 }
229
230 let data: Vec<f32> = self
231 .data
232 .iter()
233 .zip(other.data.iter())
234 .map(|(a, b)| a - b)
235 .collect();
236
237 Ok(Self {
238 data,
239 rows: self.rows,
240 cols: self.cols,
241 })
242 }
243
244 #[must_use]
246 pub fn mul_scalar(&self, scalar: f32) -> Self {
247 Self {
248 data: self.data.iter().map(|x| x * scalar).collect(),
249 rows: self.rows,
250 cols: self.cols,
251 }
252 }
253
254 pub fn cholesky_solve(&self, b: &Vector<f32>) -> Result<Vector<f32>, &'static str> {
262 if self.rows != self.cols {
263 return Err("Matrix must be square for Cholesky decomposition");
264 }
265 if self.rows != b.len() {
266 return Err("Matrix rows must match vector length");
267 }
268
269 let n = self.rows;
270
271 let mut l = vec![0.0; n * n];
273
274 for i in 0..n {
275 for j in 0..=i {
276 let mut sum = 0.0;
277
278 if i == j {
279 for k in 0..j {
280 sum += l[j * n + k] * l[j * n + k];
281 }
282 let diag = self.get(j, j) - sum;
283 if diag <= 0.0 {
284 return Err("Matrix is not positive definite");
285 }
286 l[j * n + j] = diag.sqrt();
287 } else {
288 for k in 0..j {
289 sum += l[i * n + k] * l[j * n + k];
290 }
291 l[i * n + j] = (self.get(i, j) - sum) / l[j * n + j];
292 }
293 }
294 }
295
296 let mut y = vec![0.0; n];
298 for i in 0..n {
299 let mut sum = 0.0;
300 for j in 0..i {
301 sum += l[i * n + j] * y[j];
302 }
303 y[i] = (b[i] - sum) / l[i * n + i];
304 }
305
306 let mut x = vec![0.0; n];
308 for i in (0..n).rev() {
309 let mut sum = 0.0;
310 for j in (i + 1)..n {
311 sum += l[j * n + i] * x[j];
312 }
313 x[i] = (y[i] - sum) / l[i * n + i];
314 }
315
316 Ok(Vector::from_vec(x))
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323
324 #[test]
325 fn test_from_vec() {
326 let m = Matrix::from_vec(2, 3, vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0])
327 .expect("test data has correct dimensions: 2*3=6 elements");
328 assert_eq!(m.shape(), (2, 3));
329 assert!((m.get(0, 0) - 1.0).abs() < 1e-6);
330 assert!((m.get(1, 2) - 6.0).abs() < 1e-6);
331 }
332
333 #[test]
334 fn test_from_vec_error() {
335 let result = Matrix::from_vec(2, 3, vec![1.0_f32, 2.0, 3.0]);
336 assert!(result.is_err());
337 }
338
339 #[test]
340 fn test_zeros() {
341 let m = Matrix::<f32>::zeros(2, 3);
342 assert_eq!(m.shape(), (2, 3));
343 assert!(m.as_slice().iter().all(|&x| x == 0.0));
344 }
345
346 #[test]
347 fn test_eye() {
348 let m = Matrix::<f32>::eye(3);
349 assert!((m.get(0, 0) - 1.0).abs() < 1e-6);
350 assert!((m.get(1, 1) - 1.0).abs() < 1e-6);
351 assert!((m.get(2, 2) - 1.0).abs() < 1e-6);
352 assert!((m.get(0, 1) - 0.0).abs() < 1e-6);
353 }
354
355 #[test]
356 fn test_transpose() {
357 let m = Matrix::from_vec(2, 3, vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0])
358 .expect("test data has correct dimensions: 2*3=6 elements");
359 let t = m.transpose();
360 assert_eq!(t.shape(), (3, 2));
361 assert!((t.get(0, 0) - 1.0).abs() < 1e-6);
362 assert!((t.get(0, 1) - 4.0).abs() < 1e-6);
363 assert!((t.get(2, 1) - 6.0).abs() < 1e-6);
364 }
365
366 #[test]
367 fn test_row() {
368 let m = Matrix::from_vec(2, 3, vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0])
369 .expect("test data has correct dimensions: 2*3=6 elements");
370 let row = m.row(1);
371 assert_eq!(row.len(), 3);
372 assert!((row[0] - 4.0).abs() < 1e-6);
373 assert!((row[1] - 5.0).abs() < 1e-6);
374 assert!((row[2] - 6.0).abs() < 1e-6);
375 }
376
377 #[test]
378 fn test_column() {
379 let m = Matrix::from_vec(2, 3, vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0])
380 .expect("test data has correct dimensions: 2*3=6 elements");
381 let col = m.column(1);
382 assert_eq!(col.len(), 2);
383 assert!((col[0] - 2.0).abs() < 1e-6);
384 assert!((col[1] - 5.0).abs() < 1e-6);
385 }
386
387 #[test]
388 fn test_matmul() {
389 let a = Matrix::from_vec(2, 3, vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0])
391 .expect("test data has correct dimensions: 2*3=6 elements");
392 let b = Matrix::from_vec(3, 2, vec![7.0_f32, 8.0, 9.0, 10.0, 11.0, 12.0])
393 .expect("test data has correct dimensions: 3*2=6 elements");
394 let c = a
395 .matmul(&b)
396 .expect("matrix dimensions are compatible for multiplication: 2x3 * 3x2");
397
398 assert_eq!(c.shape(), (2, 2));
399 assert!((c.get(0, 0) - 58.0).abs() < 1e-6);
401 assert!((c.get(0, 1) - 64.0).abs() < 1e-6);
403 }
404
405 #[test]
406 fn test_matmul_dimension_error() {
407 let a = Matrix::from_vec(2, 3, vec![1.0_f32; 6])
408 .expect("test data has correct dimensions: 2*3=6 elements");
409 let b = Matrix::from_vec(2, 2, vec![1.0_f32; 4])
410 .expect("test data has correct dimensions: 2*2=4 elements");
411 assert!(a.matmul(&b).is_err());
412 }
413
414 #[test]
415 fn test_matvec() {
416 let m = Matrix::from_vec(2, 3, vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0])
417 .expect("test data has correct dimensions: 2*3=6 elements");
418 let v = Vector::from_slice(&[1.0_f32, 2.0, 3.0]);
419 let result = m
420 .matvec(&v)
421 .expect("matrix columns match vector length: both 3");
422
423 assert_eq!(result.len(), 2);
424 assert!((result[0] - 14.0).abs() < 1e-6);
426 assert!((result[1] - 32.0).abs() < 1e-6);
428 }
429
430 #[test]
431 fn test_add() {
432 let a = Matrix::from_vec(2, 2, vec![1.0_f32, 2.0, 3.0, 4.0])
433 .expect("test data has correct dimensions: 2*2=4 elements");
434 let b = Matrix::from_vec(2, 2, vec![5.0_f32, 6.0, 7.0, 8.0])
435 .expect("test data has correct dimensions: 2*2=4 elements");
436 let c = a.add(&b).expect("both matrices have same dimensions: 2x2");
437
438 assert!((c.get(0, 0) - 6.0).abs() < 1e-6);
439 assert!((c.get(1, 1) - 12.0).abs() < 1e-6);
440 }
441
442 #[test]
443 fn test_add_dimension_mismatch() {
444 let a = Matrix::from_vec(2, 2, vec![1.0_f32; 4])
446 .expect("test data has correct dimensions: 2*2=4 elements");
447 let b = Matrix::from_vec(3, 2, vec![1.0_f32; 6])
448 .expect("test data has correct dimensions: 3*2=6 elements");
449 assert!(a.add(&b).is_err());
450
451 let c = Matrix::from_vec(2, 3, vec![1.0_f32; 6])
452 .expect("test data has correct dimensions: 2*3=6 elements");
453 assert!(a.add(&c).is_err());
454 }
455
456 #[test]
457 fn test_sub() {
458 let a = Matrix::from_vec(2, 2, vec![10.0_f32, 8.0, 6.0, 12.0])
460 .expect("test data has correct dimensions: 2*2=4 elements");
461 let b = Matrix::from_vec(2, 2, vec![4.0_f32, 3.0, 2.0, 7.0])
462 .expect("test data has correct dimensions: 2*2=4 elements");
463 let c = a.sub(&b).expect("both matrices have same dimensions: 2x2");
464
465 assert!((c.get(0, 0) - 6.0).abs() < 1e-6); assert!((c.get(0, 1) - 5.0).abs() < 1e-6); assert!((c.get(1, 0) - 4.0).abs() < 1e-6); assert!((c.get(1, 1) - 5.0).abs() < 1e-6); }
471
472 #[test]
473 fn test_sub_dimension_mismatch_rows() {
474 let a = Matrix::from_vec(2, 2, vec![1.0_f32; 4])
476 .expect("test data has correct dimensions: 2*2=4 elements");
477 let b = Matrix::from_vec(3, 2, vec![1.0_f32; 6])
478 .expect("test data has correct dimensions: 3*2=6 elements");
479 assert!(a.sub(&b).is_err());
480 }
481
482 #[test]
483 fn test_sub_dimension_mismatch_cols() {
484 let a = Matrix::from_vec(2, 2, vec![1.0_f32; 4])
486 .expect("test data has correct dimensions: 2*2=4 elements");
487 let b = Matrix::from_vec(2, 3, vec![1.0_f32; 6])
488 .expect("test data has correct dimensions: 2*3=6 elements");
489 assert!(a.sub(&b).is_err());
490 }
491
492 #[test]
493 fn test_cholesky_solve() {
494 let a = Matrix::from_vec(2, 2, vec![4.0_f32, 2.0, 2.0, 3.0])
499 .expect("test data has correct dimensions: 2*2=4 elements");
500 let b = Vector::from_slice(&[1.0_f32, 2.0]);
501 let x = a
502 .cholesky_solve(&b)
503 .expect("matrix is square, symmetric positive definite, and vector matches size");
504
505 assert_eq!(x.len(), 2);
506 assert!((x[0] - (-0.125)).abs() < 1e-5);
507 assert!((x[1] - 0.75).abs() < 1e-5);
508 }
509
510 #[test]
511 fn test_cholesky_solve_3x3() {
512 let a = Matrix::from_vec(
515 3,
516 3,
517 vec![4.0_f32, 12.0, -16.0, 12.0, 37.0, -43.0, -16.0, -43.0, 98.0],
518 )
519 .expect("test data has correct dimensions: 3*3=9 elements");
520 let b = Vector::from_slice(&[1.0_f32, 2.0, 3.0]);
521 let x = a
522 .cholesky_solve(&b)
523 .expect("matrix is square, symmetric positive definite, and vector matches size");
524
525 let result = a
527 .matvec(&x)
528 .expect("matrix columns match vector length: both 3");
529 for i in 0..3 {
530 assert!((result[i] - b[i]).abs() < 1e-4);
531 }
532 }
533
534 #[test]
535 fn test_cholesky_solve_strict() {
536 let a = Matrix::from_vec(
544 4,
545 4,
546 vec![
547 4.0_f32, 2.0, 1.0, 1.0, 2.0, 5.0, 2.0, 1.0, 1.0, 2.0, 6.0, 2.0, 1.0, 1.0, 2.0, 7.0,
548 ],
549 )
550 .expect("test data has correct dimensions: 4*4=16 elements");
551 let b = Vector::from_slice(&[1.0_f32, 2.0, 3.0, 4.0]);
552 let x = a
553 .cholesky_solve(&b)
554 .expect("matrix is square, symmetric positive definite, and vector matches size");
555
556 let result = a
558 .matvec(&x)
559 .expect("matrix columns match vector length: both 4");
560 for i in 0..4 {
561 assert!(
562 (result[i] - b[i]).abs() < 1e-5,
563 "Failed at index {}: expected {}, got {}",
564 i,
565 b[i],
566 result[i]
567 );
568 }
569
570 let a3 = Matrix::from_vec(3, 3, vec![9.0_f32, 3.0, 3.0, 3.0, 5.0, 1.0, 3.0, 1.0, 4.0])
573 .expect("test data has correct dimensions: 3*3=9 elements");
574 let b3 = Vector::from_slice(&[15.0_f32, 9.0, 8.0]);
575 let x3 = a3
576 .cholesky_solve(&b3)
577 .expect("matrix is square, symmetric positive definite, and vector matches size");
578
579 assert!((x3[0] - 1.0).abs() < 1e-6);
581 assert!((x3[1] - 1.0).abs() < 1e-6);
582 assert!((x3[2] - 1.0).abs() < 1e-6);
583
584 let verify3 = a3
586 .matvec(&x3)
587 .expect("matrix columns match vector length: both 3");
588 assert!((verify3[0] - 15.0).abs() < 1e-6);
589 assert!((verify3[1] - 9.0).abs() < 1e-6);
590 assert!((verify3[2] - 8.0).abs() < 1e-6);
591 }
592
593 #[test]
594 fn test_mul_scalar() {
595 let m = Matrix::from_vec(2, 2, vec![1.0_f32, 2.0, 3.0, 4.0])
596 .expect("test data has correct dimensions: 2*2=4 elements");
597 let result = m.mul_scalar(2.0);
598 assert!((result.get(0, 0) - 2.0).abs() < 1e-6);
599 assert!((result.get(1, 1) - 8.0).abs() < 1e-6);
600 }
601
602 #[test]
603 fn test_set() {
604 let mut m = Matrix::<f32>::zeros(2, 2);
605 m.set(0, 1, 5.0);
606 assert!((m.get(0, 1) - 5.0).abs() < 1e-6);
607 }
608}