matrix_oxide/
matrix.rs

1use crate::random;
2use std::fmt::Debug;
3use std::ops::{Add, Div, Mul, Sub};
4
5/// MxN Matrix
6pub struct Matrix<T> {
7    pub data: Vec<T>,
8    pub row_size: usize,
9    pub col_size: usize,
10}
11impl<T: Default + Clone> Matrix<T> {
12    /// Construct a new *non-empty* and *sized* `Matrix`
13    pub fn new(row_size: usize, col_size: usize) -> Self {
14        Matrix {
15            data: vec![T::default(); row_size * col_size],
16            row_size,
17            col_size,
18        }
19    }
20
21    /// Construct a new *non-empty* and *sized* `Matrix` with random values of `T`
22    pub fn new_random(row_size: usize, col_size: usize) -> Self
23    where
24        T: random::Random,
25    {
26        let random_data: Vec<T> = random::gen_rand_vec(row_size * col_size);
27        Matrix {
28            data: random_data,
29            row_size,
30            col_size,
31        }
32    }
33
34    /// Try to get a reference to the value at a given row and column from the matrix
35    pub fn get(&self, row: usize, col: usize) -> Option<&T> {
36        if row < self.row_size && col < self.col_size {
37            Some(&self.data[col + row * self.col_size])
38        } else {
39            None
40        }
41    }
42
43    /// Get a vector of the diagonal elements of the matrix
44    pub fn get_diagonal(&self) -> Vec<T> {
45        (0..self.col_size)
46            .filter_map(|col_idx| self.get(col_idx, col_idx).cloned())
47            .collect()
48    }
49
50    /// Try to get a mutable reference to the value at a given row and column from the matrix
51    pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut T> {
52        if row < self.row_size && col < self.col_size {
53            Some(&mut self.data[col + row * self.col_size])
54        } else {
55            None
56        }
57    }
58
59    /// Try to set a value at a given row and column in the matrix
60    pub fn set(&mut self, row: usize, column: usize, value: T) -> bool {
61        if let Some(cell) = self.get_mut(row, column) {
62            *cell = value;
63            true
64        } else {
65            false
66        }
67    }
68
69    /// Try to get all the values for a given column
70    ///
71    /// NOTE: If you pass a column value larger than the number of columns
72    /// this function will return None.
73    pub fn try_get_column(&self, column: usize) -> Option<Vec<T>> {
74        // Bounds check
75        if column >= self.col_size {
76            return None;
77        }
78
79        // Iterate over all the rows grabbing a specific column each time
80        let col_data: Vec<T> = (0..self.row_size)
81            .map(|row| self.data[row * self.col_size + column].clone())
82            .collect();
83
84        Some(col_data)
85    }
86
87    /// Try to get all the values for a given row
88    ///
89    /// NOTE: If you pass a row value larger than the number of rows
90    /// this function will return None.
91    pub fn try_get_row(&self, row: usize) -> Option<Vec<T>> {
92        // Bounds check
93        if row >= self.row_size {
94            return None;
95        }
96
97        // Iterate over all the rows grabbing a specific column each time
98        let row_data: Vec<T> = (0..self.col_size)
99            .map(|col| self.data[row * self.col_size + col].clone())
100            .collect();
101
102        Some(row_data)
103    }
104
105    /// Create a `Matrix` from a columns (vec of vec)
106    pub fn from_columns(cols: Vec<Vec<T>>) -> Matrix<T> {
107        if cols.is_empty() {
108            return Matrix {
109                data: Vec::new(),
110                row_size: 0,
111                col_size: 0,
112            };
113        }
114
115        let row_size = cols[0].len();
116        let col_size = cols.len();
117
118        let data = (0..row_size)
119            .flat_map(|row| cols.iter().filter_map(move |col| col.get(row).cloned()))
120            .collect();
121
122        Matrix {
123            data,
124            row_size,
125            col_size,
126        }
127    }
128
129    /// Create a sub matrix with a specific row and column to exclude
130    pub fn sub_matrix(&self, skip_row: usize, skip_col: usize) -> Matrix<T> {
131        let columns: Vec<Vec<T>> = (0..self.col_size)
132            .filter_map(|col| {
133                if col != skip_col {
134                    Some(
135                        self.try_get_column(col)?
136                            .into_iter()
137                            .enumerate()
138                            .filter_map(|(row, val)| if row != skip_row { Some(val) } else { None })
139                            .collect(),
140                    )
141                } else {
142                    None
143                }
144            })
145            .collect();
146
147        Matrix::from_columns(columns)
148    }
149
150    /// Perform a transpose operation (swap rows for columns and vice versa)
151    pub fn transpose(&self) -> Matrix<T> {
152        Matrix {
153            data: (0..self.col_size)
154                .flat_map(|col| {
155                    (0..self.row_size).map(move |row| self.data[row * self.col_size + col].clone())
156                })
157                .collect(),
158
159            row_size: self.col_size,
160            col_size: self.row_size,
161        }
162    }
163}
164
165impl<T: Default + Clone> Default for Matrix<T> {
166    /// Create a default `Matrix` instance
167    fn default() -> Self {
168        Self::new(2, 3)
169    }
170}
171impl<T: Default + Clone + Debug> Add for Matrix<T>
172where
173    T: Add<Output = T> + Clone,
174{
175    type Output = Matrix<T>;
176
177    /// Matrix addition
178    /// NOTE: the matrices you add MUST have the same dimensionality
179    fn add(self, rhs: Self) -> Matrix<T> {
180        let data: Vec<T> = (0..self.row_size)
181            .flat_map(|row| {
182                let row_a = self.try_get_row(row).expect("Invalid row in self");
183                let row_b = rhs.try_get_row(row).expect("Invalid row in rhs");
184                row_a.into_iter().zip(row_b).map(|(a, b)| a + b)
185            })
186            .collect();
187
188        Matrix {
189            data,
190            col_size: self.col_size,
191            row_size: self.row_size,
192        }
193    }
194}
195impl<T> Matrix<T>
196where
197    T: Default + Clone + Add<Output = T>,
198{
199    /// Perform the trace operation that computes the sum of all diagonal
200    /// elements in the matrix.
201    ///
202    /// NOTE: off-diagnonal elements do NOT contribute to the trace of the
203    /// matrix, so 2 very different matrices can have the same trace.
204    pub fn trace(&self) -> T {
205        self.get_diagonal()
206            .into_iter()
207            .fold(T::default(), |acc, diagonal| acc + diagonal)
208    }
209
210    /// Perform a summation over the matrix.
211    pub fn sum(&self) -> T {
212        self.data
213            .iter()
214            .cloned()
215            .fold(T::default(), |acc, x| acc + x)
216    }
217}
218impl<T: Default + Clone + Debug> Sub for Matrix<T>
219where
220    T: Sub<Output = T> + Clone,
221{
222    type Output = Matrix<T>;
223
224    /// Subtract a matrix by another matrix
225    /// NOTE: the matrix you subtract by MUST have the same dimensionality
226    fn sub(self, rhs: Self) -> Matrix<T> {
227        let data: Vec<T> = (0..self.row_size)
228            .flat_map(|row| {
229                let row_a = self.try_get_row(row).expect("Invalid row in self");
230                let row_b = rhs.try_get_row(row).expect("Invalid row in rhs");
231                row_a.into_iter().zip(row_b).map(|(a, b)| a - b)
232            })
233            .collect();
234
235        Matrix {
236            data,
237            row_size: self.row_size,
238            col_size: self.col_size,
239        }
240    }
241}
242impl<T: Default> Matrix<T>
243where
244    T: Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Clone,
245{
246    /// Multiply a matrix by a single number (scalar)
247    /// NOTE: The scalar type MUST match the matrix type.
248    pub fn scalar_multiply(&self, scalar: T) -> Matrix<T> {
249        let data = self
250            .data
251            .iter()
252            .map(|value| value.clone() * scalar.clone())
253            .collect();
254
255        Matrix {
256            data,
257            row_size: self.row_size,
258            col_size: self.col_size,
259        }
260    }
261
262    /// Multiply `Matrix` with another `Matrix` using standard matrix multiplication
263    /// NOTE: The matrices inner dimensions MUST match else returns None
264    pub fn multiply(&self, multiplier: &Matrix<T>) -> Option<Matrix<T>> {
265        // Validity check for the matrices inner dimensions
266        if self.col_size != multiplier.row_size {
267            return None;
268        }
269
270        let data: Vec<T> = (0..self.row_size)
271            .flat_map(|i| {
272                (0..multiplier.col_size).map(move |j| {
273                    (0..self.col_size)
274                        .map(|k| {
275                            self.data[i * self.col_size + k].clone()
276                                * multiplier.data[k * multiplier.col_size + j].clone()
277                        })
278                        .fold(T::default(), |acc, x| acc + x)
279                })
280            })
281            .collect();
282
283        Some(Matrix {
284            data,
285            col_size: multiplier.col_size,
286            row_size: self.row_size,
287        })
288    }
289
290    /// Multiply the `Matrix` by a vector
291    /// NOTE: The vectors length MUST match the vector columns, else returns None
292    pub fn vector_multiply(&self, multiplier: &[T]) -> Option<Vec<T>> {
293        // Validity check that the `Matrix` column size matches the vector column size
294        if self.col_size != multiplier.len() {
295            return None;
296        }
297
298        let data: Vec<T> = (0..self.row_size)
299            .map(|i| {
300                (0..multiplier.len())
301                    .map(|j| self.data[i * self.col_size + j].clone() * multiplier[j].clone())
302                    .fold(T::default(), |acc, x| acc + x)
303            })
304            .collect();
305
306        Some(data)
307    }
308
309    /// Compute a unique determinant for a `Matrix`
310    /// NOTE: Only computable for square (M x M) matrices.
311    /// NOTE: The determinant is 0 for a `Matrix` with rank r < M (non-invertable).
312    pub fn determinant(&self) -> Option<T> {
313        // Validity check that it's a square matrix
314        if self.col_size != self.row_size {
315            return None;
316        }
317
318        // Base case for recursion: 1x1 matrix
319        if self.row_size == 1 {
320            return Some(self.data[0].clone());
321        }
322
323        if let Some(first_row) = self.try_get_row(0) {
324            let determinant =
325                first_row
326                    .iter()
327                    .enumerate()
328                    .fold(T::default(), |acc, (col_idx, item)| {
329                        let sub_matrix = self.sub_matrix(0, col_idx);
330
331                        if let Some(sub_determinant) = sub_matrix.determinant() {
332                            // Alternate between addition and subtraction
333                            // operations every iteration
334                            if col_idx % 2 == 0 {
335                                acc + item.clone() * sub_determinant
336                            } else {
337                                acc - item.clone() * sub_determinant
338                            }
339                        } else {
340                            T::default()
341                        }
342                    });
343
344            Some(determinant)
345        } else {
346            None
347        }
348    }
349}
350impl<T> Matrix<T>
351where
352    T: Default + Clone + From<f64>,
353{
354    /// Create an identity matrix of given size
355    pub fn identity(size: usize) -> Matrix<T> {
356        let data = (0..size * size)
357            .map(|i| {
358                if i % (size + 1) == 0 {
359                    T::from(1.0)
360                } else {
361                    T::default()
362                }
363            })
364            .collect();
365
366        Matrix {
367            data,
368            row_size: size,
369            col_size: size,
370        }
371    }
372}
373impl<T> Matrix<T>
374where
375    T: Default + Clone + Mul<Output = T> + Add<Output = T> + Into<f64>,
376{
377    /// Compute the frobenius norm of a `Matrix`
378    pub fn frobenius_norm(&self) -> f64 {
379        let sum_of_squares: f64 = self
380            .data
381            .iter()
382            .map(|val| {
383                let val_f64: f64 = val.clone().into();
384                val_f64 * val_f64
385            })
386            .fold(f64::default(), |acc, x| acc + x);
387
388        sum_of_squares.sqrt()
389    }
390}
391impl<T> Matrix<T>
392where
393    T: Copy
394        + PartialOrd
395        + Default
396        + From<f64>
397        + Into<f64>
398        + Sub<Output = T>
399        + Add<Output = T>
400        + Mul<Output = T>
401        + Div<Output = T>
402        + Abs,
403{
404    /// Compute the inverse of a `Matrix`
405    /// NOTE: Only computable for square (M x M) matrices.
406    /// NOTE: Only computable for a `Matrix` where r = M (full rank).
407    pub fn inverse(&self) -> Option<Matrix<T>> {
408        // Validity check that it's a square matrix
409        if self.row_size != self.col_size {
410            return None;
411        }
412
413        let n = self.row_size;
414        let epsilon = T::from(1e-10);
415
416        // Regular matrix
417        let mut a = self.data.clone();
418        // Inverted matrix
419        let mut b = Matrix::<T>::identity(n).data;
420
421        for i in 0..n {
422            let pivot = (i..n).fold(i, |acc, j| {
423                match a[j * n + i].abs() > a[acc * n + i].abs() {
424                    true => j,
425                    false => acc,
426                }
427            });
428
429            // Validity check that it's not a singular matrix
430            if a[pivot * n + i].abs() < epsilon {
431                return None;
432            }
433
434            if pivot != i {
435                (0..n).for_each(|k| {
436                    a.swap(i * n + k, pivot * n + k);
437                    b.swap(i * n + k, pivot * n + k);
438                });
439            }
440
441            let pivot_val = a[i * n + i];
442            (0..n).for_each(|k| {
443                a[i * n + k] = a[i * n + k] / pivot_val;
444                b[i * n + k] = b[i * n + k] / pivot_val;
445            });
446
447            (0..n).filter(|&j| j != i).for_each(|j| {
448                let factor = a[j * n + i];
449                (0..n).for_each(|k| {
450                    a[j * n + k] = a[j * n + k] - factor * a[i * n + k];
451                    b[j * n + k] = b[j * n + k] - factor * b[i * n + k];
452                })
453            })
454        }
455
456        Some(Matrix {
457            data: b,
458            row_size: n,
459            col_size: n,
460        })
461    }
462}
463
464pub trait Abs {
465    fn abs(self) -> Self;
466}
467impl Abs for f64 {
468    fn abs(self) -> Self {
469        self.abs()
470    }
471}
472impl Abs for i32 {
473    fn abs(self) -> Self {
474        self.abs()
475    }
476}
477impl Abs for i64 {
478    fn abs(self) -> Self {
479        self.abs()
480    }
481}
482impl Abs for f32 {
483    fn abs(self) -> Self {
484        self.abs()
485    }
486}
487
488#[cfg(test)]
489mod tests {
490    use super::*;
491
492    /// Check if 2 float value's are *ABOUT* equal
493    fn approx_equal(a: &[f64], b: &[f64], epsilon: f64) -> bool {
494        a.iter()
495            .zip(b.iter())
496            .all(|(&a, &b)| (a - b).abs() < epsilon)
497    }
498
499    #[test]
500    fn test_matrix_vector_multiplication() {
501        let matrix = Matrix::<i32> {
502            data: vec![1, 2, 3, 4],
503            row_size: 2,
504            col_size: 2,
505        };
506        let vector = vec![5, 6];
507
508        let expected = vec![17, 39];
509        let result = matrix.vector_multiply(&vector).unwrap();
510        assert_eq!(result, expected);
511    }
512
513    #[test]
514    fn test_matrix_multiplication() {
515        let matrix_a = Matrix::<i32> {
516            data: vec![1, 2, 3, 4],
517            row_size: 2,
518            col_size: 2,
519        };
520
521        let matrix_b = Matrix::<i32> {
522            data: vec![2, 0, 1, 2],
523            row_size: 2,
524            col_size: 2,
525        };
526
527        let expected = Matrix::<i32> {
528            data: vec![4, 4, 10, 8],
529            row_size: 2,
530            col_size: 2,
531        };
532        let result = matrix_a.multiply(&matrix_b).unwrap();
533        assert_eq!(result.data, expected.data);
534    }
535
536    #[test]
537    fn test_martrix_trace() {
538        let matrix = Matrix::<i32> {
539            data: vec![1, 2, 3, 4],
540            col_size: 2,
541            row_size: 2,
542        };
543
544        let expected: i32 = 5;
545        let result = matrix.trace();
546        assert_eq!(result, expected);
547    }
548
549    #[test]
550    fn test_martrix_diagonal() {
551        let matrix = Matrix::<i32> {
552            data: vec![1, 2, 3, 4],
553            row_size: 2,
554            col_size: 2,
555        };
556
557        let expected: Vec<i32> = vec![1, 4];
558        let result = matrix.get_diagonal();
559        assert_eq!(result, expected);
560    }
561
562    #[test]
563    fn test_matrix_scalar_multiplication() {
564        let matrix = Matrix::<i32> {
565            data: vec![1, 2, 3, 4],
566            row_size: 2,
567            col_size: 2,
568        };
569
570        let expected = Matrix::<i32> {
571            data: vec![2, 4, 6, 8],
572            row_size: 2,
573            col_size: 2,
574        };
575        let result = matrix.scalar_multiply(2);
576
577        assert_eq!(result.data, expected.data);
578    }
579
580    #[test]
581    fn test_matrix_subtraction() {
582        let matrix_a = Matrix::<i32> {
583            data: vec![1, 2, 3, 4, 5, 6],
584            row_size: 2,
585            col_size: 3,
586        };
587        let matrix_b = Matrix::<i32> {
588            data: vec![6, 5, 4, 3, 2, 1],
589            row_size: 2,
590            col_size: 3,
591        };
592
593        let expected = Matrix::<i32> {
594            data: vec![-5, -3, -1, 1, 3, 5],
595            row_size: 2,
596            col_size: 3,
597        };
598        let result = matrix_a - matrix_b;
599        assert_eq!(result.data, expected.data);
600    }
601
602    #[test]
603    fn test_matrix_addition() {
604        let matrix_a = Matrix::<i32> {
605            data: vec![1, 2, 3, 4],
606            row_size: 2,
607            col_size: 2,
608        };
609        let matrix_b = Matrix::<i32> {
610            data: vec![4, 3, 2, 1],
611            row_size: 2,
612            col_size: 2,
613        };
614
615        let expected = Matrix::<i32> {
616            data: vec![5, 5, 5, 5],
617            row_size: 2,
618            col_size: 2,
619        };
620
621        let result = matrix_a + matrix_b;
622        assert_eq!(result.data, expected.data);
623    }
624
625    #[test]
626    fn new_matrix_has_correct_size() {
627        let matrix: Matrix<i32> = Matrix::new(2, 3);
628        assert_eq!(matrix.data.len(), 6);
629    }
630
631    #[test]
632    fn default_matrix_is_same_as_new() {
633        let default_matrix: Matrix<i32> = Matrix::default();
634        let new_matrix: Matrix<i32> = Matrix::new(2, 3);
635        assert_eq!(default_matrix.data, new_matrix.data);
636    }
637
638    #[test]
639    fn try_get_column_valid() {
640        let matrix: Matrix<i32> = Matrix::new(2, 3);
641        let column = matrix.try_get_column(1);
642        assert!(column.is_some());
643        assert_eq!(column.unwrap(), vec![0, 0]);
644    }
645
646    #[test]
647    fn try_get_column_invalid() {
648        let matrix: Matrix<i32> = Matrix::new(2, 3);
649        let column = matrix.try_get_column(3);
650        assert!(column.is_none());
651    }
652
653    #[test]
654    fn try_get_row_valid() {
655        let matrix: Matrix<i32> = Matrix::new(2, 3);
656        let row = matrix.try_get_row(0);
657        assert!(row.is_some());
658        assert_eq!(row.unwrap(), vec![0, 0, 0]);
659    }
660
661    #[test]
662    fn try_get_row_invalid() {
663        let matrix: Matrix<i32> = Matrix::new(2, 3);
664        let row = matrix.try_get_row(2);
665        assert!(row.is_none());
666    }
667
668    #[test]
669    fn transpose_works_correctly() {
670        let mut matrix: Matrix<i32> = Matrix::new(2, 3);
671        for i in 0..matrix.data.len() {
672            matrix.data[i] = i as i32;
673        }
674        let transposed = matrix.transpose();
675        assert_eq!(transposed.data, vec![0, 3, 1, 4, 2, 5]);
676    }
677
678    #[test]
679    fn determinant_of_1x1_matrix() {
680        let matrix = Matrix {
681            data: vec![7],
682            row_size: 1,
683            col_size: 1,
684        };
685        assert_eq!(matrix.determinant(), Some(7));
686    }
687
688    #[test]
689    fn determinant_of_2x2_matrix() {
690        let matrix = Matrix {
691            data: vec![1, 2, 3, 4],
692            row_size: 2,
693            col_size: 2,
694        };
695        assert_eq!(matrix.determinant(), Some(-2));
696    }
697
698    #[test]
699    fn determinant_of_3x3_matrix() {
700        let matrix = Matrix {
701            data: vec![3, 2, 1, 0, 1, 4, 5, 6, 0],
702            row_size: 3,
703            col_size: 3,
704        };
705        let expected = -37;
706        assert_eq!(matrix.determinant(), Some(expected));
707    }
708
709    #[test]
710    fn determinant_non_square_matrix() {
711        let matrix = Matrix {
712            data: vec![1, 2, 3, 4, 5, 6],
713            row_size: 2,
714            col_size: 3,
715        };
716        assert_eq!(matrix.determinant(), None);
717    }
718
719    #[test]
720    fn test_frobenius_norm_i32() {
721        let matrix = Matrix::<i32> {
722            data: vec![1, 2, 3, 4],
723            row_size: 2,
724            col_size: 2,
725        };
726
727        let expected: f64 = 5.477225575051661;
728        let result = matrix.frobenius_norm();
729        assert!((result - expected).abs() < 1e-10);
730    }
731
732    #[test]
733    fn test_frobenius_norm_f32() {
734        let matrix = Matrix::<f32> {
735            data: vec![1.0, 2.0, 3.0, 4.0],
736            row_size: 2,
737            col_size: 2,
738        };
739
740        let expected: f64 = 5.477225575051661;
741        let result = matrix.frobenius_norm();
742        assert!((result - expected).abs() < 1e-10);
743    }
744
745    #[test]
746    fn test_inverse_2x2() {
747        let matrix = Matrix {
748            data: vec![4.0, 7.0, 2.0, 6.0],
749            row_size: 2,
750            col_size: 2,
751        };
752
753        let expected_inverse = vec![0.6, -0.7, -0.2, 0.4];
754        let result = matrix.inverse().unwrap();
755
756        assert!(approx_equal(&result.data, &expected_inverse, 1e-6));
757    }
758
759    #[test]
760    fn test_inverse_3x3() {
761        let matrix = Matrix {
762            data: vec![1.0, 2.0, 3.0, 0.0, 1.0, 4.0, 5.0, 6.0, 0.0],
763            row_size: 3,
764            col_size: 3,
765        };
766
767        let expected_inverse = vec![-24.0, 18.0, 5.0, 20.0, -15.0, -4.0, -5.0, 4.0, 1.0];
768        let result = matrix.inverse().unwrap();
769
770        assert!(approx_equal(&result.data, &expected_inverse, 1e-6));
771    }
772
773    #[test]
774    fn test_inverse_identity() {
775        let matrix = Matrix::<f64>::identity(3);
776        let result = matrix.inverse().unwrap();
777
778        assert_eq!(result.data, matrix.data);
779    }
780
781    #[test]
782    fn test_inverse_singular() {
783        let matrix = Matrix {
784            data: vec![1.0, 2.0, 2.0, 4.0],
785            row_size: 2,
786            col_size: 2,
787        };
788
789        let result = matrix.inverse();
790
791        assert!(result.is_none());
792    }
793
794    #[test]
795    fn test_inverse_1x1() {
796        let matrix = Matrix {
797            data: vec![2.0],
798            row_size: 1,
799            col_size: 1,
800        };
801
802        let expected_inverse = vec![0.5];
803        let result = matrix.inverse().unwrap();
804
805        assert!(approx_equal(&result.data, &expected_inverse, 1e-6));
806    }
807
808    #[test]
809    /// Test that a one by one matrix only has one value
810    fn test_random_1x1() {
811        let result: Matrix<f64> = Matrix::new_random(1, 1);
812        assert_eq!(result.data.len(), 1);
813    }
814
815    #[test]
816    /// Test that a two by three matrix has six values
817    fn test_random_2x3() {
818        let result: Matrix<i64> = Matrix::new_random(2, 3);
819        assert_eq!(result.data.len(), 6);
820    }
821
822    #[test]
823    /// Test that a twenty by twenty matrix has four hundred values
824    fn test_random_20x20() {
825        let result: Matrix<i64> = Matrix::new_random(20, 20);
826        println!("{:?}", result.data);
827        assert_eq!(result.data.len(), 400);
828    }
829
830    #[test]
831    fn test_sum_empty_matrix() {
832        let mat = Matrix::<f64>::new(0, 0);
833        assert_eq!(mat.sum(), 0.0);
834    }
835
836    #[test]
837    fn test_sum_zero_matrix() {
838        let mat = Matrix::<f64>::new(3, 3);
839        assert_eq!(mat.sum(), 0.0);
840    }
841
842    #[test]
843    fn test_sum_positive_elements() {
844        let matrix = Matrix {
845            data: vec![1.0, 2.0, 3.0, 4.0],
846            row_size: 2,
847            col_size: 2,
848        };
849        assert_eq!(matrix.sum(), 10.0);
850    }
851
852    #[test]
853    fn test_sum_negative_elements() {
854        let matrix = Matrix {
855            data: vec![-1.0, -2.0, -3.0, -4.0],
856            row_size: 2,
857            col_size: 2,
858        };
859        assert_eq!(matrix.sum(), -10.0);
860    }
861
862    #[test]
863    fn test_sum_mixed_elements() {
864        let matrix = Matrix {
865            data: vec![-1.0, 2.0, -3.0, 4.0],
866            row_size: 2,
867            col_size: 2,
868        };
869        assert_eq!(matrix.sum(), 2.0);
870    }
871}