matrix_rs/
lib.rs

1#![feature(generic_const_exprs)]
2
3use checks::{usize::Zero, Failed};
4
5#[derive(Debug, PartialEq, Copy, Clone)]
6pub struct Matrix<const R: usize, const C: usize>([[f32; C]; R])
7where
8    Zero<R>: Failed,
9    Zero<C>: Failed;
10
11pub type SquareMatrix<const D: usize> = Matrix<D, D>;
12pub type VecMatrix = Vec<Vec<f32>>;
13
14#[macro_export]
15macro_rules! matrix {
16    () => (compile_error!("Empty matrix not allowed"));
17    ($($($value:expr)*),*) => {
18        Matrix::from([
19            $([$($value),*],)*
20        ])
21    };
22}
23
24impl<const R: usize, const C: usize> Matrix<R, C>
25where
26    Zero<R>: Failed,
27    Zero<C>: Failed,
28{
29    pub fn new(closure: impl Fn(usize, usize) -> f32) -> Self {
30        Self(std::array::from_fn(|row| {
31            std::array::from_fn(|column| closure(row, column))
32        }))
33    }
34
35    pub fn zero() -> Self {
36        Self::new(|_, _| 0.0)
37    }
38    pub const fn is_square(&self) -> bool {
39        R == C
40    }
41
42    pub fn rows(&self) -> [[f32; C]; R] {
43        self.0
44    }
45    pub fn columns(&self) -> [[f32; R]; C] {
46        std::array::from_fn(|i| self.rows().map(|row| row[i]))
47    }
48
49    pub fn transpose(&self) -> Matrix<C, R> {
50        Matrix::from(self.columns())
51    }
52
53    pub fn map<F>(&self, f: F) -> Self
54    where
55        F: Fn(f32) -> f32,
56    {
57        Self::new(|r, c| f(self[r][c]))
58    }
59
60    pub fn merge<F>(&self, other: Matrix<R, C>, f: F) -> Self
61    where
62        F: Fn(f32, f32) -> f32,
63    {
64        Self::new(|r, c| f(self[r][c], other[r][c]))
65    }
66}
67
68impl<const D: usize> SquareMatrix<D>
69where
70    Zero<D>: Failed,
71{
72    pub fn identity() -> Self {
73        Self::new(|row, column| if row == column { 1.0 } else { 0.0 })
74    }
75
76    pub fn determinant(&self) -> f32 {
77        Self::determinant_vec_impl(&self.into())
78    }
79
80    pub fn has_inverse(&self) -> bool {
81        self.determinant() != 0.0
82    }
83
84    pub fn inverse(&self) -> Option<Self> {
85        let det = self.determinant();
86        if det == 0.0 {
87            None
88        } else {
89            todo!()
90        }
91    }
92
93    fn determinant_vec_impl(vec: &VecMatrix) -> f32 {
94        let side_len = vec.len();
95        match side_len {
96            0 => 1.0,
97            1 => vec[0][0],
98            2 => (vec[0][0] * vec[1][1]) - (vec[0][1] * vec[1][0]),
99            _ => {
100                let mut det = 0.0;
101                let main_row = &vec[0];
102                for i in 0..vec.len() {
103                    let to = side_len - 1;
104                    let sub: VecMatrix = (0..to)
105                        .map(|ri| {
106                            (0..to)
107                                .map(|ci| {
108                                    let row = &vec[ri + 1];
109                                    row[if ci >= i { ci + 1 } else { ci }]
110                                })
111                                .collect()
112                        })
113                        .collect();
114                    det += (main_row[i] * Self::determinant_vec_impl(&sub))
115                        * (if i % 2 == 0 { 1.0 } else { -1.0 })
116                }
117                det
118            }
119        }
120    }
121}
122
123macro_rules! matrix_merge_op {
124    ($type:path => $op:tt) => {
125        impl<const R: usize, const C: usize> $type for Matrix<R, C>
126        where
127            Zero<R>: Failed,
128            Zero<C>: Failed,
129        {
130            type Output = Self;
131
132            fn $op(self, rhs: Self) -> Self::Output {
133                self.merge(rhs, |a, b| a.$op(b))
134            }
135        }
136    };
137}
138
139matrix_merge_op!(std::ops::Add => add);
140matrix_merge_op!(std::ops::Sub => sub);
141
142impl<const R: usize, const C: usize, const C2: usize> std::ops::Mul<Matrix<C, C2>> for Matrix<R, C>
143where
144    Zero<R>: Failed,
145    Zero<C>: Failed,
146    Zero<C2>: Failed,
147{
148    type Output = Matrix<R, C2>;
149
150    fn mul(self, other: Matrix<C, C2>) -> Self::Output {
151        Matrix::new(|ri, ci| {
152            let row = self.rows()[ri];
153            let column = other.columns()[ci];
154            let mut sum = 0.0;
155            for i in 0..C {
156                sum += row[i] * column[i];
157            }
158            sum
159        })
160    }
161}
162
163impl<const R: usize, const C: usize> std::ops::Mul<f32> for Matrix<R, C>
164where
165    Zero<R>: Failed,
166    Zero<C>: Failed,
167{
168    type Output = Self;
169    fn mul(self, rhs: f32) -> Self::Output {
170        self.map(|v| v * rhs)
171    }
172}
173
174macro_rules! matrix_from_2d_num_array {
175    ($($num:ty)*) => ($(
176        impl<const R: usize, const C: usize> From<[[$num; C]; R]> for Matrix<R, C>
177        where
178            Zero<R>: Failed,
179            Zero<C>: Failed
180        {
181            fn from(value: [[$num; C]; R]) -> Self {
182                Self(value.map(|a| a.map(|b| b as f32)))
183            }
184        }
185    )*)
186}
187
188matrix_from_2d_num_array!(f32 i32 usize);
189
190impl<const R: usize, const C: usize> From<&Matrix<R, C>> for VecMatrix
191where
192    Zero<R>: Failed,
193    Zero<C>: Failed,
194{
195    fn from(val: &Matrix<R, C>) -> Self {
196        val.rows().map(|r| r.to_vec()).to_vec()
197    }
198}
199
200impl<const R: usize, const C: usize> Default for Matrix<R, C>
201where
202    Zero<R>: Failed,
203    Zero<C>: Failed,
204{
205    fn default() -> Self {
206        Self::zero()
207    }
208}
209
210impl<const R: usize, const C: usize> std::ops::Index<usize> for Matrix<R, C>
211where
212    Zero<R>: Failed,
213    Zero<C>: Failed,
214{
215    type Output = [f32; C];
216
217    fn index(&self, row: usize) -> &Self::Output {
218        &self.0[row]
219    }
220}
221
222impl<const R: usize, const C: usize> std::fmt::Display for Matrix<R, C>
223where
224    Zero<R>: Failed,
225    Zero<C>: Failed,
226{
227    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228        let lines = self.rows().map(|row| format!("{:?}", row));
229        let longest = lines.iter().map(|s| s.len()).max().unwrap_or(0);
230        writeln!(
231            f,
232            "{:^len$}",
233            format!("({}x{} matrix)", R, C),
234            len = longest
235        )?;
236        for line in lines {
237            writeln!(f, "{line}")?;
238        }
239        Ok(())
240    }
241}