1use std::{
4 fmt,
5 ops::{
6 Add,
7 Index,
8 IndexMut,
9 Sub,
10 },
11};
12
13use super::Vector;
14
15#[derive(Clone, Copy, PartialEq, Debug)]
16pub struct Matrix<const N: usize> {
18 values: [[f64; N]; N],
20}
21
22impl<const N: usize> Matrix<N> {
24 pub fn zero() -> Self {
26 Self {
27 values: [[0.0; N]; N],
28 }
29 }
30
31 pub fn identity() -> Self {
33 let mut values = [[0.0; N]; N];
34
35 for i in 0..N {
36 values[i][i] = 1.0;
37 }
38
39 Self {
40 values
41 }
42 }
43
44 pub fn new(values: [[f64; N]; N]) -> Self {
46 Self {
47 values,
48 }
49 }
50
51 pub fn rotation(
56 axis: Vector<3>,
57 angle: f64,
58 ) -> Matrix<3> {
59 let basis = Matrix::<3>::identity();
60 let mut r = [Vector::<3>::zero(); 3];
61
62 for i in 0..3 {
63 r[i] = basis.column(i).rotate(axis, angle);
64 }
65
66 Matrix::<3>::new([
67 [r[0][0], r[1][0], r[2][0]],
68 [r[0][1], r[1][1], r[2][1]],
69 [r[0][2], r[1][2], r[2][2]],
70 ])
71 }
72
73 pub fn decompose(&self) -> [Vector<N>; N] {
75 let mut columns = [Vector::zero(); N];
76
77 for i in 0..N {
78 columns[i] = self.column(i);
79 }
80
81 columns
82 }
83
84 pub fn column(&self, j: usize) -> Vector<N> {
86 let mut vector = Vector::zero();
87
88 for i in 0..N {
89 vector[i] = self[(i, j)];
90 }
91
92 vector
93 }
94
95 pub fn mult(&self, vector: Vector<N>) -> Vector<N> {
97 let mut output = Vector::<N>::zero();
98
99 for i in 0..N {
100 for j in 0..N {
101 output[i] += self[(i, j)] * vector[j];
102 }
103 }
104
105 output
106 }
107
108 pub fn matmult(&self, matrix: Matrix<N>) -> Matrix<N> {
110 let mut output = Matrix::<N>::zero();
111
112 for i in 0..N {
113 for j in 0..N {
114 for k in 0..N {
115 output[(i, j)] += self[(i, k)] * matrix[(k, j)];
116 }
117 }
118 }
119
120 output
121 }
122
123 fn swaprow(&mut self, i: usize, j: usize) {
125 let temp = self.values[i];
126 self.values[i] = self.values[j];
127 self.values[j] = temp;
128 }
129
130 fn scalerow(&mut self, i: usize, s: f64) {
132 for j in 0..N {
133 self[(i, j)] *= s;
134 }
135 }
136
137 fn subrow(&mut self, i: usize, j: usize, s: f64) {
139 for k in 0..N {
140 self[(i, k)] -= s * self[(j, k)];
141 }
142 }
143
144 pub fn inverse(&self) -> Self {
146 let mut output = *self;
147 let mut inverse = Self::identity();
148
149 for i in 0..N {
150 let mut j = i;
153 for k in i..N {
154 if output[(k, i)] > output[(i, i)] {
155 j = k;
156 }
157 }
158
159 output.swaprow(i, j);
161 inverse.swaprow(i, j);
162
163 let s = 1.0 / output[(i, i)];
165 output.scalerow(i, s);
166 inverse.scalerow(i, s);
167
168 for k in (i + 1)..N {
170 let s = output[(k, i)];
171 output.subrow(k, i, s);
172 inverse.subrow(k, i, s);
173 }
174 }
175
176 for i in 0..N {
179 for j in (i + 1)..N {
180 let s = output[(i, j)];
181 output.subrow(i, j, s);
182 inverse.subrow(i, j, s);
183 }
184 }
185
186 inverse
187 }
188
189 pub fn scale(&self, scalar: f64) -> Self {
191 let mut newvalues = [[0.0; N]; N];
192 for i in 0..N {
193 for j in 0..N {
194 newvalues[i][j] = scalar * self[(i, j)];
195 }
196 }
197
198 Self {
199 values: newvalues,
200 }
201 }
202}
203
204impl<const N: usize> Index<(usize, usize)> for Matrix<N> {
205 type Output = f64;
206
207 fn index(&self, idx: (usize, usize)) -> &Self::Output {
208 &self.values[idx.0][idx.1]
209 }
210}
211
212impl<const N: usize> IndexMut<(usize, usize)> for Matrix<N> {
213 fn index_mut(&mut self, idx: (usize, usize)) -> &mut Self::Output {
214 &mut self.values[idx.0][idx.1]
215 }
216}
217
218impl<const N: usize> Add for Matrix<N> {
219 type Output = Self;
220
221 fn add(self, other: Self) -> Self {
222 let mut new = Self::zero();
223
224 for i in 0..N {
225 for j in 0..N {
226 new[(i, j)] = self[(i, j)] + other[(i, j)];
227 }
228 }
229
230 new
231 }
232}
233
234impl<const N: usize> Sub for Matrix<N> {
235 type Output = Self;
236
237 fn sub(self, other: Self) -> Self {
238 let mut new = Self::zero();
239
240 for i in 0..N {
241 for j in 0..N {
242 new[(i, j)] = self[(i, j)] - other[(i, j)];
243 }
244 }
245
246 new
247 }
248}
249
250impl<const N: usize> fmt::Display for Matrix<N> {
251 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
252 let mut values = Vec::new();
253 let mut maxlen = 0;
254 for i in 0..N {
255 for j in 0..N {
256 let value = self[(i, j)];
257 let row = if value >= 0.0 {
258 format!(" {:.8}", value)
259 } else {
260 format!("{:.8}", value)
261 };
262 let l = row.len();
263 values.push(row);
264 if l > maxlen {
265 maxlen = l;
266 }
267 }
268 }
269
270 let mut output = String::new();
271 for i in 0..N {
272 output.push_str("[");
273 for j in 0..N {
274 output.push_str(
275 &format!("{:^i$}", values[j + N*i], i = maxlen + 2)
276 );
277 }
278 output.push_str("]\n");
279 }
280
281 write!(f, "{}", output)
282 }
283}
284
285#[test]
286fn matrix_multiply() {
287 let a = Matrix::new([
288 [1.0, 2.0, 3.0],
289 [4.0, 5.0, 6.0],
290 [7.0, 8.0, 9.0],
291 ]);
292
293 let b = Matrix::new([
294 [9.0, 8.0, 7.0],
295 [6.0, 5.0, 4.0],
296 [3.0, 2.0, 1.0],
297 ]);
298
299 let c = Matrix::new([
300 [ 30.0, 24.0, 18.0],
301 [ 84.0, 69.0, 54.0],
302 [138.0, 114.0, 90.0],
303 ]);
304
305 println!("{}", c);
306
307 assert_eq!(a.matmult(b), c);
308}
309
310#[test]
311fn decompose() {
312 let a = Matrix::new([
313 [1.0, 2.0, 3.0],
314 [4.0, 5.0, 6.0],
315 [7.0, 8.0, 9.0],
316 ]);
317
318 let basis: [Vector<3>; 3] = [
319 [1.0, 4.0, 7.0].into(),
320 [2.0, 5.0, 8.0].into(),
321 [3.0, 6.0, 9.0].into(),
322 ];
323
324 assert_eq!(a.column(0), basis[0]);
325}
326
327#[test]
328fn z_rotation_matrix() {
329 let axis = [0.0, 0.0, 1.0].into();
330
331 let rotation = Matrix::<3>::rotation(axis, 30.0 * 3.141592653 / 180.0);
332
333 println!("{}", rotation);
334}