irox_tools/math/
matrix.rs

1// SPDX-License-Identifier: MIT
2// Copyright 2025 IROX Contributors
3//
4
5#![allow(clippy::indexing_slicing)]
6
7use crate::ToSigned;
8use core::ops::{Add, AddAssign, Deref, DerefMut, Index, IndexMut, Mul, Sub};
9
10cfg_feature_std! {
11    use crate::ToF64;
12}
13
14pub trait AsMatrix<const M: usize, const N: usize, T: Sized + Copy + Default> {
15    fn as_matrix(&self) -> Matrix<M, N, T>;
16}
17
18#[derive(Debug, Copy, Clone, Eq, PartialEq)]
19pub struct Matrix<const M: usize, const N: usize, T: Sized + Copy + Default> {
20    pub values: [[T; N]; M],
21}
22
23impl<const M: usize, const N: usize, T: Sized + Copy + Default> Matrix<M, N, T> {
24    #[must_use]
25    pub const fn new(values: [[T; N]; M]) -> Matrix<M, N, T> {
26        Matrix { values }
27    }
28}
29impl<const M: usize, const N: usize, T: Sized + Copy + Default> From<[[T; N]; M]>
30    for Matrix<M, N, T>
31{
32    fn from(value: [[T; N]; M]) -> Self {
33        Self { values: value }
34    }
35}
36impl<const M: usize, const N: usize, T: Sized + Copy + Default> From<&[[T; N]; M]>
37    for Matrix<M, N, T>
38{
39    fn from(value: &[[T; N]; M]) -> Self {
40        Self { values: *value }
41    }
42}
43impl<const M: usize, const N: usize, T: Sized + Copy + Default> AsMatrix<M, N, T> for [[T; N]; M] {
44    fn as_matrix(&self) -> Matrix<M, N, T> {
45        self.into()
46    }
47}
48impl<const M: usize, const N: usize> Matrix<M, N, f64> {
49    #[must_use]
50    pub const fn mul<const P: usize>(&self, other: Matrix<N, P, f64>) -> Matrix<M, P, f64> {
51        let mut out = [[0.0f64; P]; M];
52        let mut m = 0;
53        while m < M {
54            let mut p = 0;
55            while p < P {
56                let mut n = 0;
57                let mut sum = 0.0;
58                while n < N {
59                    sum += self.values[m][n] * other.values[n][p];
60                    n += 1;
61                }
62                out[m][p] = sum;
63                p += 1;
64            }
65            m += 1;
66        }
67        Matrix { values: out }
68    }
69}
70
71macro_rules! impl_square {
72    ($N:literal) => {
73        impl<T: Sized + Copy + Default> Matrix<$N, $N, T> {
74            #[must_use]
75            pub fn empty() -> Self {
76                Self {
77                    values: [<[T; $N]>::default(); $N],
78                }
79            }
80        }
81        impl Matrix<$N, $N, f64> {
82            #[must_use]
83            pub fn identity() -> Self {
84                let mut out = Self::empty();
85                for i in 0..$N {
86                    out[i][i] = 1.0;
87                }
88                out
89            }
90
91            #[must_use]
92            pub fn transpose(&self) -> Self {
93                let mut out = Self::empty();
94                for i in 0..$N {
95                    for j in 0..$N {
96                        out[i][j] = self.values[j][i];
97                    }
98                }
99                out
100            }
101        }
102    };
103}
104impl_square!(2);
105impl_square!(3);
106impl_square!(4);
107impl_square!(5);
108impl_square!(6);
109impl_square!(7);
110impl_square!(8);
111impl_square!(9);
112impl_square!(10);
113
114impl Matrix<2, 2, f64> {
115    cfg_feature_std! {
116        #[must_use]
117        pub fn rotation_counterclockwise(angle: f64) -> Self {
118            Self::new([[angle.cos(), -angle.sin()], [angle.sin(), angle.cos()]])
119        }
120        #[must_use]
121        pub fn rotate_counterclockwise(&self, angle: f64) -> Self {
122            self.mul(Self::rotation_counterclockwise(angle))
123        }
124        #[must_use]
125        pub fn rotation_clockwise(angle: f64) -> Self {
126            Self::new([[angle.cos(), angle.sin()], [-angle.sin(), angle.cos()]])
127        }
128        #[must_use]
129        pub fn rotate_clockwise(&self, angle: f64) -> Self {
130            self.mul(Self::rotation_clockwise(angle))
131        }
132    }
133
134    #[must_use]
135    pub const fn sheered_x(factor: f64) -> Self {
136        Self::new([[1., factor], [0., 1.]])
137    }
138    #[must_use]
139    pub const fn sheer_x(&self, factor: f64) -> Self {
140        self.mul(Self::sheered_x(factor))
141    }
142    #[must_use]
143    pub const fn sheered_y(factor: f64) -> Self {
144        Self::new([[1., 0.], [factor, 1.]])
145    }
146    #[must_use]
147    pub const fn sheer_y(&self, factor: f64) -> Self {
148        self.mul(Self::sheered_y(factor))
149    }
150
151    #[must_use]
152    pub const fn scaled_x(factor: f64) -> Self {
153        Self::new([[factor, 0.], [0., 1.]])
154    }
155    #[must_use]
156    pub const fn scale_x(&self, factor: f64) -> Self {
157        self.mul(Self::scaled_x(factor))
158    }
159
160    #[must_use]
161    pub const fn scaled_y(factor: f64) -> Self {
162        Self::new([[1., 0.], [0., factor]])
163    }
164    #[must_use]
165    pub const fn scale_y(&self, factor: f64) -> Self {
166        self.mul(Self::scaled_y(factor))
167    }
168
169    #[must_use]
170    pub const fn scaled(factor: f64) -> Self {
171        Self::new([[factor, 0.], [0., factor]])
172    }
173    #[must_use]
174    pub const fn scale(&self, factor: f64) -> Self {
175        self.mul(Self::scaled(factor))
176    }
177    #[must_use]
178    pub const fn reflected() -> Self {
179        Self::new([[-1., 0.], [0., -1.]])
180    }
181    #[must_use]
182    pub const fn reflect(&self) -> Self {
183        self.mul(Self::reflected())
184    }
185
186    #[must_use]
187    pub const fn reflected_x() -> Self {
188        Self::new([[1., 0.], [0., -1.]])
189    }
190    #[must_use]
191    pub const fn reflect_x(&self) -> Self {
192        self.mul(Self::reflected_x())
193    }
194    #[must_use]
195    pub const fn reflected_y() -> Self {
196        Self::new([[-1., 0.], [0., 1.]])
197    }
198    #[must_use]
199    pub const fn reflect_y(&self) -> Self {
200        self.mul(Self::reflected_y())
201    }
202}
203impl Matrix<3, 1, f64> {
204    #[must_use]
205    pub const fn translate(&self, x: f64, y: f64) -> Self {
206        Matrix::mul(
207            &Matrix::new([[1., 0., x], [0., 1., y], [0., 0., 1.]]),
208            *self,
209        )
210    }
211    cfg_feature_std! {
212        #[must_use]
213        pub fn rotate_x<T: ToF64 + Copy>(&self, angle: T) -> Self {
214            Matrix::<3, 3, f64>::rotated_x(angle).mul(*self)
215        }
216        #[must_use]
217        pub fn rotate_y<T: ToF64 + Copy>(&self, angle: T) -> Self {
218            Matrix::<3, 3, f64>::rotated_y(angle).mul(*self)
219        }
220        #[must_use]
221        pub fn rotate_z<T: ToF64 + Copy>(&self, angle: T) -> Self {
222            Matrix::<3, 3, f64>::rotated_z(angle).mul(*self)
223        }
224        #[must_use]
225        pub fn rotate_zyx<T: ToF64 + Copy>(&self, x_angle: T, y_angle: T, z_angle: T) -> Self {
226            Matrix::<3, 3, f64>::rotated_zyx(x_angle, y_angle, z_angle).mul(*self)
227        }
228    }
229}
230impl Matrix<3, 3, f64> {
231    cfg_feature_std! {
232        #[must_use]
233        pub fn rotated_x<T: ToF64 + Copy>(angle: T) -> Self {
234        let angle = angle.to_f64();
235            Self::new([
236                [1., 0., 0.],
237                [0., angle.cos(), -angle.sin()],
238                [0., angle.sin(), angle.cos()],
239            ])
240        }
241        #[must_use]
242        pub fn rotate_x<T: ToF64 + Copy>(&self, angle: T) -> Self {
243            self.mul(Self::rotated_x(angle))
244        }
245        #[must_use]
246        pub fn rotated_y<T: ToF64 + Copy>(angle: T) -> Self {
247            let angle = angle.to_f64();
248                Self::new([
249                    [angle.cos(), 0.0, angle.sin()],
250                    [0., 1., 0.],
251                    [-angle.sin(), 0., angle.cos()],
252                ])
253            }
254            #[must_use]
255        pub fn rotate_y<T: ToF64 + Copy>(&self, angle: T) -> Self {
256            self.mul(Self::rotated_y(angle))
257        }
258        #[must_use]
259        pub fn rotated_z<T: ToF64 + Copy>(angle: T) -> Self {
260            let angle = angle.to_f64();
261                Self::new([
262                    [angle.cos(), angle.sin(), 0.],
263                    [-angle.sin(), angle.cos(), 0.],
264                    [0., 0., 1.],
265                ])
266            }
267        #[must_use]
268        pub fn rotate_z<T: ToF64 + Copy>(&self, angle: T) -> Self {
269            self.mul(Self::rotated_z(angle))
270        }
271
272        #[must_use]
273        pub fn rotated_zyx<T: ToF64 + Copy>(x_angle: T, y_angle: T, z_angle: T) -> Self {
274            Self::rotated_z(z_angle).rotate_y(y_angle).rotate_x(x_angle)
275        }
276        #[must_use]
277        pub fn rotate_zyx<T: ToF64 + Copy>(&self, x_angle: T, y_angle: T, z_angle: T) -> Self {
278            self.rotate_z(z_angle).rotate_y(y_angle).rotate_x(x_angle)
279        }
280    }
281}
282
283impl<const M: usize, const N: usize, T: Sized + Copy + Default> Index<usize> for Matrix<M, N, T> {
284    type Output = [T; N];
285
286    fn index(&self, index: usize) -> &Self::Output {
287        self.values.index(index)
288    }
289}
290impl<const M: usize, const N: usize, T: Sized + Copy + Default> IndexMut<usize>
291    for Matrix<M, N, T>
292{
293    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
294        self.values.index_mut(index)
295    }
296}
297
298impl<const M: usize, const N: usize, T: Sized + Copy + Default> Deref for Matrix<M, N, T> {
299    type Target = [[T; N]; M];
300
301    fn deref(&self) -> &Self::Target {
302        &self.values
303    }
304}
305impl<const M: usize, const N: usize, T: Sized + Copy + Default> DerefMut for Matrix<M, N, T> {
306    fn deref_mut(&mut self) -> &mut Self::Target {
307        &mut self.values
308    }
309}
310// matrix multiply
311macro_rules! impl_mul {
312    ($($ty:ty)+) => {
313        impl<
314            const M: usize,
315            const N: usize,
316            const P: usize,
317            T: Sized + Copy + Default + Add + Mul + AddAssign<<T as Mul<T>>::Output>,
318            > Mul<Matrix<N, P, T>> for $($ty)+
319        {
320            type Output = Matrix<M, P, T>;
321            fn mul(self, other: Matrix<N, P, T>) -> Matrix<M, P, T> {
322                let mut out = [[T::default(); P]; M];
323                let mut m = 0;
324                while m < M {
325                    let mut p = 0;
326                    while p < P {
327                        let mut n = 0;
328                        let mut sum = T::default();
329                        while n < N {
330                            sum += self.values[m][n] * other.values[n][p];
331                            n += 1;
332                        }
333                        out[m][p] = sum;
334                        p += 1;
335                    }
336                    m += 1;
337                }
338                Matrix { values: out }
339            }
340        }
341    };
342}
343impl_mul!(Matrix<M, N, T>);
344impl_mul!(&Matrix<M, N, T>);
345impl_mul!(&mut Matrix<M, N, T>);
346
347// matrix add
348impl<const M: usize, const N: usize, T: Sized + Copy + Default + Add<Output = T>> Add
349    for Matrix<M, N, T>
350{
351    type Output = Matrix<M, N, T>;
352    fn add(self, other: Matrix<M, N, T>) -> Matrix<M, N, T> {
353        let mut out = [[T::default(); N]; M];
354
355        for (i, ith) in out.iter_mut().enumerate().take(M) {
356            for (j, val) in ith.iter_mut().enumerate().take(N) {
357                *val = self.values[i][j] + other.values[i][j];
358            }
359        }
360        Matrix { values: out }
361    }
362}
363// scalar multiply
364impl<const M: usize, const N: usize, T: Sized + Copy + Default + Mul<T, Output = T>> Mul<T>
365    for Matrix<M, N, T>
366{
367    type Output = Matrix<M, N, T>;
368    fn mul(self, other: T) -> Matrix<M, N, T> {
369        let mut out = [[T::default(); N]; M];
370
371        for (i, ith) in out.iter_mut().enumerate().take(M) {
372            for (j, val) in ith.iter_mut().enumerate().take(N) {
373                *val = self.values[i][j] * other;
374            }
375        }
376        Matrix { values: out }
377    }
378}
379
380impl<
381        const M: usize,
382        const N: usize,
383        T: Sized + Copy + Default + Add<Output = T> + Mul<T, Output = T> + ToSigned<Output = T>,
384    > Sub for Matrix<M, N, T>
385{
386    type Output = Matrix<M, N, T>;
387
388    fn sub(self, rhs: Self) -> Self::Output {
389        let v = rhs * <T as ToSigned>::negative_one();
390        self + v
391    }
392}
393
394#[cfg(test)]
395mod test {
396    use crate::math::{AsMatrix, Matrix};
397    use core::ops::Deref;
398
399    #[test]
400    pub fn test_scalar() {
401        let mat = Matrix::new([[4, 0], [1, -9]]);
402        let res = mat * 2;
403        assert_eq!(res, Matrix::new([[8, 0], [2, -18]]));
404    }
405
406    #[test]
407    pub fn test_product() {
408        let m1 = Matrix::new([[1, 2, 3], [4, 5, 6]]);
409        let m2 = Matrix::new([[7, 8], [9, 10], [11, 12]]);
410        let res = m1 * m2;
411        assert_eq!(res, Matrix::new([[58, 64], [139, 154]]));
412    }
413
414    #[cfg(feature = "std")]
415    #[test]
416    pub fn test_rotate1() {
417        let m = [[3.], [7.], [4.]].as_matrix();
418        let [[x], [y], [z]] = *m.rotate_x(core::f64::consts::FRAC_PI_2).deref();
419        assert_eq_eps!(3., x, 2. * f64::EPSILON);
420        assert_eq_eps!(-4., y, 2. * f64::EPSILON);
421        assert_eq_eps!(7., z, 2. * f64::EPSILON);
422    }
423
424    #[cfg(feature = "std")]
425    #[test]
426    pub fn test_rotate2() {
427        let m = [[3.], [7.], [4.]].as_matrix();
428        let [[x], [y], [z]] = *m.rotate_y(core::f64::consts::FRAC_PI_2).deref();
429        assert_eq_eps!(4., x, 2. * f64::EPSILON);
430        assert_eq_eps!(7., y, 2. * f64::EPSILON);
431        assert_eq_eps!(-3., z, 2. * f64::EPSILON);
432    }
433
434    #[cfg(feature = "std")]
435    #[test]
436    pub fn test_rotate3() {
437        let m = [[3.], [7.], [4.]].as_matrix();
438        let [[x], [y], [z]] = *m.rotate_z(core::f64::consts::FRAC_PI_2).deref();
439        assert_eq_eps!(7., x, 2. * f64::EPSILON);
440        assert_eq_eps!(-3., y, 2. * f64::EPSILON);
441        assert_eq_eps!(4., z, 2. * f64::EPSILON);
442    }
443}