math_concept/lalg/matrix/
mod.rs

1use std::default::Default;
2use std::ops::{Add, AddAssign, Div, Mul};
3
4#[cfg(test)]
5mod tests;
6
7/// A representation of a standard Matrix of size `r * c`
8///
9/// # Examples
10///
11/// ```
12/// # // no idea why I need this but eh, well..
13/// # use math_concept::lalg::matrix::Matrix;
14/// let mut m = Matrix::new(2, 2, vec![0; 4]);
15///
16/// assert_eq!(m.at(0, 0), 0);
17///
18/// *m.at_mut(1, 1) = 5;
19/// assert_eq!(m.at(1, 1), 5);
20/// ```
21/// ```
22/// # use math_concept::lalg::matrix::Matrix;
23/// let m = Matrix::from_vec(1, 2, &vec![9, 8]);
24///
25/// assert_eq!(&vec![9, 8], m.data());
26/// ```
27pub struct Matrix<T> {
28    r: usize,
29    c: usize,
30    buf: Vec<T>,
31}
32
33impl<T> Matrix<T>
34where
35    T: Copy + Add<Output = T> + Div<Output = T>,
36{
37    /// Constructs a new `Matrix<T>`, taking ownership of `v` and using it
38    /// as the buffer.
39    pub fn new(r: usize, c: usize, v: Vec<T>) -> Matrix<T> {
40        assert_eq!(r * c, v.len(), "Matrix dimensions and data size differ");
41        Matrix { r, c, buf: v }
42    }
43
44    /// Constructs a new `Matrix<T>`, using `v`
45    /// as the buffer without taking ownership of it.
46    pub fn from_vec(r: usize, c: usize, v: &Vec<T>) -> Matrix<T> {
47        assert_eq!(r * c, v.len(), "Matrix dimensions and data size differ");
48        Matrix::new(r, c, v.clone())
49    }
50
51    /// Returns the number of rows of the matrix.
52    pub fn rows(&self) -> usize {
53        self.r
54    }
55
56    /// Returns the number of columns of the matrix.
57    pub fn cols(&self) -> usize {
58        self.c
59    }
60
61    /// Returns the element at the `r`-th row and `c`-th column,
62    /// provided those are within bounds.
63    pub fn at(&self, r: usize, c: usize) -> T {
64        assert!(r < self.r && c < self.c, "Out of bounds");
65        self.buf[r * self.c + c]
66    }
67
68    /// Returns a mutable reference to the element at the `r`-th row
69    /// and `c`-th column, provided those are within bounds.
70    pub fn at_mut(&mut self, r: usize, c: usize) -> &mut T {
71        assert!(r < self.r && c < self.c, "Out of bounds");
72        &mut self.buf[r * self.c + c]
73    }
74
75    // TODO: implement this
76    // pub fn determinant(&self) -> T;
77
78    /// Returns a reference to the matrix data vector
79    pub fn data(&self) -> &Vec<T> {
80        &self.buf
81    }
82
83    /// Maps over the elements of the matrix.
84    ///
85    /// ## Example
86    /// ```
87    /// # // no idea why I need this but eh, well..
88    /// # use math_concept::lalg::matrix::Matrix;
89    /// let m = Matrix::new(2, 2, vec![5; 4]);
90    ///
91    /// let mapped = m.map(|&x| x * 8);
92    /// assert_eq!(mapped.data(), &vec![40; 4]);
93    /// ```
94    pub fn map<U, F>(&self, f: F) -> Matrix<U>
95    where
96        U: Copy + Add<Output = U> + Div<Output = U>,
97        F: Fn(&T) -> U,
98    {
99        let new_v = self.buf.iter().map(f).collect();
100
101        Matrix::new(self.r, self.c, new_v)
102    }
103
104    /// Produces the diagonal of an 'm x m' matrix.
105    /// Panics otherwise.
106    ///
107    /// ## Example
108    /// ```
109    /// # // no idea why I need this but eh, well..
110    /// # use math_concept::lalg::matrix::Matrix;
111    /// let m = Matrix::new(2, 2, vec![1, 2, 3, 4]);
112    ///
113    /// let diag = m.diagonal();
114    /// assert_eq!(diag, vec![1, 4]);
115    /// ```
116    pub fn diagonal(&self) -> Vec<T> {
117        self.assert_square();
118
119        self.buf
120            .iter()
121            .zip(0..self.buf.len())
122            .filter(|(_, i)| i % self.c == i / self.r)
123            .map(|(&elem, _)| elem)
124            .collect()
125    }
126
127    fn assert_square(&self) {
128        assert_eq!(self.r, self.c, "Matrix is not square");
129    }
130}
131
132impl<'a, T> Add<&'a Matrix<T>> for &'a Matrix<T>
133where
134    T: Copy + Add<Output = T> + Div<Output = T>,
135{
136    type Output = Matrix<T>;
137
138    fn add(self, other: Self) -> Self::Output {
139        assert!(
140            self.r == other.rows() && self.c == other.cols(),
141            "Matrices are not of the same size"
142        );
143
144        let new_v = self
145            .buf
146            .iter()
147            .zip(other.data())
148            .map(|(&x, &y)| x + y)
149            .collect();
150
151        Matrix::new(self.r, self.c, new_v)
152    }
153}
154
155impl<'a, T> Mul<&'a Matrix<T>> for &'a Matrix<T>
156where
157    T: Copy + Add<Output = T> + AddAssign<T> + Mul<Output = T> + Div<Output = T> + Default,
158{
159    type Output = Matrix<T>;
160
161    fn mul(self, other: Self) -> Self::Output {
162        assert_eq!(self.c, other.r);
163
164        let mut v: Vec<T> = vec![Default::default(); self.r * other.c];
165
166        for i in 0..self.r {
167            for j in 0..other.c {
168                for k in 0..self.c {
169                    v[i * self.r + j] += self.buf[i * self.c + k] * other.buf[k * other.c + j];
170                }
171            }
172        }
173
174        Matrix::new(self.r, other.c, v)
175    }
176}