1use crate::{Quat, Vec3, Vec4};
2use serde::{Deserialize, Serialize};
3use std::ops::Mul;
4
5#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
7#[repr(C, align(16))]
8pub struct Mat4 {
9 pub cols: [[f32; 4]; 4],
11}
12
13impl Default for Mat4 {
14 #[inline(always)]
15 fn default() -> Self {
16 Self::IDENTITY
17 }
18}
19
20impl Mat4 {
21 pub const IDENTITY: Self = Self {
23 cols: [
24 [1.0, 0.0, 0.0, 0.0],
25 [0.0, 1.0, 0.0, 0.0],
26 [0.0, 0.0, 1.0, 0.0],
27 [0.0, 0.0, 0.0, 1.0],
28 ],
29 };
30
31 pub const ZERO: Self = Self {
33 cols: [[0.0; 4]; 4],
34 };
35
36 #[inline(always)]
38 pub fn col(&self, i: usize) -> [f32; 4] {
39 self.cols[i]
40 }
41
42 #[inline(always)]
44 pub fn get(&self, col: usize, row: usize) -> f32 {
45 self.cols[col][row]
46 }
47
48 pub fn from_translation(t: Vec3) -> Self {
50 let mut m = Self::IDENTITY;
51 m.cols[3][0] = t.x;
52 m.cols[3][1] = t.y;
53 m.cols[3][2] = t.z;
54 m
55 }
56
57 pub fn from_scale(s: Vec3) -> Self {
59 let mut m = Self::IDENTITY;
60 m.cols[0][0] = s.x;
61 m.cols[1][1] = s.y;
62 m.cols[2][2] = s.z;
63 m
64 }
65
66 pub fn from_rotation(q: Quat) -> Self {
68 let x2 = q.x + q.x;
69 let y2 = q.y + q.y;
70 let z2 = q.z + q.z;
71 let xx = q.x * x2;
72 let xy = q.x * y2;
73 let xz = q.x * z2;
74 let yy = q.y * y2;
75 let yz = q.y * z2;
76 let zz = q.z * z2;
77 let wx = q.w * x2;
78 let wy = q.w * y2;
79 let wz = q.w * z2;
80
81 Self {
82 cols: [
83 [1.0 - (yy + zz), xy + wz, xz - wy, 0.0],
84 [xy - wz, 1.0 - (xx + zz), yz + wx, 0.0],
85 [xz + wy, yz - wx, 1.0 - (xx + yy), 0.0],
86 [0.0, 0.0, 0.0, 1.0],
87 ],
88 }
89 }
90
91 pub fn from_scale_rotation_translation(s: Vec3, r: Quat, t: Vec3) -> Self {
93 let rot = Self::from_rotation(r);
94 Self {
95 cols: [
96 [
97 rot.cols[0][0] * s.x,
98 rot.cols[0][1] * s.x,
99 rot.cols[0][2] * s.x,
100 0.0,
101 ],
102 [
103 rot.cols[1][0] * s.y,
104 rot.cols[1][1] * s.y,
105 rot.cols[1][2] * s.y,
106 0.0,
107 ],
108 [
109 rot.cols[2][0] * s.z,
110 rot.cols[2][1] * s.z,
111 rot.cols[2][2] * s.z,
112 0.0,
113 ],
114 [t.x, t.y, t.z, 1.0],
115 ],
116 }
117 }
118
119 pub fn orthographic_lh(
121 left: f32,
122 right: f32,
123 bottom: f32,
124 top: f32,
125 z_near: f32,
126 z_far: f32,
127 ) -> Self {
128 let rml = right - left;
129 let tmb = top - bottom;
130 let fmn = z_far - z_near;
131 Self {
132 cols: [
133 [2.0 / rml, 0.0, 0.0, 0.0],
134 [0.0, 2.0 / tmb, 0.0, 0.0],
135 [0.0, 0.0, 1.0 / fmn, 0.0],
136 [
137 -(right + left) / rml,
138 -(top + bottom) / tmb,
139 -z_near / fmn,
140 1.0,
141 ],
142 ],
143 }
144 }
145
146 pub fn perspective_lh(fov_y_radians: f32, aspect: f32, z_near: f32, z_far: f32) -> Self {
148 let h = 1.0 / (fov_y_radians * 0.5).tan();
149 let w = h / aspect;
150 let r = z_far / (z_far - z_near);
151
152 Self {
153 cols: [
154 [w, 0.0, 0.0, 0.0],
155 [0.0, h, 0.0, 0.0],
156 [0.0, 0.0, r, 1.0],
157 [0.0, 0.0, -r * z_near, 0.0],
158 ],
159 }
160 }
161
162 pub fn look_at_lh(eye: Vec3, target: Vec3, up: Vec3) -> Self {
164 let f = (target - eye).normalize();
165 let s = up.cross(f).normalize();
166 let u = f.cross(s);
167
168 Self {
169 cols: [
170 [s.x, u.x, f.x, 0.0],
171 [s.y, u.y, f.y, 0.0],
172 [s.z, u.z, f.z, 0.0],
173 [-s.dot(eye), -u.dot(eye), -f.dot(eye), 1.0],
174 ],
175 }
176 }
177
178 pub fn inverse(self) -> Self {
180 let m = &self.cols;
182 let a2323 = m[2][2] * m[3][3] - m[3][2] * m[2][3];
183 let a1323 = m[1][2] * m[3][3] - m[3][2] * m[1][3];
184 let a1223 = m[1][2] * m[2][3] - m[2][2] * m[1][3];
185 let a0323 = m[0][2] * m[3][3] - m[3][2] * m[0][3];
186 let a0223 = m[0][2] * m[2][3] - m[2][2] * m[0][3];
187 let a0123 = m[0][2] * m[1][3] - m[1][2] * m[0][3];
188 let a2313 = m[2][1] * m[3][3] - m[3][1] * m[2][3];
189 let a1313 = m[1][1] * m[3][3] - m[3][1] * m[1][3];
190 let a1213 = m[1][1] * m[2][3] - m[2][1] * m[1][3];
191 let a2312 = m[2][1] * m[3][2] - m[3][1] * m[2][2];
192 let a1312 = m[1][1] * m[3][2] - m[3][1] * m[1][2];
193 let a1212 = m[1][1] * m[2][2] - m[2][1] * m[1][2];
194 let a0313 = m[0][1] * m[3][3] - m[3][1] * m[0][3];
195 let a0213 = m[0][1] * m[2][3] - m[2][1] * m[0][3];
196 let a0312 = m[0][1] * m[3][2] - m[3][1] * m[0][2];
197 let a0212 = m[0][1] * m[2][2] - m[2][1] * m[0][2];
198 let a0113 = m[0][1] * m[1][3] - m[1][1] * m[0][3];
199 let a0112 = m[0][1] * m[1][2] - m[1][1] * m[0][2];
200
201 let det = m[0][0] * (m[1][1] * a2323 - m[2][1] * a1323 + m[3][1] * a1223)
202 - m[1][0] * (m[0][1] * a2323 - m[2][1] * a0323 + m[3][1] * a0223)
203 + m[2][0] * (m[0][1] * a1323 - m[1][1] * a0323 + m[3][1] * a0123)
204 - m[3][0] * (m[0][1] * a1223 - m[1][1] * a0223 + m[2][1] * a0123);
205
206 let inv_det = 1.0 / det;
207
208 Self {
209 cols: [
210 [
211 inv_det * (m[1][1] * a2323 - m[2][1] * a1323 + m[3][1] * a1223),
212 inv_det * -(m[0][1] * a2323 - m[2][1] * a0323 + m[3][1] * a0223),
213 inv_det * (m[0][1] * a1323 - m[1][1] * a0323 + m[3][1] * a0123),
214 inv_det * -(m[0][1] * a1223 - m[1][1] * a0223 + m[2][1] * a0123),
215 ],
216 [
217 inv_det * -(m[1][0] * a2323 - m[2][0] * a1323 + m[3][0] * a1223),
218 inv_det * (m[0][0] * a2323 - m[2][0] * a0323 + m[3][0] * a0223),
219 inv_det * -(m[0][0] * a1323 - m[1][0] * a0323 + m[3][0] * a0123),
220 inv_det * (m[0][0] * a1223 - m[1][0] * a0223 + m[2][0] * a0123),
221 ],
222 [
223 inv_det * (m[1][0] * a2313 - m[2][0] * a1313 + m[3][0] * a1213),
224 inv_det * -(m[0][0] * a2313 - m[2][0] * a0313 + m[3][0] * a0213),
225 inv_det * (m[0][0] * a1313 - m[1][0] * a0313 + m[3][0] * a0113),
226 inv_det * -(m[0][0] * a1213 - m[1][0] * a0213 + m[2][0] * a0113),
227 ],
228 [
229 inv_det * -(m[1][0] * a2312 - m[2][0] * a1312 + m[3][0] * a1212),
230 inv_det * (m[0][0] * a2312 - m[2][0] * a0312 + m[3][0] * a0212),
231 inv_det * -(m[0][0] * a1312 - m[1][0] * a0312 + m[3][0] * a0112),
232 inv_det * (m[0][0] * a1212 - m[1][0] * a0212 + m[2][0] * a0112),
233 ],
234 ],
235 }
236 }
237
238 pub fn transpose(self) -> Self {
240 let m = &self.cols;
241 Self {
242 cols: [
243 [m[0][0], m[1][0], m[2][0], m[3][0]],
244 [m[0][1], m[1][1], m[2][1], m[3][1]],
245 [m[0][2], m[1][2], m[2][2], m[3][2]],
246 [m[0][3], m[1][3], m[2][3], m[3][3]],
247 ],
248 }
249 }
250
251 #[inline(always)]
253 pub fn to_cols_array_2d(&self) -> [[f32; 4]; 4] {
254 self.cols
255 }
256
257 #[inline(always)]
259 pub fn from_cols_array_2d(cols: &[[f32; 4]; 4]) -> Self {
260 Self { cols: *cols }
261 }
262
263 pub fn transform_point3(&self, p: Vec3) -> Vec3 {
265 let m = &self.cols;
266 Vec3::new(
267 m[0][0] * p.x + m[1][0] * p.y + m[2][0] * p.z + m[3][0],
268 m[0][1] * p.x + m[1][1] * p.y + m[2][1] * p.z + m[3][1],
269 m[0][2] * p.x + m[1][2] * p.y + m[2][2] * p.z + m[3][2],
270 )
271 }
272}
273
274impl Mul for Mat4 {
275 type Output = Self;
276 fn mul(self, rhs: Self) -> Self {
277 let a = &self.cols;
278 let b = &rhs.cols;
279 let mut out = [[0.0f32; 4]; 4];
280
281 for c in 0..4 {
282 for r in 0..4 {
283 out[c][r] =
284 a[0][r] * b[c][0] + a[1][r] * b[c][1] + a[2][r] * b[c][2] + a[3][r] * b[c][3];
285 }
286 }
287
288 Self { cols: out }
289 }
290}
291
292impl Mul<Vec4> for Mat4 {
293 type Output = Vec4;
294 fn mul(self, v: Vec4) -> Vec4 {
295 let m = &self.cols;
296 Vec4::new(
297 m[0][0] * v.x + m[1][0] * v.y + m[2][0] * v.z + m[3][0] * v.w,
298 m[0][1] * v.x + m[1][1] * v.y + m[2][1] * v.z + m[3][1] * v.w,
299 m[0][2] * v.x + m[1][2] * v.y + m[2][2] * v.z + m[3][2] * v.w,
300 m[0][3] * v.x + m[1][3] * v.y + m[2][3] * v.z + m[3][3] * v.w,
301 )
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308
309 #[test]
310 fn identity_mul() {
311 let m = Mat4::from_translation(Vec3::new(1.0, 2.0, 3.0));
312 let result = Mat4::IDENTITY * m;
313 assert_eq!(result, m);
314 }
315
316 #[test]
317 fn inverse_identity() {
318 let m = Mat4::from_scale_rotation_translation(
319 Vec3::new(2.0, 2.0, 2.0),
320 Quat::from_axis_angle(Vec3::Z, 0.5),
321 Vec3::new(10.0, 20.0, 30.0),
322 );
323 let inv = m.inverse();
324 let result = m * inv;
325 for c in 0..4 {
326 for r in 0..4 {
327 let expected = if c == r { 1.0 } else { 0.0 };
328 assert!(
329 (result.cols[c][r] - expected).abs() < 1e-4,
330 "M*M^-1 [{c}][{r}] = {} (expected {expected})",
331 result.cols[c][r]
332 );
333 }
334 }
335 }
336
337 #[test]
338 fn transform_point() {
339 let m = Mat4::from_translation(Vec3::new(10.0, 0.0, 0.0));
340 let p = Vec3::new(1.0, 2.0, 3.0);
341 let result = m.transform_point3(p);
342 assert_eq!(result, Vec3::new(11.0, 2.0, 3.0));
343 }
344}