matrix_basic/
lib.rs

1//! This is a crate for very basic matrix operations
2//! with any type that implement [`Add`], [`Sub`], [`Mul`],
3//! [`Zero`], [`Neg`] and [`Copy`]. Additional properties might be
4//! needed for certain operations.
5//!
6//! I created it mostly to learn using generic types
7//! and traits.
8//!
9//! Sayantan Santra (2023)
10
11use errors::MatrixError;
12use num::{
13    traits::{One, Zero},
14    Integer,
15};
16use std::{
17    fmt::{self, Debug, Display, Formatter},
18    ops::{Add, Div, Mul, Neg, Sub},
19    result::Result,
20};
21
22pub mod errors;
23mod tests;
24
25/// Trait a type must satisfy to be element of a matrix. This is
26/// mostly to reduce writing trait bounds afterwards.
27pub trait ToMatrix:
28    Mul<Output = Self>
29    + Add<Output = Self>
30    + Sub<Output = Self>
31    + Zero<Output = Self>
32    + Neg<Output = Self>
33    + Copy
34{
35}
36
37/// Blanket implementation for [`ToMatrix`] for any type that satisfies its bounds.
38impl<T> ToMatrix for T where
39    T: Mul<Output = T>
40        + Add<Output = T>
41        + Sub<Output = T>
42        + Zero<Output = T>
43        + Neg<Output = T>
44        + Copy
45{
46}
47
48/// A generic matrix struct (over any type with [`Add`], [`Sub`], [`Mul`],
49/// [`Zero`], [`Neg`] and [`Copy`] implemented).
50/// Look at [`from`](Self::from()) to see examples.
51#[derive(PartialEq, Debug, Clone)]
52pub struct Matrix<T: ToMatrix> {
53    entries: Vec<Vec<T>>,
54}
55
56impl<T: ToMatrix> Matrix<T> {
57    /// Creates a matrix from given 2D "array" in a `Vec<Vec<T>>` form.
58    /// It'll throw an error if all the given rows aren't of the same size.
59    /// # Example
60    /// ```
61    /// use matrix_basic::Matrix;
62    /// let m = Matrix::from(vec![vec![1, 2, 3], vec![4, 5, 6]]);
63    /// ```
64    /// will create the following matrix:  
65    /// ⌈1, 2, 3⌉  
66    /// ⌊4, 5, 6⌋
67    pub fn from(entries: Vec<Vec<T>>) -> Result<Matrix<T>, MatrixError> {
68        let mut equal_rows = true;
69        let row_len = entries[0].len();
70        for row in &entries {
71            if row_len != row.len() {
72                equal_rows = false;
73                break;
74            }
75        }
76        if equal_rows {
77            Ok(Matrix { entries })
78        } else {
79            Err(MatrixError::UnequalRows)
80        }
81    }
82
83    /// Returns the height of a matrix.
84    pub fn height(&self) -> usize {
85        self.entries.len()
86    }
87
88    /// Returns the width of a matrix.
89    pub fn width(&self) -> usize {
90        self.entries[0].len()
91    }
92
93    /// Returns the transpose of a matrix.
94    pub fn transpose(&self) -> Self {
95        let mut out = Vec::new();
96        for i in 0..self.width() {
97            let mut column = Vec::new();
98            for row in &self.entries {
99                column.push(row[i]);
100            }
101            out.push(column)
102        }
103        Matrix { entries: out }
104    }
105
106    /// Returns a reference to the rows of a matrix as `&Vec<Vec<T>>`.
107    pub fn rows(&self) -> &Vec<Vec<T>> {
108        &self.entries
109    }
110
111    /// Return the columns of a matrix as `Vec<Vec<T>>`.
112    pub fn columns(&self) -> Vec<Vec<T>> {
113        self.transpose().entries
114    }
115
116    /// Return true if a matrix is square and false otherwise.
117    pub fn is_square(&self) -> bool {
118        self.height() == self.width()
119    }
120
121    /// Returns a matrix after removing the provided row and column from it.
122    /// Note: Row and column numbers are 0-indexed.
123    /// # Example
124    /// ```
125    /// use matrix_basic::Matrix;
126    /// let m = Matrix::from(vec![vec![1, 2, 3], vec![4, 5, 6]]).unwrap();
127    /// let n = Matrix::from(vec![vec![5, 6]]).unwrap();
128    /// assert_eq!(m.submatrix(0, 0), n);
129    /// ```
130    pub fn submatrix(&self, row: usize, col: usize) -> Self {
131        let mut out = Vec::new();
132        for (m, row_iter) in self.entries.iter().enumerate() {
133            if m == row {
134                continue;
135            }
136            let mut new_row = Vec::new();
137            for (n, entry) in row_iter.iter().enumerate() {
138                if n != col {
139                    new_row.push(*entry);
140                }
141            }
142            out.push(new_row);
143        }
144        Matrix { entries: out }
145    }
146
147    /// Returns the determinant of a square matrix.
148    /// This uses basic recursive algorithm using cofactor-minor.
149    /// See [`det_in_field`](Self::det_in_field()) for faster determinant calculation in fields.
150    /// It'll throw an error if the provided matrix isn't square.
151    /// # Example
152    /// ```
153    /// use matrix_basic::Matrix;
154    /// let m = Matrix::from(vec![vec![1, 2], vec![3, 4]]).unwrap();
155    /// assert_eq!(m.det(), Ok(-2));
156    /// ```
157    pub fn det(&self) -> Result<T, MatrixError> {
158        if self.is_square() {
159            // It's a recursive algorithm using minors.
160            // TODO: Implement a faster algorithm.
161            let out = if self.width() == 1 {
162                self.entries[0][0]
163            } else {
164                // Add the minors multiplied by cofactors.
165                let n = 0..self.width();
166                let mut out = T::zero();
167                for i in n {
168                    if i.is_even() {
169                        out = out + (self.entries[0][i] * self.submatrix(0, i).det().unwrap());
170                    } else {
171                        out = out - (self.entries[0][i] * self.submatrix(0, i).det().unwrap());
172                    }
173                }
174                out
175            };
176            Ok(out)
177        } else {
178            Err(MatrixError::NotSquare)
179        }
180    }
181
182    /// Returns the determinant of a square matrix over a field i.e. needs [`One`] and [`Div`] traits.
183    /// See [`det`](Self::det()) for determinants in rings.
184    /// This method uses row reduction as is much faster.
185    /// It'll throw an error if the provided matrix isn't square.
186    /// # Example
187    /// ```
188    /// use matrix_basic::Matrix;
189    /// let m = Matrix::from(vec![vec![1.0, 2.0], vec![3.0, 4.0]]).unwrap();
190    /// assert_eq!(m.det_in_field(), Ok(-2.0));
191    /// ```
192    pub fn det_in_field(&self) -> Result<T, MatrixError>
193    where
194        T: One,
195        T: PartialEq,
196        T: Div<Output = T>,
197    {
198        if self.is_square() {
199            // Cloning is necessary as we'll be doing row operations on it.
200            let mut rows = self.entries.clone();
201            let mut multiplier = T::one();
202            let h = self.height();
203            let w = self.width();
204            for i in 0..(h - 1) {
205                // First check if the row has diagonal element 0, if yes, then swap.
206                if rows[i][i] == T::zero() {
207                    let mut zero_column = true;
208                    for j in (i + 1)..h {
209                        if rows[j][i] != T::zero() {
210                            rows.swap(i, j);
211                            multiplier = -multiplier;
212                            zero_column = false;
213                            break;
214                        }
215                    }
216                    if zero_column {
217                        return Ok(T::zero());
218                    }
219                }
220                for j in (i + 1)..h {
221                    let ratio = rows[j][i] / rows[i][i];
222                    for k in i..w {
223                        rows[j][k] = rows[j][k] - rows[i][k] * ratio;
224                    }
225                }
226            }
227            for (i, row) in rows.iter().enumerate() {
228                multiplier = multiplier * row[i];
229            }
230            Ok(multiplier)
231        } else {
232            Err(MatrixError::NotSquare)
233        }
234    }
235
236    /// Returns the row echelon form of a matrix over a field i.e. needs the [`Div`] trait.
237    /// # Example
238    /// ```
239    /// use matrix_basic::Matrix;
240    /// let m = Matrix::from(vec![vec![1.0, 2.0, 3.0], vec![3.0, 4.0, 5.0]]).unwrap();
241    /// let n = Matrix::from(vec![vec![1.0, 2.0, 3.0], vec![0.0, -2.0, -4.0]]).unwrap();
242    /// assert_eq!(m.row_echelon(), n);
243    /// ```
244    pub fn row_echelon(&self) -> Self
245    where
246        T: PartialEq,
247        T: Div<Output = T>,
248    {
249        // Cloning is necessary as we'll be doing row operations on it.
250        let mut rows = self.entries.clone();
251        let mut offset = 0;
252        let h = self.height();
253        let w = self.width();
254        for i in 0..(h - 1) {
255            // Check if all the rows below are 0
256            if i + offset >= self.width() {
257                break;
258            }
259            // First check if the row has diagonal element 0, if yes, then swap.
260            if rows[i][i + offset] == T::zero() {
261                let mut zero_column = true;
262                for j in (i + 1)..h {
263                    if rows[j][i + offset] != T::zero() {
264                        rows.swap(i, j);
265                        zero_column = false;
266                        break;
267                    }
268                }
269                if zero_column {
270                    offset += 1;
271                }
272            }
273            for j in (i + 1)..h {
274                let ratio = rows[j][i + offset] / rows[i][i + offset];
275                for k in (i + offset)..w {
276                    rows[j][k] = rows[j][k] - rows[i][k] * ratio;
277                }
278            }
279        }
280        Matrix { entries: rows }
281    }
282
283    /// Returns the column echelon form of a matrix over a field i.e. needs the [`Div`] trait.
284    /// It's just the transpose of the row echelon form of the transpose.
285    /// See [`row_echelon`](Self::row_echelon()) and [`transpose`](Self::transpose()).
286    pub fn column_echelon(&self) -> Self
287    where
288        T: PartialEq,
289        T: Div<Output = T>,
290    {
291        self.transpose().row_echelon().transpose()
292    }
293
294    /// Returns the reduced row echelon form of a matrix over a field i.e. needs the `Div`] trait.
295    /// # Example
296    /// ```
297    /// use matrix_basic::Matrix;
298    /// let m = Matrix::from(vec![vec![1.0, 2.0, 3.0], vec![3.0, 4.0, 5.0]]).unwrap();
299    /// let n = Matrix::from(vec![vec![1.0, 2.0, 3.0], vec![0.0, 1.0, 2.0]]).unwrap();
300    /// assert_eq!(m.reduced_row_echelon(), n);
301    /// ```
302    pub fn reduced_row_echelon(&self) -> Self
303    where
304        T: PartialEq,
305        T: Div<Output = T>,
306    {
307        let mut echelon = self.row_echelon();
308        let mut offset = 0;
309        for row in &mut echelon.entries {
310            while row[offset] == T::zero() {
311                offset += 1;
312            }
313            let divisor = row[offset];
314            for entry in row.iter_mut().skip(offset) {
315                *entry = *entry / divisor;
316            }
317            offset += 1;
318        }
319        echelon
320    }
321
322    /// Creates a zero matrix of a given size.
323    pub fn zero(height: usize, width: usize) -> Self {
324        let mut out = Vec::new();
325        for _ in 0..height {
326            let mut new_row = Vec::new();
327            for _ in 0..width {
328                new_row.push(T::zero());
329            }
330            out.push(new_row);
331        }
332        Matrix { entries: out }
333    }
334
335    /// Creates an identity matrix of a given size.
336    /// It needs the [`One`] trait.
337    pub fn identity(size: usize) -> Self
338    where
339        T: One,
340    {
341        let mut out = Matrix::zero(size, size);
342        for (i, row) in out.entries.iter_mut().enumerate() {
343            row[i] = T::one();
344        }
345        out
346    }
347
348    /// Returns the trace of a square matrix.
349    /// It'll throw an error if the provided matrix isn't square.
350    /// # Example
351    /// ```
352    /// use matrix_basic::Matrix;
353    /// let m = Matrix::from(vec![vec![1, 2], vec![3, 4]]).unwrap();
354    /// assert_eq!(m.trace(), Ok(5));
355    /// ```
356    pub fn trace(self) -> Result<T, MatrixError> {
357        if self.is_square() {
358            let mut out = self.entries[0][0];
359            for i in 1..self.height() {
360                out = out + self.entries[i][i];
361            }
362            Ok(out)
363        } else {
364            Err(MatrixError::NotSquare)
365        }
366    }
367
368    /// Returns a diagonal matrix with a given diagonal.
369    /// # Example
370    /// ```
371    /// use matrix_basic::Matrix;
372    /// let m = Matrix::diagonal_matrix(vec![1, 2, 3]);
373    /// let n = Matrix::from(vec![vec![1, 0, 0], vec![0, 2, 0], vec![0, 0, 3]]).unwrap();
374    ///
375    /// assert_eq!(m, n);
376    /// ```
377    pub fn diagonal_matrix(diag: Vec<T>) -> Self {
378        let size = diag.len();
379        let mut out = Matrix::zero(size, size);
380        for (i, row) in out.entries.iter_mut().enumerate() {
381            row[i] = diag[i];
382        }
383        out
384    }
385
386    /// Multiplies all entries of a matrix by a scalar.
387    /// Note that it modifies the supplied matrix.
388    /// # Example
389    /// ```
390    /// use matrix_basic::Matrix;
391    /// let mut m = Matrix::from(vec![vec![1, 2, 0], vec![0, 2, 5], vec![0, 0, 3]]).unwrap();
392    /// let n = Matrix::from(vec![vec![2, 4, 0], vec![0, 4, 10], vec![0, 0, 6]]).unwrap();
393    /// m.mul_scalar(2);
394    ///
395    /// assert_eq!(m, n);
396    /// ```
397    pub fn mul_scalar(&mut self, scalar: T) {
398        for row in &mut self.entries {
399            for entry in row {
400                *entry = *entry * scalar;
401            }
402        }
403    }
404
405    /// Returns the inverse of a square matrix. Throws an error if the matrix isn't square.
406    /// /// # Example
407    /// ```
408    /// use matrix_basic::Matrix;
409    /// let m = Matrix::from(vec![vec![1.0, 2.0], vec![3.0, 4.0]]).unwrap();
410    /// let n = Matrix::from(vec![vec![-2.0, 1.0], vec![1.5, -0.5]]).unwrap();
411    /// assert_eq!(m.inverse(), Ok(n));
412    /// ```
413    pub fn inverse(&self) -> Result<Self, MatrixError>
414    where
415        T: Div<Output = T>,
416        T: One,
417        T: PartialEq,
418    {
419        if self.is_square() {
420            // We'll use the basic technique of using an augmented matrix (in essence)
421            // Cloning is necessary as we'll be doing row operations on it.
422            let mut rows = self.entries.clone();
423            let h = self.height();
424            let w = self.width();
425            let mut out = Self::identity(h).entries;
426
427            // First we get row echelon form
428            for i in 0..(h - 1) {
429                // First check if the row has diagonal element 0, if yes, then swap.
430                if rows[i][i] == T::zero() {
431                    let mut zero_column = true;
432                    for j in (i + 1)..h {
433                        if rows[j][i] != T::zero() {
434                            rows.swap(i, j);
435                            out.swap(i, j);
436                            zero_column = false;
437                            break;
438                        }
439                    }
440                    if zero_column {
441                        return Err(MatrixError::Singular);
442                    }
443                }
444                for j in (i + 1)..h {
445                    let ratio = rows[j][i] / rows[i][i];
446                    for k in i..w {
447                        rows[j][k] = rows[j][k] - rows[i][k] * ratio;
448                    }
449                    // We cannot skip entries here as they might not be 0
450                    for k in 0..w {
451                        out[j][k] = out[j][k] - out[i][k] * ratio;
452                    }
453                }
454            }
455
456            // Then we reduce the rows
457            for i in 0..h {
458                if rows[i][i] == T::zero() {
459                    return Err(MatrixError::Singular);
460                }
461                let divisor = rows[i][i];
462                for entry in rows[i].iter_mut().skip(i) {
463                    *entry = *entry / divisor;
464                }
465                for entry in out[i].iter_mut() {
466                    *entry = *entry / divisor;
467                }
468            }
469
470            // Finally, we do upside down row reduction
471            for i in (1..h).rev() {
472                for j in (0..i).rev() {
473                    let ratio = rows[j][i];
474                    for k in 0..w {
475                        out[j][k] = out[j][k] - out[i][k] * ratio;
476                    }
477                }
478            }
479
480            Ok(Matrix { entries: out })
481        } else {
482            Err(MatrixError::NotSquare)
483        }
484    }
485
486    // TODO: Canonical forms, eigenvalues, eigenvectors etc.
487}
488
489impl<T: Debug + ToMatrix> Display for Matrix<T> {
490    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
491        write!(f, "{:?}", self.entries)
492    }
493}
494
495impl<T: Mul<Output = T> + ToMatrix> Mul for Matrix<T> {
496    // TODO: Implement a faster algorithm.
497    type Output = Self;
498    fn mul(self, other: Self) -> Self::Output {
499        let width = self.width();
500        if width != other.height() {
501            panic!("row length of first matrix != column length of second matrix");
502        } else {
503            let mut out = Vec::new();
504            for row in self.rows() {
505                let mut new_row = Vec::new();
506                for col in other.columns() {
507                    let mut prod = row[0] * col[0];
508                    for i in 1..width {
509                        prod = prod + (row[i] * col[i]);
510                    }
511                    new_row.push(prod)
512                }
513                out.push(new_row);
514            }
515            Matrix { entries: out }
516        }
517    }
518}
519
520impl<T: Mul<Output = T> + ToMatrix> Add for Matrix<T> {
521    type Output = Self;
522    fn add(self, other: Self) -> Self::Output {
523        if self.height() == other.height() && self.width() == other.width() {
524            let mut out = self.entries.clone();
525            for (i, row) in self.rows().iter().enumerate() {
526                for (j, entry) in other.rows()[i].iter().enumerate() {
527                    out[i][j] = row[j] + *entry;
528                }
529            }
530            Matrix { entries: out }
531        } else {
532            panic!("provided matrices have different dimensions");
533        }
534    }
535}
536
537impl<T: ToMatrix> Neg for Matrix<T> {
538    type Output = Self;
539    fn neg(self) -> Self::Output {
540        let mut out = self;
541        for row in &mut out.entries {
542            for entry in row {
543                *entry = -*entry;
544            }
545        }
546        out
547    }
548}
549
550impl<T: ToMatrix> Sub for Matrix<T> {
551    type Output = Self;
552    fn sub(self, other: Self) -> Self::Output {
553        if self.height() == other.height() && self.width() == other.width() {
554            self + -other
555        } else {
556            panic!("provided matrices have different dimensions");
557        }
558    }
559}
560
561/// Trait for conversion between matrices of different types.
562/// It only has a [`matrix_from()`](Self::matrix_from()) method.
563/// This is needed since negative trait bound are not supported in stable Rust
564/// yet, so we'll have a conflict trying to implement [`From`].
565/// I plan to change this to the default From trait as soon as some sort
566/// of specialization system is implemented.
567/// You can track this issue [here](https://github.com/rust-lang/rust/issues/42721).
568pub trait MatrixFrom<T: ToMatrix> {
569    /// Method for getting a matrix of a new type from a matrix of type [`Matrix<T>`].
570    /// # Example
571    /// ```
572    /// use matrix_basic::Matrix;
573    /// use matrix_basic::MatrixFrom;
574    ///
575    /// let a = Matrix::from(vec![vec![1, 2, 3], vec![0, 1, 2]]).unwrap();
576    /// let b = Matrix::from(vec![vec![1.0, 2.0, 3.0], vec![0.0, 1.0, 2.0]]).unwrap();
577    /// let c = Matrix::<f64>::matrix_from(a); // Type annotation is needed here
578    ///
579    /// assert_eq!(c, b);
580    /// ```
581    fn matrix_from(input: Matrix<T>) -> Self;
582}
583
584/// Blanket implementation of [`MatrixFrom<T>`] for converting [`Matrix<S>`] to [`Matrix<T>`] whenever
585/// `S` implements [`From(T)`]. Look at [`matrix_into`](Self::matrix_into()).
586impl<T: ToMatrix, S: ToMatrix + From<T>> MatrixFrom<T> for Matrix<S> {
587    fn matrix_from(input: Matrix<T>) -> Self {
588        let mut out = Vec::new();
589        for row in input.entries {
590            let mut new_row: Vec<S> = Vec::new();
591            for entry in row {
592                new_row.push(entry.into());
593            }
594            out.push(new_row)
595        }
596        Matrix { entries: out }
597    }
598}
599
600/// Sister trait of [`MatrixFrom`]. Basically does the same thing, just with a
601/// different syntax.
602pub trait MatrixInto<T> {
603    /// Method for converting a matrix [`Matrix<T>`] to another type.
604    /// # Example
605    /// ```
606    /// use matrix_basic::Matrix;
607    /// use matrix_basic::MatrixInto;
608    ///
609    /// let a = Matrix::from(vec![vec![1, 2, 3], vec![0, 1, 2]]).unwrap();
610    /// let b = Matrix::from(vec![vec![1.0, 2.0, 3.0], vec![0.0, 1.0, 2.0]]).unwrap();
611    /// let c: Matrix<f64> = a.matrix_into(); // Type annotation is needed here
612    ///
613    ///
614    /// assert_eq!(c, b);
615    /// ```
616    fn matrix_into(self) -> T;
617}
618
619/// Blanket implementation of [`MatrixInto<T>`] for [`Matrix<S>`] whenever `T`
620/// (which is actually some)[`Matrix<U>`] implements [`MatrixFrom<S>`].
621impl<T: MatrixFrom<S>, S: ToMatrix> MatrixInto<T> for Matrix<S> {
622    fn matrix_into(self) -> T {
623        T::matrix_from(self)
624    }
625}