maria_linalg/
matrix.rs

1//! Implements necessary methods on *square* matrices.
2
3use std::{
4    fmt,
5    ops::{
6        Add,
7        Index,
8        IndexMut,
9        Sub,
10    },
11};
12
13use super::Vector;
14
15#[derive(Clone, Copy, PartialEq, Debug)]
16/// Abstracts over a square matrix of arbitrary dimension.
17pub struct Matrix<const N: usize> {
18    /// Contains the values of this matrix.
19    values: [[f64; N]; N],
20}
21
22/// Implements necessary behaviors of a matrix.
23impl<const N: usize> Matrix<N> {
24    /// Constructs a zero matrix.
25    pub fn zero() -> Self {
26        Self {
27            values: [[0.0; N]; N],
28        }
29    }
30
31    /// Constructs an identity matrix.
32    pub fn identity() -> Self {
33        let mut values = [[0.0; N]; N];
34
35        for i in 0..N {
36            values[i][i] = 1.0;
37        }
38
39        Self {
40            values
41        }
42    }
43
44    /// Constructs a matrix of provided values.
45    pub fn new(values: [[f64; N]; N]) -> Self {
46        Self {
47            values,
48        }
49    }
50
51    /// Returns a 3D rotation matrix representing a right-handed rotation about the
52    ///     provided axis by the provided angle.
53    /// 
54    /// *Note*: the provided angle is in radians.
55    pub fn rotation(
56        axis: Vector<3>,
57        angle: f64,
58    ) -> Matrix<3> {
59        let basis = Matrix::<3>::identity();
60        let mut r = [Vector::<3>::zero(); 3];
61
62        for i in 0..3 {
63            r[i] = basis.column(i).rotate(axis, angle);
64        }
65
66        Matrix::<3>::new([
67            [r[0][0], r[1][0], r[2][0]],
68            [r[0][1], r[1][1], r[2][1]],
69            [r[0][2], r[1][2], r[2][2]],
70        ])
71    }
72
73    /// Decomposes this matrix into its columns.
74    pub fn decompose(&self) -> [Vector<N>; N] {
75        let mut columns = [Vector::zero(); N];
76
77        for i in 0..N {
78            columns[i] = self.column(i);
79        }
80
81        columns
82    }
83
84    /// Gets a column of this vector.
85    pub fn column(&self, j: usize) -> Vector<N> {
86        let mut vector = Vector::zero();
87
88        for i in 0..N {
89            vector[i] = self[(i, j)];
90        }
91
92        vector
93    }
94
95    /// Right-multiplies this matrix by the provided vector, returning the result.
96    pub fn mult(&self, vector: Vector<N>) -> Vector<N> {
97        let mut output = Vector::<N>::zero();
98
99        for i in 0..N {
100            for j in 0..N {
101                output[i] += self[(i, j)] * vector[j];
102            }
103        }
104
105        output
106    }
107
108    /// Right-multiplies this matrix by the provided matrix, returning the result.
109    pub fn matmult(&self, matrix: Matrix<N>) -> Matrix<N> {
110        let mut output = Matrix::<N>::zero();
111
112        for i in 0..N {
113            for j in 0..N {
114                for k in 0..N {
115                    output[(i, j)] += self[(i, k)] * matrix[(k, j)];
116                }
117            }
118        }
119
120        output
121    }
122
123    /// Swap rows `i` and `j`.
124    fn swaprow(&mut self, i: usize, j: usize) {
125        let temp = self.values[i];
126        self.values[i] = self.values[j];
127        self.values[j] = temp;
128    }
129
130    /// Scale row `i` by factor `s`.
131    fn scalerow(&mut self, i: usize, s: f64) {
132        for j in 0..N {
133            self[(i, j)] *= s;
134        }
135    }
136
137    /// Subtract `s` times row `j` from row `i`.
138    fn subrow(&mut self, i: usize, j: usize, s: f64) {
139        for k in 0..N {
140            self[(i, k)] -= s * self[(j, k)];
141        }
142    }
143
144    /// Returns the inverse of this matrix.
145    pub fn inverse(&self) -> Self {
146        let mut output = *self;
147        let mut inverse = Self::identity();
148
149        for i in 0..N {
150            // Determine the index of the row with the largest pivot
151            // Start from the working row
152            let mut j = i;
153            for k in i..N {
154                if output[(k, i)] > output[(i, i)] {
155                    j = k;
156                }
157            }
158
159            // Swap largest pivot to working row
160            output.swaprow(i, j);
161            inverse.swaprow(i, j);
162
163            // Normalize this row
164            let s = 1.0 / output[(i, i)];
165            output.scalerow(i, s);
166            inverse.scalerow(i, s);
167
168            // Subtract this row from all lower rows
169            for k in (i + 1)..N {
170                let s = output[(k, i)];
171                output.subrow(k, i, s);
172                inverse.subrow(k, i, s);
173            }
174        }
175
176        // We're now in upper triangular, let's get to GJ normal form
177
178        for i in 0..N {
179            for j in (i + 1)..N {
180                let s = output[(i, j)];
181                output.subrow(i, j, s);
182                inverse.subrow(i, j, s);
183            }
184        }
185
186        inverse
187    }
188
189    /// Scales a matrix by a provided scalar, returning the new matrix.
190    pub fn scale(&self, scalar: f64) -> Self {
191        let mut newvalues = [[0.0; N]; N];
192        for i in 0..N {
193            for j in 0..N {
194                newvalues[i][j] = scalar * self[(i, j)];
195            }
196        }
197
198        Self {
199            values: newvalues,
200        }
201    }
202}
203
204impl<const N: usize> Index<(usize, usize)> for Matrix<N> {
205    type Output = f64;
206
207    fn index(&self, idx: (usize, usize)) -> &Self::Output {
208        &self.values[idx.0][idx.1]
209    }
210}
211
212impl<const N: usize> IndexMut<(usize, usize)> for Matrix<N> {
213    fn index_mut(&mut self, idx: (usize, usize)) -> &mut Self::Output {
214        &mut self.values[idx.0][idx.1]
215    }
216}
217
218impl<const N: usize> Add for Matrix<N> {
219    type Output = Self;
220
221    fn add(self, other: Self) -> Self {
222        let mut new = Self::zero();
223    
224        for i in 0..N {
225            for j in 0..N {
226                new[(i, j)] = self[(i, j)] + other[(i, j)];   
227            }
228        }
229
230        new
231    }
232}
233
234impl<const N: usize> Sub for Matrix<N> {
235    type Output = Self;
236
237    fn sub(self, other: Self) -> Self {
238        let mut new = Self::zero();
239    
240        for i in 0..N {
241            for j in 0..N {
242                new[(i, j)] = self[(i, j)] - other[(i, j)];   
243            }
244        }
245
246        new
247    }
248}
249
250impl<const N: usize> fmt::Display for Matrix<N> {
251    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
252        let mut values = Vec::new();
253        let mut maxlen = 0;
254        for i in 0..N {
255            for j in 0..N {
256                let value = self[(i, j)];
257                let row = if value >= 0.0 {
258                    format!(" {:.8}", value)
259                } else {
260                    format!("{:.8}", value)
261                };
262                let l = row.len();
263                values.push(row);
264                if l > maxlen {
265                    maxlen = l;
266                }
267            }
268        }
269
270        let mut output = String::new();
271        for i in 0..N {
272            output.push_str("[");
273            for j in 0..N {
274                output.push_str(
275                    &format!("{:^i$}", values[j + N*i], i = maxlen + 2)
276                );
277            }
278            output.push_str("]\n");
279        }
280
281        write!(f, "{}", output)
282    }
283}
284
285#[test]
286fn matrix_multiply() {
287    let a = Matrix::new([
288        [1.0, 2.0, 3.0],
289        [4.0, 5.0, 6.0],
290        [7.0, 8.0, 9.0],
291    ]);
292
293    let b = Matrix::new([
294        [9.0, 8.0, 7.0],
295        [6.0, 5.0, 4.0],
296        [3.0, 2.0, 1.0],
297    ]);
298
299    let c = Matrix::new([
300        [ 30.0,  24.0,  18.0],
301        [ 84.0,  69.0,  54.0],
302        [138.0, 114.0,  90.0],
303    ]);
304
305    println!("{}", c);
306
307    assert_eq!(a.matmult(b), c);
308}
309
310#[test]
311fn decompose() {
312    let a = Matrix::new([
313        [1.0, 2.0, 3.0],
314        [4.0, 5.0, 6.0],
315        [7.0, 8.0, 9.0],
316    ]);
317
318    let basis: [Vector<3>; 3] = [
319        [1.0, 4.0, 7.0].into(),
320        [2.0, 5.0, 8.0].into(),
321        [3.0, 6.0, 9.0].into(),
322    ];
323
324    assert_eq!(a.column(0), basis[0]);
325}
326
327#[test]
328fn z_rotation_matrix() {
329    let axis = [0.0, 0.0, 1.0].into();
330
331    let rotation = Matrix::<3>::rotation(axis, 30.0 * 3.141592653 / 180.0);
332
333    println!("{}", rotation);
334}