Skip to main content

euca_math/
mat.rs

1use crate::{Quat, Vec3, Vec4};
2use serde::{Deserialize, Serialize};
3use std::ops::Mul;
4
5/// 4x4 column-major matrix.
6#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
7#[repr(C, align(16))]
8pub struct Mat4 {
9    /// Column 0 (x-axis)
10    pub cols: [[f32; 4]; 4],
11}
12
13impl Default for Mat4 {
14    #[inline(always)]
15    fn default() -> Self {
16        Self::IDENTITY
17    }
18}
19
20impl Mat4 {
21    /// The identity matrix.
22    pub const IDENTITY: Self = Self {
23        cols: [
24            [1.0, 0.0, 0.0, 0.0],
25            [0.0, 1.0, 0.0, 0.0],
26            [0.0, 0.0, 1.0, 0.0],
27            [0.0, 0.0, 0.0, 1.0],
28        ],
29    };
30
31    /// All zeros.
32    pub const ZERO: Self = Self {
33        cols: [[0.0; 4]; 4],
34    };
35
36    /// Returns column `i` as a 4-element array.
37    #[inline(always)]
38    pub fn col(&self, i: usize) -> [f32; 4] {
39        self.cols[i]
40    }
41
42    /// Column-major element access: col i, row j.
43    #[inline(always)]
44    pub fn get(&self, col: usize, row: usize) -> f32 {
45        self.cols[col][row]
46    }
47
48    /// Creates a translation matrix from a 3D offset.
49    pub fn from_translation(t: Vec3) -> Self {
50        let mut m = Self::IDENTITY;
51        m.cols[3][0] = t.x;
52        m.cols[3][1] = t.y;
53        m.cols[3][2] = t.z;
54        m
55    }
56
57    /// Creates a non-uniform scale matrix.
58    pub fn from_scale(s: Vec3) -> Self {
59        let mut m = Self::IDENTITY;
60        m.cols[0][0] = s.x;
61        m.cols[1][1] = s.y;
62        m.cols[2][2] = s.z;
63        m
64    }
65
66    /// Creates a rotation matrix from a unit quaternion.
67    pub fn from_rotation(q: Quat) -> Self {
68        let x2 = q.x + q.x;
69        let y2 = q.y + q.y;
70        let z2 = q.z + q.z;
71        let xx = q.x * x2;
72        let xy = q.x * y2;
73        let xz = q.x * z2;
74        let yy = q.y * y2;
75        let yz = q.y * z2;
76        let zz = q.z * z2;
77        let wx = q.w * x2;
78        let wy = q.w * y2;
79        let wz = q.w * z2;
80
81        Self {
82            cols: [
83                [1.0 - (yy + zz), xy + wz, xz - wy, 0.0],
84                [xy - wz, 1.0 - (xx + zz), yz + wx, 0.0],
85                [xz + wy, yz - wx, 1.0 - (xx + yy), 0.0],
86                [0.0, 0.0, 0.0, 1.0],
87            ],
88        }
89    }
90
91    /// Creates a combined scale-rotation-translation matrix.
92    pub fn from_scale_rotation_translation(s: Vec3, r: Quat, t: Vec3) -> Self {
93        let rot = Self::from_rotation(r);
94        Self {
95            cols: [
96                [
97                    rot.cols[0][0] * s.x,
98                    rot.cols[0][1] * s.x,
99                    rot.cols[0][2] * s.x,
100                    0.0,
101                ],
102                [
103                    rot.cols[1][0] * s.y,
104                    rot.cols[1][1] * s.y,
105                    rot.cols[1][2] * s.y,
106                    0.0,
107                ],
108                [
109                    rot.cols[2][0] * s.z,
110                    rot.cols[2][1] * s.z,
111                    rot.cols[2][2] * s.z,
112                    0.0,
113                ],
114                [t.x, t.y, t.z, 1.0],
115            ],
116        }
117    }
118
119    /// Orthographic projection (left-handed, depth 0..1).
120    pub fn orthographic_lh(
121        left: f32,
122        right: f32,
123        bottom: f32,
124        top: f32,
125        z_near: f32,
126        z_far: f32,
127    ) -> Self {
128        let rml = right - left;
129        let tmb = top - bottom;
130        let fmn = z_far - z_near;
131        Self {
132            cols: [
133                [2.0 / rml, 0.0, 0.0, 0.0],
134                [0.0, 2.0 / tmb, 0.0, 0.0],
135                [0.0, 0.0, 1.0 / fmn, 0.0],
136                [
137                    -(right + left) / rml,
138                    -(top + bottom) / tmb,
139                    -z_near / fmn,
140                    1.0,
141                ],
142            ],
143        }
144    }
145
146    /// Perspective projection (left-handed, depth 0..1).
147    pub fn perspective_lh(fov_y_radians: f32, aspect: f32, z_near: f32, z_far: f32) -> Self {
148        let h = 1.0 / (fov_y_radians * 0.5).tan();
149        let w = h / aspect;
150        let r = z_far / (z_far - z_near);
151
152        Self {
153            cols: [
154                [w, 0.0, 0.0, 0.0],
155                [0.0, h, 0.0, 0.0],
156                [0.0, 0.0, r, 1.0],
157                [0.0, 0.0, -r * z_near, 0.0],
158            ],
159        }
160    }
161
162    /// Left-handed look-at view matrix.
163    pub fn look_at_lh(eye: Vec3, target: Vec3, up: Vec3) -> Self {
164        let f = (target - eye).normalize();
165        let s = up.cross(f).normalize();
166        let u = f.cross(s);
167
168        Self {
169            cols: [
170                [s.x, u.x, f.x, 0.0],
171                [s.y, u.y, f.y, 0.0],
172                [s.z, u.z, f.z, 0.0],
173                [-s.dot(eye), -u.dot(eye), -f.dot(eye), 1.0],
174            ],
175        }
176    }
177
178    /// Computes the matrix inverse via cofactor expansion.
179    pub fn inverse(self) -> Self {
180        // Cofactor expansion for 4x4 inverse
181        let m = &self.cols;
182        let a2323 = m[2][2] * m[3][3] - m[3][2] * m[2][3];
183        let a1323 = m[1][2] * m[3][3] - m[3][2] * m[1][3];
184        let a1223 = m[1][2] * m[2][3] - m[2][2] * m[1][3];
185        let a0323 = m[0][2] * m[3][3] - m[3][2] * m[0][3];
186        let a0223 = m[0][2] * m[2][3] - m[2][2] * m[0][3];
187        let a0123 = m[0][2] * m[1][3] - m[1][2] * m[0][3];
188        let a2313 = m[2][1] * m[3][3] - m[3][1] * m[2][3];
189        let a1313 = m[1][1] * m[3][3] - m[3][1] * m[1][3];
190        let a1213 = m[1][1] * m[2][3] - m[2][1] * m[1][3];
191        let a2312 = m[2][1] * m[3][2] - m[3][1] * m[2][2];
192        let a1312 = m[1][1] * m[3][2] - m[3][1] * m[1][2];
193        let a1212 = m[1][1] * m[2][2] - m[2][1] * m[1][2];
194        let a0313 = m[0][1] * m[3][3] - m[3][1] * m[0][3];
195        let a0213 = m[0][1] * m[2][3] - m[2][1] * m[0][3];
196        let a0312 = m[0][1] * m[3][2] - m[3][1] * m[0][2];
197        let a0212 = m[0][1] * m[2][2] - m[2][1] * m[0][2];
198        let a0113 = m[0][1] * m[1][3] - m[1][1] * m[0][3];
199        let a0112 = m[0][1] * m[1][2] - m[1][1] * m[0][2];
200
201        let det = m[0][0] * (m[1][1] * a2323 - m[2][1] * a1323 + m[3][1] * a1223)
202            - m[1][0] * (m[0][1] * a2323 - m[2][1] * a0323 + m[3][1] * a0223)
203            + m[2][0] * (m[0][1] * a1323 - m[1][1] * a0323 + m[3][1] * a0123)
204            - m[3][0] * (m[0][1] * a1223 - m[1][1] * a0223 + m[2][1] * a0123);
205
206        let inv_det = 1.0 / det;
207
208        Self {
209            cols: [
210                [
211                    inv_det * (m[1][1] * a2323 - m[2][1] * a1323 + m[3][1] * a1223),
212                    inv_det * -(m[0][1] * a2323 - m[2][1] * a0323 + m[3][1] * a0223),
213                    inv_det * (m[0][1] * a1323 - m[1][1] * a0323 + m[3][1] * a0123),
214                    inv_det * -(m[0][1] * a1223 - m[1][1] * a0223 + m[2][1] * a0123),
215                ],
216                [
217                    inv_det * -(m[1][0] * a2323 - m[2][0] * a1323 + m[3][0] * a1223),
218                    inv_det * (m[0][0] * a2323 - m[2][0] * a0323 + m[3][0] * a0223),
219                    inv_det * -(m[0][0] * a1323 - m[1][0] * a0323 + m[3][0] * a0123),
220                    inv_det * (m[0][0] * a1223 - m[1][0] * a0223 + m[2][0] * a0123),
221                ],
222                [
223                    inv_det * (m[1][0] * a2313 - m[2][0] * a1313 + m[3][0] * a1213),
224                    inv_det * -(m[0][0] * a2313 - m[2][0] * a0313 + m[3][0] * a0213),
225                    inv_det * (m[0][0] * a1313 - m[1][0] * a0313 + m[3][0] * a0113),
226                    inv_det * -(m[0][0] * a1213 - m[1][0] * a0213 + m[2][0] * a0113),
227                ],
228                [
229                    inv_det * -(m[1][0] * a2312 - m[2][0] * a1312 + m[3][0] * a1212),
230                    inv_det * (m[0][0] * a2312 - m[2][0] * a0312 + m[3][0] * a0212),
231                    inv_det * -(m[0][0] * a1312 - m[1][0] * a0312 + m[3][0] * a0112),
232                    inv_det * (m[0][0] * a1212 - m[1][0] * a0212 + m[2][0] * a0112),
233                ],
234            ],
235        }
236    }
237
238    /// Returns the transpose of this matrix.
239    pub fn transpose(self) -> Self {
240        let m = &self.cols;
241        Self {
242            cols: [
243                [m[0][0], m[1][0], m[2][0], m[3][0]],
244                [m[0][1], m[1][1], m[2][1], m[3][1]],
245                [m[0][2], m[1][2], m[2][2], m[3][2]],
246                [m[0][3], m[1][3], m[2][3], m[3][3]],
247            ],
248        }
249    }
250
251    /// Convert to column-major 2D array (for GPU upload).
252    #[inline(always)]
253    pub fn to_cols_array_2d(&self) -> [[f32; 4]; 4] {
254        self.cols
255    }
256
257    /// Create from a column-major 2D array.
258    #[inline(always)]
259    pub fn from_cols_array_2d(cols: &[[f32; 4]; 4]) -> Self {
260        Self { cols: *cols }
261    }
262
263    /// Transform a point (w=1, applies translation).
264    pub fn transform_point3(&self, p: Vec3) -> Vec3 {
265        let m = &self.cols;
266        Vec3::new(
267            m[0][0] * p.x + m[1][0] * p.y + m[2][0] * p.z + m[3][0],
268            m[0][1] * p.x + m[1][1] * p.y + m[2][1] * p.z + m[3][1],
269            m[0][2] * p.x + m[1][2] * p.y + m[2][2] * p.z + m[3][2],
270        )
271    }
272}
273
274impl Mul for Mat4 {
275    type Output = Self;
276    fn mul(self, rhs: Self) -> Self {
277        let a = &self.cols;
278        let b = &rhs.cols;
279        let mut out = [[0.0f32; 4]; 4];
280
281        for c in 0..4 {
282            for r in 0..4 {
283                out[c][r] =
284                    a[0][r] * b[c][0] + a[1][r] * b[c][1] + a[2][r] * b[c][2] + a[3][r] * b[c][3];
285            }
286        }
287
288        Self { cols: out }
289    }
290}
291
292impl Mul<Vec4> for Mat4 {
293    type Output = Vec4;
294    fn mul(self, v: Vec4) -> Vec4 {
295        let m = &self.cols;
296        Vec4::new(
297            m[0][0] * v.x + m[1][0] * v.y + m[2][0] * v.z + m[3][0] * v.w,
298            m[0][1] * v.x + m[1][1] * v.y + m[2][1] * v.z + m[3][1] * v.w,
299            m[0][2] * v.x + m[1][2] * v.y + m[2][2] * v.z + m[3][2] * v.w,
300            m[0][3] * v.x + m[1][3] * v.y + m[2][3] * v.z + m[3][3] * v.w,
301        )
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    #[test]
310    fn identity_mul() {
311        let m = Mat4::from_translation(Vec3::new(1.0, 2.0, 3.0));
312        let result = Mat4::IDENTITY * m;
313        assert_eq!(result, m);
314    }
315
316    #[test]
317    fn inverse_identity() {
318        let m = Mat4::from_scale_rotation_translation(
319            Vec3::new(2.0, 2.0, 2.0),
320            Quat::from_axis_angle(Vec3::Z, 0.5),
321            Vec3::new(10.0, 20.0, 30.0),
322        );
323        let inv = m.inverse();
324        let result = m * inv;
325        for c in 0..4 {
326            for r in 0..4 {
327                let expected = if c == r { 1.0 } else { 0.0 };
328                assert!(
329                    (result.cols[c][r] - expected).abs() < 1e-4,
330                    "M*M^-1 [{c}][{r}] = {} (expected {expected})",
331                    result.cols[c][r]
332                );
333            }
334        }
335    }
336
337    #[test]
338    fn transform_point() {
339        let m = Mat4::from_translation(Vec3::new(10.0, 0.0, 0.0));
340        let p = Vec3::new(1.0, 2.0, 3.0);
341        let result = m.transform_point3(p);
342        assert_eq!(result, Vec3::new(11.0, 2.0, 3.0));
343    }
344}