1use crate::internal::{self, uninit};
2use crate::linalg::vector::Vector;
3use num_traits::{Float, One, Zero};
4use std::fmt::{Debug, Display, Formatter};
5use std::iter::Sum;
6use std::ops::{Add, AddAssign, Div, DivAssign, Index, IndexMut, Mul, MulAssign, Sub, SubAssign};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9pub struct Matrix<T, const M: usize, const N: usize> {
10 cols: [Vector<T, M>; N],
11}
12
13impl<T, const M: usize, const N: usize> Matrix<T, M, N> {
14 pub const fn new(cols: [Vector<T, M>; N]) -> Self {
15 Self { cols }
16 }
17
18 pub fn gen<F>(f: F) -> Self
19 where
20 F: Fn(usize, usize) -> T,
21 {
22 let mut col_data = uninit::new_uninit_array::<Vector<T, M>, N>();
23 for (col_idx, col) in col_data.iter_mut().enumerate() {
24 col.write(Vector::gen(|row_idx| f(col_idx, row_idx)));
25 }
26 let col_data = unsafe { uninit::array_assume_init(col_data) };
27 Matrix::new(col_data)
28 }
29
30 pub fn submatrix<const P: usize, const Q: usize>(
31 self,
32 col_offset: usize,
33 row_offset: usize,
34 ) -> Matrix<T, P, Q>
35 where
36 T: Clone,
37 {
38 debug_assert_eq!(P + row_offset, M);
39 debug_assert_eq!(Q + col_offset, N);
40 let mut submatrix_cols = uninit::new_uninit_array::<Vector<T, P>, Q>();
41 for (col_idx, col) in submatrix_cols.iter_mut().enumerate() {
42 let column = Vector::<T, P>::gen(|row_idx| {
43 self[col_idx + col_offset][row_idx + row_offset].clone()
44 });
45 col.write(column);
46 }
47 let submatrix_cols = unsafe { uninit::array_assume_init(submatrix_cols) };
48 Matrix::new(submatrix_cols)
49 }
50
51 pub fn transpose(self) -> Matrix<T, N, M>
52 where
53 T: Clone,
54 {
55 let mut new_cols = uninit::new_uninit_array::<Vector<T, N>, M>();
57 for (row_idx, new_col) in new_cols.iter_mut().enumerate() {
58 new_col.write(Vector::gen(|col_idx| self.cols[col_idx][row_idx].clone()));
59 }
60 let new_cols = unsafe { uninit::array_assume_init(new_cols) };
61 Matrix::new(new_cols)
62 }
63
64 pub fn map<F, U>(self, f: F) -> Matrix<U, M, N>
65 where
66 F: Fn(T) -> U,
67 {
68 let mut new_cols = uninit::new_uninit_array::<Vector<U, M>, N>();
69 for (new_col, old_col) in new_cols.iter_mut().zip(self.cols.into_iter()) {
70 new_col.write(old_col.map(&f));
71 }
72 let new_cols = unsafe { uninit::array_assume_init(new_cols) };
73 Matrix::new(new_cols)
74 }
75
76 pub fn map_mut<F>(&mut self, f: F)
77 where
78 F: Fn(&mut T),
79 {
80 for col in self.cols.iter_mut() {
81 col.map_mut(&f);
82 }
83 }
84
85 pub fn map_column<F>(self, f: F, col_idx: usize) -> Self
86 where
87 F: Fn(T) -> T,
88 T: Debug,
89 {
90 debug_assert!(col_idx < N);
91 let new_cols: Vec<_> = self
92 .cols
93 .into_iter()
94 .enumerate()
95 .map(|(i, col)| if i == col_idx { col.map(&f) } else { col })
96 .collect();
97 let new_cols: [Vector<T, M>; N] = new_cols.try_into().unwrap();
98 Matrix::new(new_cols)
99 }
100
101 pub fn map_column_mut<F>(&mut self, f: F, col_idx: usize)
102 where
103 F: Fn(&mut T),
104 T: Debug,
105 {
106 debug_assert!(col_idx < N);
107 self.cols
108 .iter_mut()
109 .enumerate()
110 .nth(col_idx)
111 .unwrap()
112 .1
113 .map_mut(f);
114 }
115
116 pub fn apply<F, U, V>(self, f: F, rhs: Matrix<U, M, N>) -> Matrix<V, M, N>
117 where
118 F: Fn(T, U) -> V,
119 {
120 let mut new_cols = uninit::new_uninit_array::<Vector<V, M>, N>();
121 for (new_col, (lhs, rhs)) in new_cols
122 .iter_mut()
123 .zip(self.cols.into_iter().zip(rhs.cols.into_iter()))
124 {
125 new_col.write(lhs.apply(&f, rhs));
126 }
127 let new_cols = unsafe { uninit::array_assume_init(new_cols) };
128 Matrix::new(new_cols)
129 }
130
131 pub fn apply_mut<F, U>(&mut self, f: F, rhs: Matrix<U, M, N>)
132 where
133 F: Fn(&mut T, U),
134 {
135 for (lhs_col, rhs_col) in self.cols.iter_mut().zip(rhs.cols.into_iter()) {
136 lhs_col.apply_mut(&f, rhs_col);
137 }
138 }
139
140 pub fn iter(&self) -> <&Self as IntoIterator>::IntoIter {
141 <&Self as IntoIterator>::into_iter(self)
142 }
143
144 pub fn into_iter(self) -> <Self as IntoIterator>::IntoIter {
145 <Self as IntoIterator>::into_iter(self)
146 }
147
148 pub fn iter_mut(&mut self) -> <&mut Self as IntoIterator>::IntoIter {
149 <&mut Self as IntoIterator>::into_iter(self)
150 }
151
152 pub fn row_echelon_form(self) -> Self
153 where
154 T: Float + Debug,
155 {
156 self.transpose().column_echelon_form().transpose()
157 }
158
159 pub fn reduced_row_echelon_form(self) -> Self
160 where
161 T: Float + Debug,
162 {
163 self.transpose().reduced_column_echelon_form().transpose()
164 }
165
166 pub fn column_echelon_form(self) -> Self
167 where
168 T: Float + Debug,
169 {
170 let mut arranged = self.sort_columns_by_leading_coefficient_index();
171 for i in 0..N {
172 let col_leading_coef_idx = arranged[i]
173 .iter()
174 .enumerate()
175 .find(|(_, ele)| ele.is_zero())
176 .map(|(i, _)| i);
177 if let Some(leading_idx) = col_leading_coef_idx {
178 let pivot_col = arranged[i];
179 for (j, col) in arranged.iter_mut().enumerate() {
180 if i == j {
181 continue;
182 }
183 if !col[leading_idx].is_zero() {
184 let pivot_ratio = col[leading_idx] / pivot_col[leading_idx];
185 *col = *col - pivot_col * pivot_ratio;
186 }
187 }
188 }
189 }
190 arranged
191 }
192
193 pub fn reduced_column_echelon_form(self) -> Self
194 where
195 T: Float + Debug,
196 {
197 let mut cef = self.column_echelon_form();
198 for (col, leading_idx) in cef.leading_coefficient_indices_mut() {
199 if let Some(leading_idx) = leading_idx {
200 let pivot_divisor = col[leading_idx];
201 *col = *col / pivot_divisor;
202 }
203 }
204 cef
205 }
206
207 pub fn sort_columns_by_leading_coefficient_index(self) -> Self
208 where
209 T: Zero + Clone + Debug,
210 {
211 let mut leading_coef_idxs: Vec<_> = self.leading_coefficient_indices().collect();
212 leading_coef_idxs
213 .sort_by(|(_, idx1), (_, idx2)| internal::option::option_ordering_max_none(idx1, idx2));
214 let sorted_cols: [Vector<T, M>; N] = leading_coef_idxs
215 .into_iter()
216 .map(|(col, _)| col)
217 .cloned()
218 .collect::<Vec<_>>()
219 .try_into()
220 .unwrap();
221 Matrix::new(sorted_cols)
222 }
223
224 fn leading_coefficient_indices(&self) -> impl Iterator<Item = (&Vector<T, M>, Option<usize>)>
225 where
226 T: Zero,
227 {
228 self.cols.iter().map(|col| {
229 let idx = col
230 .iter()
231 .enumerate()
232 .skip_while(|(_, ele)| ele.is_zero())
233 .map(|(idx, _)| idx)
234 .next();
235 (col, idx)
236 })
237 }
238
239 fn leading_coefficient_indices_mut(
240 &mut self,
241 ) -> impl Iterator<Item = (&mut Vector<T, M>, Option<usize>)>
242 where
243 T: Zero,
244 {
245 self.cols.iter_mut().map(|col| {
246 let idx = col
247 .iter()
248 .enumerate()
249 .skip_while(|(_, ele)| ele.is_zero())
250 .map(|(idx, _)| idx)
251 .next();
252 (col, idx)
253 })
254 }
255
256 pub fn swap_columns(&mut self, col1_idx: usize, col2_idx: usize) {
257 debug_assert!(col1_idx < N);
258 debug_assert!(col2_idx < N);
259 self.cols.swap(col1_idx, col2_idx);
260 }
261}
262
263impl<T, const N: usize> Matrix<T, N, 1> {
264 pub fn new_column(col_data: [T; N]) -> Self {
265 Vector::new(col_data).into()
266 }
267
268 pub(super) fn into_vector(self) -> Vector<T, N> {
269 self.cols.into_iter().next().unwrap()
270 }
271}
272
273impl<T, const N: usize> Matrix<T, 1, N> {
274 pub fn new_row(row_data: [T; N]) -> Self {
275 let mut cols = uninit::new_uninit_array::<Vector<T, 1>, N>();
276 for (col, datum) in cols.iter_mut().zip(row_data.into_iter()) {
277 col.write(Vector::new1(datum));
278 }
279 let cols = unsafe { uninit::array_assume_init(cols) };
280 Matrix::new(cols)
281 }
282}
283
284impl<T, const M: usize, const N: usize> Add<Matrix<T, M, N>> for Matrix<T, M, N>
285where
286 T: Add<T, Output = T>,
287{
288 type Output = Matrix<T, M, N>;
289
290 fn add(self, rhs: Matrix<T, M, N>) -> Self::Output {
291 self.apply(T::add, rhs)
292 }
293}
294
295impl<T, const M: usize, const N: usize> AddAssign<Matrix<T, M, N>> for Matrix<T, M, N>
296where
297 T: AddAssign<T>,
298{
299 fn add_assign(&mut self, rhs: Matrix<T, M, N>) {
300 self.apply_mut(T::add_assign, rhs);
301 }
302}
303
304impl<T, const M: usize, const N: usize> Sub<Matrix<T, M, N>> for Matrix<T, M, N>
305where
306 T: Sub<T, Output = T>,
307{
308 type Output = Matrix<T, M, N>;
309
310 fn sub(self, rhs: Matrix<T, M, N>) -> Self::Output {
311 self.apply(T::sub, rhs)
312 }
313}
314
315impl<T, const M: usize, const N: usize> SubAssign<Matrix<T, M, N>> for Matrix<T, M, N>
316where
317 T: SubAssign<T>,
318{
319 fn sub_assign(&mut self, rhs: Matrix<T, M, N>) {
320 self.apply_mut(T::sub_assign, rhs)
321 }
322}
323
324impl<T, const M: usize, const N: usize> Mul<T> for Matrix<T, M, N>
325where
326 T: Mul<T, Output = T> + Clone,
327{
328 type Output = Matrix<T, M, N>;
329
330 fn mul(self, rhs: T) -> Self::Output {
331 self.map(|x| x * rhs.clone())
332 }
333}
334
335impl<T, const M: usize, const N: usize, const P: usize> Mul<Matrix<T, N, P>> for Matrix<T, M, N>
336where
337 T: Clone + Mul<T, Output = T> + Sum,
338{
339 type Output = Matrix<T, M, P>;
340
341 fn mul(self, rhs: Matrix<T, N, P>) -> Self::Output {
342 let tp = self.transpose();
343 let mut new_cols = uninit::new_uninit_array::<Vector<T, M>, P>();
344 for (new_col, rhs) in new_cols.iter_mut().zip(rhs.cols.into_iter()) {
345 let mut new_col_data = uninit::new_uninit_array::<T, M>();
346 for (new_col_datum, lhs) in new_col_data.iter_mut().zip(tp.cols.iter().cloned()) {
347 new_col_datum.write(lhs.clone().dot(rhs.clone()));
348 }
349 let new_col_data = unsafe { uninit::array_assume_init(new_col_data) };
350 new_col.write(Vector::new(new_col_data));
351 }
352 let new_cols = unsafe { uninit::array_assume_init(new_cols) };
353 Matrix::new(new_cols)
354 }
355}
356
357impl<T, const M: usize, const N: usize> Mul<Vector<T, N>> for Matrix<T, M, N>
358where
359 T: Clone + Mul<T, Output = T> + Sum,
360{
361 type Output = Vector<T, M>;
362
363 fn mul(self, rhs: Vector<T, N>) -> Self::Output {
364 (self * Matrix::from(rhs)).into()
365 }
366}
367
368impl<T, const M: usize, const N: usize> Div<T> for Matrix<T, M, N>
369where
370 T: Div<T, Output = T> + Clone,
371{
372 type Output = Matrix<T, M, N>;
373
374 fn div(self, rhs: T) -> Self::Output {
375 self.map(|x| x / rhs.clone())
376 }
377}
378
379impl<T, const M: usize, const N: usize> DivAssign<T> for Matrix<T, M, N>
380where
381 T: DivAssign<T> + Clone,
382{
383 fn div_assign(&mut self, rhs: T) {
384 self.map_mut(|x| *x /= rhs.clone());
385 }
386}
387
388impl<T, const M: usize, const N: usize> Zero for Matrix<T, M, N>
389where
390 T: Zero,
391{
392 fn zero() -> Self {
393 Matrix::gen(|_, _| T::zero())
394 }
395
396 fn is_zero(&self) -> bool {
397 self.cols.iter().all(|col| col.iter().all(|x| x.is_zero()))
398 }
399}
400
401impl<T, const M: usize> One for Matrix<T, M, M>
402where
403 T: Clone + Mul<T, Output = T> + Sum + One + Zero,
404{
405 fn one() -> Self {
406 Matrix::gen(|col_idx, row_idx| {
407 if col_idx == row_idx {
408 T::one()
409 } else {
410 T::zero()
411 }
412 })
413 }
414}
415
416impl<T, const M: usize, const N: usize> IntoIterator for Matrix<T, M, N> {
417 type Item = Vector<T, M>;
418 type IntoIter = std::array::IntoIter<Vector<T, M>, N>;
419
420 fn into_iter(self) -> Self::IntoIter {
421 self.cols.into_iter()
422 }
423}
424
425impl<'a, T, const M: usize, const N: usize> IntoIterator for &'a Matrix<T, M, N> {
426 type Item = &'a Vector<T, M>;
427 type IntoIter = std::slice::Iter<'a, Vector<T, M>>;
428
429 fn into_iter(self) -> Self::IntoIter {
430 self.cols.iter()
431 }
432}
433
434impl<'a, T, const M: usize, const N: usize> IntoIterator for &'a mut Matrix<T, M, N> {
435 type Item = &'a mut Vector<T, M>;
436 type IntoIter = std::slice::IterMut<'a, Vector<T, M>>;
437 fn into_iter(self) -> Self::IntoIter {
438 self.cols.iter_mut()
439 }
440}
441
442impl<T, const M: usize, const N: usize> Index<usize> for Matrix<T, M, N> {
443 type Output = Vector<T, M>;
444
445 fn index(&self, index: usize) -> &Self::Output {
446 &self.cols[index]
447 }
448}
449
450impl<T, const M: usize, const N: usize> IndexMut<usize> for Matrix<T, M, N> {
451 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
452 &mut self.cols[index]
453 }
454}
455
456impl<T, const N: usize> From<Vector<T, N>> for Matrix<T, N, 1> {
457 fn from(value: Vector<T, N>) -> Self {
458 value.into_matrix()
459 }
460}
461
462impl<T, const M: usize, const N: usize> From<[[T; M]; N]> for Matrix<T, M, N> {
463 fn from(value: [[T; M]; N]) -> Self {
464 let cols = value.map(|row| Vector::new(row));
465 Matrix::new(cols)
466 }
467}
468
469impl<T, const M: usize, const N: usize> Display for Matrix<T, M, N>
470where
471 T: Display,
472{
473 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
474 for row_idx in 0..M {
475 for col_idx in 0..N {
476 write!(f, "{}\t", self.cols[col_idx][row_idx])?;
477 }
478 writeln!(f)?;
479 }
480 Ok(())
481 }
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487
488 #[test]
489 fn check_add() {
490 let lhs = Matrix::<isize, 3, 4>::from([[2, 3, 5], [7, 11, 13], [17, 19, 23], [29, 31, 37]]);
491 let rhs =
492 Matrix::<isize, 3, 4>::from([[41, 43, 47], [53, 59, 61], [67, 71, 73], [79, 83, 89]]);
493 let expected_sum = Matrix::<isize, 3, 4>::from([
494 [43, 46, 52],
495 [60, 70, 74],
496 [84, 90, 96],
497 [108, 114, 126],
498 ]);
499 let actual_sum = lhs + rhs;
500 assert_eq!(expected_sum, actual_sum);
501 }
502
503 #[test]
504 fn check_sub() {
505 let lhs = Matrix::<isize, 2, 3>::from([[97, 101], [103, 107], [109, 113]]);
506 let rhs = Matrix::<isize, 2, 3>::from([[127, 131], [137, 139], [149, 151]]);
507 let expected_diff = Matrix::<isize, 2, 3>::from([[-30, -30], [-34, -32], [-40, -38]]);
508 let actual_diff = lhs - rhs;
509 assert_eq!(expected_diff, actual_diff);
510 }
511
512 #[test]
513 fn check_mul() {
514 let lhs = Matrix::<isize, 4, 2>::from([[2, 3, 5, 7], [11, 13, 17, 19]]);
515 let rhs = Matrix::<isize, 2, 3>::from([[23, 29], [31, 37], [41, 43]]);
516 let expected_prod = Matrix::<isize, 4, 3>::from([
517 [365, 446, 608, 712],
518 [469, 574, 784, 920],
519 [555, 682, 936, 1104],
520 ]);
521 let actual_prod = lhs * rhs;
522 assert_eq!(expected_prod, actual_prod);
523 }
524
525 #[test]
526 fn check_transpose() {
527 let mat = Matrix::<isize, 3, 4>::from([[2, 3, 5], [7, 11, 13], [17, 19, 23], [29, 31, 37]]);
528 let expected_transpose =
529 Matrix::<isize, 4, 3>::from([[2, 7, 17, 29], [3, 11, 19, 31], [5, 13, 23, 37]]);
530 let actual_transpose = mat.transpose();
531 println!("{}", mat);
532 println!("{}", actual_transpose);
533 assert_eq!(expected_transpose, actual_transpose);
534 }
535
536 #[test]
537 fn check_rref() {
538 let mat = Matrix::<f64, 3, 4>::from([
539 [2.0, -3.0, -2.0],
540 [1.0, -1.0, 1.0],
541 [-1.0, 2.0, 2.0],
542 [8.0, -11.0, -3.0],
543 ]);
544 let expected_rref = Matrix::<f64, 3, 4>::from([
545 [1.0, 0.0, 0.0],
546 [0.0, 1.0, 0.0],
547 [0.0, 0.0, 1.0],
548 [2.0, 3.0, -1.0],
549 ]);
550 let actual_rref = mat.reduced_row_echelon_form();
551 assert_eq!(expected_rref, actual_rref);
552 }
553}