bs_trace/linalg/
matrix.rs

1use crate::internal::{self, uninit};
2use crate::linalg::vector::Vector;
3use num_traits::{Float, One, Zero};
4use std::fmt::{Debug, Display, Formatter};
5use std::iter::Sum;
6use std::ops::{Add, AddAssign, Div, DivAssign, Index, IndexMut, Mul, MulAssign, Sub, SubAssign};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9pub struct Matrix<T, const M: usize, const N: usize> {
10    cols: [Vector<T, M>; N],
11}
12
13impl<T, const M: usize, const N: usize> Matrix<T, M, N> {
14    pub const fn new(cols: [Vector<T, M>; N]) -> Self {
15        Self { cols }
16    }
17
18    pub fn gen<F>(f: F) -> Self
19    where
20        F: Fn(usize, usize) -> T,
21    {
22        let mut col_data = uninit::new_uninit_array::<Vector<T, M>, N>();
23        for (col_idx, col) in col_data.iter_mut().enumerate() {
24            col.write(Vector::gen(|row_idx| f(col_idx, row_idx)));
25        }
26        let col_data = unsafe { uninit::array_assume_init(col_data) };
27        Matrix::new(col_data)
28    }
29
30    pub fn submatrix<const P: usize, const Q: usize>(
31        self,
32        col_offset: usize,
33        row_offset: usize,
34    ) -> Matrix<T, P, Q>
35    where
36        T: Clone,
37    {
38        debug_assert_eq!(P + row_offset, M);
39        debug_assert_eq!(Q + col_offset, N);
40        let mut submatrix_cols = uninit::new_uninit_array::<Vector<T, P>, Q>();
41        for (col_idx, col) in submatrix_cols.iter_mut().enumerate() {
42            let column = Vector::<T, P>::gen(|row_idx| {
43                self[col_idx + col_offset][row_idx + row_offset].clone()
44            });
45            col.write(column);
46        }
47        let submatrix_cols = unsafe { uninit::array_assume_init(submatrix_cols) };
48        Matrix::new(submatrix_cols)
49    }
50
51    pub fn transpose(self) -> Matrix<T, N, M>
52    where
53        T: Clone,
54    {
55        // TODO inplace to remove Clone bound
56        let mut new_cols = uninit::new_uninit_array::<Vector<T, N>, M>();
57        for (row_idx, new_col) in new_cols.iter_mut().enumerate() {
58            new_col.write(Vector::gen(|col_idx| self.cols[col_idx][row_idx].clone()));
59        }
60        let new_cols = unsafe { uninit::array_assume_init(new_cols) };
61        Matrix::new(new_cols)
62    }
63
64    pub fn map<F, U>(self, f: F) -> Matrix<U, M, N>
65    where
66        F: Fn(T) -> U,
67    {
68        let mut new_cols = uninit::new_uninit_array::<Vector<U, M>, N>();
69        for (new_col, old_col) in new_cols.iter_mut().zip(self.cols.into_iter()) {
70            new_col.write(old_col.map(&f));
71        }
72        let new_cols = unsafe { uninit::array_assume_init(new_cols) };
73        Matrix::new(new_cols)
74    }
75
76    pub fn map_mut<F>(&mut self, f: F)
77    where
78        F: Fn(&mut T),
79    {
80        for col in self.cols.iter_mut() {
81            col.map_mut(&f);
82        }
83    }
84
85    pub fn map_column<F>(self, f: F, col_idx: usize) -> Self
86    where
87        F: Fn(T) -> T,
88        T: Debug,
89    {
90        debug_assert!(col_idx < N);
91        let new_cols: Vec<_> = self
92            .cols
93            .into_iter()
94            .enumerate()
95            .map(|(i, col)| if i == col_idx { col.map(&f) } else { col })
96            .collect();
97        let new_cols: [Vector<T, M>; N] = new_cols.try_into().unwrap();
98        Matrix::new(new_cols)
99    }
100
101    pub fn map_column_mut<F>(&mut self, f: F, col_idx: usize)
102    where
103        F: Fn(&mut T),
104        T: Debug,
105    {
106        debug_assert!(col_idx < N);
107        self.cols
108            .iter_mut()
109            .enumerate()
110            .nth(col_idx)
111            .unwrap()
112            .1
113            .map_mut(f);
114    }
115
116    pub fn apply<F, U, V>(self, f: F, rhs: Matrix<U, M, N>) -> Matrix<V, M, N>
117    where
118        F: Fn(T, U) -> V,
119    {
120        let mut new_cols = uninit::new_uninit_array::<Vector<V, M>, N>();
121        for (new_col, (lhs, rhs)) in new_cols
122            .iter_mut()
123            .zip(self.cols.into_iter().zip(rhs.cols.into_iter()))
124        {
125            new_col.write(lhs.apply(&f, rhs));
126        }
127        let new_cols = unsafe { uninit::array_assume_init(new_cols) };
128        Matrix::new(new_cols)
129    }
130
131    pub fn apply_mut<F, U>(&mut self, f: F, rhs: Matrix<U, M, N>)
132    where
133        F: Fn(&mut T, U),
134    {
135        for (lhs_col, rhs_col) in self.cols.iter_mut().zip(rhs.cols.into_iter()) {
136            lhs_col.apply_mut(&f, rhs_col);
137        }
138    }
139
140    pub fn iter(&self) -> <&Self as IntoIterator>::IntoIter {
141        <&Self as IntoIterator>::into_iter(self)
142    }
143
144    pub fn into_iter(self) -> <Self as IntoIterator>::IntoIter {
145        <Self as IntoIterator>::into_iter(self)
146    }
147
148    pub fn iter_mut(&mut self) -> <&mut Self as IntoIterator>::IntoIter {
149        <&mut Self as IntoIterator>::into_iter(self)
150    }
151
152    pub fn row_echelon_form(self) -> Self
153    where
154        T: Float + Debug,
155    {
156        self.transpose().column_echelon_form().transpose()
157    }
158
159    pub fn reduced_row_echelon_form(self) -> Self
160    where
161        T: Float + Debug,
162    {
163        self.transpose().reduced_column_echelon_form().transpose()
164    }
165
166    pub fn column_echelon_form(self) -> Self
167    where
168        T: Float + Debug,
169    {
170        let mut arranged = self.sort_columns_by_leading_coefficient_index();
171        for i in 0..N {
172            let col_leading_coef_idx = arranged[i]
173                .iter()
174                .enumerate()
175                .find(|(_, ele)| ele.is_zero())
176                .map(|(i, _)| i);
177            if let Some(leading_idx) = col_leading_coef_idx {
178                let pivot_col = arranged[i];
179                for (j, col) in arranged.iter_mut().enumerate() {
180                    if i == j {
181                        continue;
182                    }
183                    if !col[leading_idx].is_zero() {
184                        let pivot_ratio = col[leading_idx] / pivot_col[leading_idx];
185                        *col = *col - pivot_col * pivot_ratio;
186                    }
187                }
188            }
189        }
190        arranged
191    }
192
193    pub fn reduced_column_echelon_form(self) -> Self
194    where
195        T: Float + Debug,
196    {
197        let mut cef = self.column_echelon_form();
198        for (col, leading_idx) in cef.leading_coefficient_indices_mut() {
199            if let Some(leading_idx) = leading_idx {
200                let pivot_divisor = col[leading_idx];
201                *col = *col / pivot_divisor;
202            }
203        }
204        cef
205    }
206
207    pub fn sort_columns_by_leading_coefficient_index(self) -> Self
208    where
209        T: Zero + Clone + Debug,
210    {
211        let mut leading_coef_idxs: Vec<_> = self.leading_coefficient_indices().collect();
212        leading_coef_idxs
213            .sort_by(|(_, idx1), (_, idx2)| internal::option::option_ordering_max_none(idx1, idx2));
214        let sorted_cols: [Vector<T, M>; N] = leading_coef_idxs
215            .into_iter()
216            .map(|(col, _)| col)
217            .cloned()
218            .collect::<Vec<_>>()
219            .try_into()
220            .unwrap();
221        Matrix::new(sorted_cols)
222    }
223
224    fn leading_coefficient_indices(&self) -> impl Iterator<Item = (&Vector<T, M>, Option<usize>)>
225    where
226        T: Zero,
227    {
228        self.cols.iter().map(|col| {
229            let idx = col
230                .iter()
231                .enumerate()
232                .skip_while(|(_, ele)| ele.is_zero())
233                .map(|(idx, _)| idx)
234                .next();
235            (col, idx)
236        })
237    }
238
239    fn leading_coefficient_indices_mut(
240        &mut self,
241    ) -> impl Iterator<Item = (&mut Vector<T, M>, Option<usize>)>
242    where
243        T: Zero,
244    {
245        self.cols.iter_mut().map(|col| {
246            let idx = col
247                .iter()
248                .enumerate()
249                .skip_while(|(_, ele)| ele.is_zero())
250                .map(|(idx, _)| idx)
251                .next();
252            (col, idx)
253        })
254    }
255
256    pub fn swap_columns(&mut self, col1_idx: usize, col2_idx: usize) {
257        debug_assert!(col1_idx < N);
258        debug_assert!(col2_idx < N);
259        self.cols.swap(col1_idx, col2_idx);
260    }
261}
262
263impl<T, const N: usize> Matrix<T, N, 1> {
264    pub fn new_column(col_data: [T; N]) -> Self {
265        Vector::new(col_data).into()
266    }
267
268    pub(super) fn into_vector(self) -> Vector<T, N> {
269        self.cols.into_iter().next().unwrap()
270    }
271}
272
273impl<T, const N: usize> Matrix<T, 1, N> {
274    pub fn new_row(row_data: [T; N]) -> Self {
275        let mut cols = uninit::new_uninit_array::<Vector<T, 1>, N>();
276        for (col, datum) in cols.iter_mut().zip(row_data.into_iter()) {
277            col.write(Vector::new1(datum));
278        }
279        let cols = unsafe { uninit::array_assume_init(cols) };
280        Matrix::new(cols)
281    }
282}
283
284impl<T, const M: usize, const N: usize> Add<Matrix<T, M, N>> for Matrix<T, M, N>
285where
286    T: Add<T, Output = T>,
287{
288    type Output = Matrix<T, M, N>;
289
290    fn add(self, rhs: Matrix<T, M, N>) -> Self::Output {
291        self.apply(T::add, rhs)
292    }
293}
294
295impl<T, const M: usize, const N: usize> AddAssign<Matrix<T, M, N>> for Matrix<T, M, N>
296where
297    T: AddAssign<T>,
298{
299    fn add_assign(&mut self, rhs: Matrix<T, M, N>) {
300        self.apply_mut(T::add_assign, rhs);
301    }
302}
303
304impl<T, const M: usize, const N: usize> Sub<Matrix<T, M, N>> for Matrix<T, M, N>
305where
306    T: Sub<T, Output = T>,
307{
308    type Output = Matrix<T, M, N>;
309
310    fn sub(self, rhs: Matrix<T, M, N>) -> Self::Output {
311        self.apply(T::sub, rhs)
312    }
313}
314
315impl<T, const M: usize, const N: usize> SubAssign<Matrix<T, M, N>> for Matrix<T, M, N>
316where
317    T: SubAssign<T>,
318{
319    fn sub_assign(&mut self, rhs: Matrix<T, M, N>) {
320        self.apply_mut(T::sub_assign, rhs)
321    }
322}
323
324impl<T, const M: usize, const N: usize> Mul<T> for Matrix<T, M, N>
325where
326    T: Mul<T, Output = T> + Clone,
327{
328    type Output = Matrix<T, M, N>;
329
330    fn mul(self, rhs: T) -> Self::Output {
331        self.map(|x| x * rhs.clone())
332    }
333}
334
335impl<T, const M: usize, const N: usize, const P: usize> Mul<Matrix<T, N, P>> for Matrix<T, M, N>
336where
337    T: Clone + Mul<T, Output = T> + Sum,
338{
339    type Output = Matrix<T, M, P>;
340
341    fn mul(self, rhs: Matrix<T, N, P>) -> Self::Output {
342        let tp = self.transpose();
343        let mut new_cols = uninit::new_uninit_array::<Vector<T, M>, P>();
344        for (new_col, rhs) in new_cols.iter_mut().zip(rhs.cols.into_iter()) {
345            let mut new_col_data = uninit::new_uninit_array::<T, M>();
346            for (new_col_datum, lhs) in new_col_data.iter_mut().zip(tp.cols.iter().cloned()) {
347                new_col_datum.write(lhs.clone().dot(rhs.clone()));
348            }
349            let new_col_data = unsafe { uninit::array_assume_init(new_col_data) };
350            new_col.write(Vector::new(new_col_data));
351        }
352        let new_cols = unsafe { uninit::array_assume_init(new_cols) };
353        Matrix::new(new_cols)
354    }
355}
356
357impl<T, const M: usize, const N: usize> Mul<Vector<T, N>> for Matrix<T, M, N>
358where
359    T: Clone + Mul<T, Output = T> + Sum,
360{
361    type Output = Vector<T, M>;
362
363    fn mul(self, rhs: Vector<T, N>) -> Self::Output {
364        (self * Matrix::from(rhs)).into()
365    }
366}
367
368impl<T, const M: usize, const N: usize> Div<T> for Matrix<T, M, N>
369where
370    T: Div<T, Output = T> + Clone,
371{
372    type Output = Matrix<T, M, N>;
373
374    fn div(self, rhs: T) -> Self::Output {
375        self.map(|x| x / rhs.clone())
376    }
377}
378
379impl<T, const M: usize, const N: usize> DivAssign<T> for Matrix<T, M, N>
380where
381    T: DivAssign<T> + Clone,
382{
383    fn div_assign(&mut self, rhs: T) {
384        self.map_mut(|x| *x /= rhs.clone());
385    }
386}
387
388impl<T, const M: usize, const N: usize> Zero for Matrix<T, M, N>
389where
390    T: Zero,
391{
392    fn zero() -> Self {
393        Matrix::gen(|_, _| T::zero())
394    }
395
396    fn is_zero(&self) -> bool {
397        self.cols.iter().all(|col| col.iter().all(|x| x.is_zero()))
398    }
399}
400
401impl<T, const M: usize> One for Matrix<T, M, M>
402where
403    T: Clone + Mul<T, Output = T> + Sum + One + Zero,
404{
405    fn one() -> Self {
406        Matrix::gen(|col_idx, row_idx| {
407            if col_idx == row_idx {
408                T::one()
409            } else {
410                T::zero()
411            }
412        })
413    }
414}
415
416impl<T, const M: usize, const N: usize> IntoIterator for Matrix<T, M, N> {
417    type Item = Vector<T, M>;
418    type IntoIter = std::array::IntoIter<Vector<T, M>, N>;
419
420    fn into_iter(self) -> Self::IntoIter {
421        self.cols.into_iter()
422    }
423}
424
425impl<'a, T, const M: usize, const N: usize> IntoIterator for &'a Matrix<T, M, N> {
426    type Item = &'a Vector<T, M>;
427    type IntoIter = std::slice::Iter<'a, Vector<T, M>>;
428
429    fn into_iter(self) -> Self::IntoIter {
430        self.cols.iter()
431    }
432}
433
434impl<'a, T, const M: usize, const N: usize> IntoIterator for &'a mut Matrix<T, M, N> {
435    type Item = &'a mut Vector<T, M>;
436    type IntoIter = std::slice::IterMut<'a, Vector<T, M>>;
437    fn into_iter(self) -> Self::IntoIter {
438        self.cols.iter_mut()
439    }
440}
441
442impl<T, const M: usize, const N: usize> Index<usize> for Matrix<T, M, N> {
443    type Output = Vector<T, M>;
444
445    fn index(&self, index: usize) -> &Self::Output {
446        &self.cols[index]
447    }
448}
449
450impl<T, const M: usize, const N: usize> IndexMut<usize> for Matrix<T, M, N> {
451    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
452        &mut self.cols[index]
453    }
454}
455
456impl<T, const N: usize> From<Vector<T, N>> for Matrix<T, N, 1> {
457    fn from(value: Vector<T, N>) -> Self {
458        value.into_matrix()
459    }
460}
461
462impl<T, const M: usize, const N: usize> From<[[T; M]; N]> for Matrix<T, M, N> {
463    fn from(value: [[T; M]; N]) -> Self {
464        let cols = value.map(|row| Vector::new(row));
465        Matrix::new(cols)
466    }
467}
468
469impl<T, const M: usize, const N: usize> Display for Matrix<T, M, N>
470where
471    T: Display,
472{
473    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
474        for row_idx in 0..M {
475            for col_idx in 0..N {
476                write!(f, "{}\t", self.cols[col_idx][row_idx])?;
477            }
478            writeln!(f)?;
479        }
480        Ok(())
481    }
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487
488    #[test]
489    fn check_add() {
490        let lhs = Matrix::<isize, 3, 4>::from([[2, 3, 5], [7, 11, 13], [17, 19, 23], [29, 31, 37]]);
491        let rhs =
492            Matrix::<isize, 3, 4>::from([[41, 43, 47], [53, 59, 61], [67, 71, 73], [79, 83, 89]]);
493        let expected_sum = Matrix::<isize, 3, 4>::from([
494            [43, 46, 52],
495            [60, 70, 74],
496            [84, 90, 96],
497            [108, 114, 126],
498        ]);
499        let actual_sum = lhs + rhs;
500        assert_eq!(expected_sum, actual_sum);
501    }
502
503    #[test]
504    fn check_sub() {
505        let lhs = Matrix::<isize, 2, 3>::from([[97, 101], [103, 107], [109, 113]]);
506        let rhs = Matrix::<isize, 2, 3>::from([[127, 131], [137, 139], [149, 151]]);
507        let expected_diff = Matrix::<isize, 2, 3>::from([[-30, -30], [-34, -32], [-40, -38]]);
508        let actual_diff = lhs - rhs;
509        assert_eq!(expected_diff, actual_diff);
510    }
511
512    #[test]
513    fn check_mul() {
514        let lhs = Matrix::<isize, 4, 2>::from([[2, 3, 5, 7], [11, 13, 17, 19]]);
515        let rhs = Matrix::<isize, 2, 3>::from([[23, 29], [31, 37], [41, 43]]);
516        let expected_prod = Matrix::<isize, 4, 3>::from([
517            [365, 446, 608, 712],
518            [469, 574, 784, 920],
519            [555, 682, 936, 1104],
520        ]);
521        let actual_prod = lhs * rhs;
522        assert_eq!(expected_prod, actual_prod);
523    }
524
525    #[test]
526    fn check_transpose() {
527        let mat = Matrix::<isize, 3, 4>::from([[2, 3, 5], [7, 11, 13], [17, 19, 23], [29, 31, 37]]);
528        let expected_transpose =
529            Matrix::<isize, 4, 3>::from([[2, 7, 17, 29], [3, 11, 19, 31], [5, 13, 23, 37]]);
530        let actual_transpose = mat.transpose();
531        println!("{}", mat);
532        println!("{}", actual_transpose);
533        assert_eq!(expected_transpose, actual_transpose);
534    }
535
536    #[test]
537    fn check_rref() {
538        let mat = Matrix::<f64, 3, 4>::from([
539            [2.0, -3.0, -2.0],
540            [1.0, -1.0, 1.0],
541            [-1.0, 2.0, 2.0],
542            [8.0, -11.0, -3.0],
543        ]);
544        let expected_rref = Matrix::<f64, 3, 4>::from([
545            [1.0, 0.0, 0.0],
546            [0.0, 1.0, 0.0],
547            [0.0, 0.0, 1.0],
548            [2.0, 3.0, -1.0],
549        ]);
550        let actual_rref = mat.reduced_row_echelon_form();
551        assert_eq!(expected_rref, actual_rref);
552    }
553}