1#![feature(generic_const_exprs)]
2
3use checks::{usize::Zero, Failed};
4
5#[derive(Debug, PartialEq, Copy, Clone)]
6pub struct Matrix<const R: usize, const C: usize>([[f32; C]; R])
7where
8 Zero<R>: Failed,
9 Zero<C>: Failed;
10
11pub type SquareMatrix<const D: usize> = Matrix<D, D>;
12pub type VecMatrix = Vec<Vec<f32>>;
13
14#[macro_export]
15macro_rules! matrix {
16 () => (compile_error!("Empty matrix not allowed"));
17 ($($($value:expr)*),*) => {
18 Matrix::from([
19 $([$($value),*],)*
20 ])
21 };
22}
23
24impl<const R: usize, const C: usize> Matrix<R, C>
25where
26 Zero<R>: Failed,
27 Zero<C>: Failed,
28{
29 pub fn new(closure: impl Fn(usize, usize) -> f32) -> Self {
30 Self(std::array::from_fn(|row| {
31 std::array::from_fn(|column| closure(row, column))
32 }))
33 }
34
35 pub fn zero() -> Self {
36 Self::new(|_, _| 0.0)
37 }
38 pub const fn is_square(&self) -> bool {
39 R == C
40 }
41
42 pub fn rows(&self) -> [[f32; C]; R] {
43 self.0
44 }
45 pub fn columns(&self) -> [[f32; R]; C] {
46 std::array::from_fn(|i| self.rows().map(|row| row[i]))
47 }
48
49 pub fn transpose(&self) -> Matrix<C, R> {
50 Matrix::from(self.columns())
51 }
52
53 pub fn map<F>(&self, f: F) -> Self
54 where
55 F: Fn(f32) -> f32,
56 {
57 Self::new(|r, c| f(self[r][c]))
58 }
59
60 pub fn merge<F>(&self, other: Matrix<R, C>, f: F) -> Self
61 where
62 F: Fn(f32, f32) -> f32,
63 {
64 Self::new(|r, c| f(self[r][c], other[r][c]))
65 }
66}
67
68impl<const D: usize> SquareMatrix<D>
69where
70 Zero<D>: Failed,
71{
72 pub fn identity() -> Self {
73 Self::new(|row, column| if row == column { 1.0 } else { 0.0 })
74 }
75
76 pub fn determinant(&self) -> f32 {
77 Self::determinant_vec_impl(&self.into())
78 }
79
80 pub fn has_inverse(&self) -> bool {
81 self.determinant() != 0.0
82 }
83
84 pub fn inverse(&self) -> Option<Self> {
85 let det = self.determinant();
86 if det == 0.0 {
87 None
88 } else {
89 todo!()
90 }
91 }
92
93 fn determinant_vec_impl(vec: &VecMatrix) -> f32 {
94 let side_len = vec.len();
95 match side_len {
96 0 => 1.0,
97 1 => vec[0][0],
98 2 => (vec[0][0] * vec[1][1]) - (vec[0][1] * vec[1][0]),
99 _ => {
100 let mut det = 0.0;
101 let main_row = &vec[0];
102 for i in 0..vec.len() {
103 let to = side_len - 1;
104 let sub: VecMatrix = (0..to)
105 .map(|ri| {
106 (0..to)
107 .map(|ci| {
108 let row = &vec[ri + 1];
109 row[if ci >= i { ci + 1 } else { ci }]
110 })
111 .collect()
112 })
113 .collect();
114 det += (main_row[i] * Self::determinant_vec_impl(&sub))
115 * (if i % 2 == 0 { 1.0 } else { -1.0 })
116 }
117 det
118 }
119 }
120 }
121}
122
123macro_rules! matrix_merge_op {
124 ($type:path => $op:tt) => {
125 impl<const R: usize, const C: usize> $type for Matrix<R, C>
126 where
127 Zero<R>: Failed,
128 Zero<C>: Failed,
129 {
130 type Output = Self;
131
132 fn $op(self, rhs: Self) -> Self::Output {
133 self.merge(rhs, |a, b| a.$op(b))
134 }
135 }
136 };
137}
138
139matrix_merge_op!(std::ops::Add => add);
140matrix_merge_op!(std::ops::Sub => sub);
141
142impl<const R: usize, const C: usize, const C2: usize> std::ops::Mul<Matrix<C, C2>> for Matrix<R, C>
143where
144 Zero<R>: Failed,
145 Zero<C>: Failed,
146 Zero<C2>: Failed,
147{
148 type Output = Matrix<R, C2>;
149
150 fn mul(self, other: Matrix<C, C2>) -> Self::Output {
151 Matrix::new(|ri, ci| {
152 let row = self.rows()[ri];
153 let column = other.columns()[ci];
154 let mut sum = 0.0;
155 for i in 0..C {
156 sum += row[i] * column[i];
157 }
158 sum
159 })
160 }
161}
162
163impl<const R: usize, const C: usize> std::ops::Mul<f32> for Matrix<R, C>
164where
165 Zero<R>: Failed,
166 Zero<C>: Failed,
167{
168 type Output = Self;
169 fn mul(self, rhs: f32) -> Self::Output {
170 self.map(|v| v * rhs)
171 }
172}
173
174macro_rules! matrix_from_2d_num_array {
175 ($($num:ty)*) => ($(
176 impl<const R: usize, const C: usize> From<[[$num; C]; R]> for Matrix<R, C>
177 where
178 Zero<R>: Failed,
179 Zero<C>: Failed
180 {
181 fn from(value: [[$num; C]; R]) -> Self {
182 Self(value.map(|a| a.map(|b| b as f32)))
183 }
184 }
185 )*)
186}
187
188matrix_from_2d_num_array!(f32 i32 usize);
189
190impl<const R: usize, const C: usize> From<&Matrix<R, C>> for VecMatrix
191where
192 Zero<R>: Failed,
193 Zero<C>: Failed,
194{
195 fn from(val: &Matrix<R, C>) -> Self {
196 val.rows().map(|r| r.to_vec()).to_vec()
197 }
198}
199
200impl<const R: usize, const C: usize> Default for Matrix<R, C>
201where
202 Zero<R>: Failed,
203 Zero<C>: Failed,
204{
205 fn default() -> Self {
206 Self::zero()
207 }
208}
209
210impl<const R: usize, const C: usize> std::ops::Index<usize> for Matrix<R, C>
211where
212 Zero<R>: Failed,
213 Zero<C>: Failed,
214{
215 type Output = [f32; C];
216
217 fn index(&self, row: usize) -> &Self::Output {
218 &self.0[row]
219 }
220}
221
222impl<const R: usize, const C: usize> std::fmt::Display for Matrix<R, C>
223where
224 Zero<R>: Failed,
225 Zero<C>: Failed,
226{
227 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228 let lines = self.rows().map(|row| format!("{:?}", row));
229 let longest = lines.iter().map(|s| s.len()).max().unwrap_or(0);
230 writeln!(
231 f,
232 "{:^len$}",
233 format!("({}x{} matrix)", R, C),
234 len = longest
235 )?;
236 for line in lines {
237 writeln!(f, "{line}")?;
238 }
239 Ok(())
240 }
241}