est_render/math/
matrix.rs

1use std::ops::{Add, Mul, Sub};
2
3use bytemuck::{Pod, Zeroable};
4use num_traits::ToPrimitive;
5
6use super::{Vector2, Vector3, Vector4};
7
8#[derive(Debug, Clone, Copy, Default, Pod, Zeroable)]
9#[repr(C)]
10pub struct Matrix4 {
11    pub m: [[f32; 4]; 4],
12}
13
14impl Matrix4 {
15    pub fn new() -> Self {
16        Self {
17            m: [
18                [0.0, 0.0, 0.0, 0.0],
19                [0.0, 0.0, 0.0, 0.0],
20                [0.0, 0.0, 0.0, 0.0],
21                [0.0, 0.0, 0.0, 0.0],
22            ],
23        }
24    }
25
26    pub fn identity() -> Self {
27        Self {
28            m: [
29                [1.0, 0.0, 0.0, 0.0],
30                [0.0, 1.0, 0.0, 0.0],
31                [0.0, 0.0, 1.0, 0.0],
32                [0.0, 0.0, 0.0, 1.0],
33            ],
34        }
35    }
36
37    pub fn look_at(eye: Vector3, target: Vector3, up: Vector3) -> Self {
38        let f = (target - eye).normalize();
39        let s = f.cross(&up.normalize()).normalize();
40        let u = s.cross(&f);
41
42        Self {
43            m: [
44                [s.x, s.y, s.z, -s.dot(&eye)],
45                [u.x, u.y, u.z, -u.dot(&eye)],
46                [-f.x, -f.y, -f.z, f.dot(&eye)],
47                [0.0, 0.0, 0.0, 1.0],
48            ],
49        }
50    }
51
52    pub fn frustum<T: ToPrimitive>(left: T, right: T, bottom: T, top: T, near: T, far: T) -> Self {
53        let left = left.to_f32().unwrap();
54        let right = right.to_f32().unwrap();
55        let bottom = bottom.to_f32().unwrap();
56        let top = top.to_f32().unwrap();
57        let near = near.to_f32().unwrap();
58        let far = far.to_f32().unwrap();
59
60        let rl = 1.0 / (right - left);
61        let bt = 1.0 / (top - bottom);
62        let nf = 1.0 / (near - far);
63
64        Self {
65            m: [
66                [2.0 * near * rl, 0.0, 0.0, 0.0],
67                [0.0, 2.0 * near * bt, 0.0, 0.0],
68                [
69                    (right + left) * rl,
70                    (top + bottom) * bt,
71                    (far + near) * nf,
72                    -1.0,
73                ],
74                [0.0, 0.0, 2.0 * far * near * nf, 0.0],
75            ],
76        }
77    }
78
79    pub fn perspective<T: ToPrimitive>(fov: T, aspect: T, near: T, far: T) -> Self {
80        let fov = fov.to_f32().unwrap();
81        let aspect = aspect.to_f32().unwrap();
82        let near = near.to_f32().unwrap();
83        let far = far.to_f32().unwrap();
84
85        let f = 1.0 / (fov / 2.0).tan();
86        let nf = 1.0 / (near - far);
87
88        Self {
89            m: [
90                [f / aspect, 0.0, 0.0, 0.0],
91                [0.0, f, 0.0, 0.0],
92                [0.0, 0.0, (far + near) * nf, 2.0 * far * near * nf],
93                [0.0, 0.0, -1.0, 0.0],
94            ],
95        }
96    }
97
98    pub fn translate<T: ToPrimitive>(x: T, y: T, z: T) -> Self {
99        let x = x.to_f32().unwrap();
100        let y = y.to_f32().unwrap();
101        let z = z.to_f32().unwrap();
102
103        Self {
104            m: [
105                [1.0, 0.0, 0.0, x],
106                [0.0, 1.0, 0.0, y],
107                [0.0, 0.0, 1.0, z],
108                [0.0, 0.0, 0.0, 1.0],
109            ],
110        }
111    }
112
113    pub fn scale<T: ToPrimitive>(x: T, y: T, z: T) -> Self {
114        let x = x.to_f32().unwrap();
115        let y = y.to_f32().unwrap();
116        let z = z.to_f32().unwrap();
117
118        Self {
119            m: [
120                [x, 0.0, 0.0, 0.0],
121                [0.0, y, 0.0, 0.0],
122                [0.0, 0.0, z, 0.0],
123                [0.0, 0.0, 0.0, 1.0],
124            ],
125        }
126    }
127
128    pub fn orthographic<T: ToPrimitive>(
129        left: T,
130        right: T,
131        bottom: T,
132        top: T,
133        near: T,
134        far: T,
135    ) -> Self {
136        let left = left.to_f32().unwrap();
137        let right = right.to_f32().unwrap();
138        let bottom = bottom.to_f32().unwrap();
139        let top = top.to_f32().unwrap();
140        let near = near.to_f32().unwrap();
141        let far = far.to_f32().unwrap();
142
143        let lr = 1.0 / (left - right);
144        let bt = 1.0 / (bottom - top);
145        let nf = 1.0 / (near - far);
146
147        Self {
148            m: [
149                [-2.0 * lr, 0.0, 0.0, (left + right) * lr],
150                [0.0, -2.0 * bt, 0.0, (top + bottom) * bt],
151                [0.0, 0.0, 2.0 * nf, (far + near) * nf],
152                [0.0, 0.0, 0.0, 1.0],
153            ],
154        }
155    }
156
157    pub fn rotate<T: ToPrimitive>(angle: T, x: T, y: T, z: T) -> Self {
158        let angle = angle.to_f32().unwrap();
159        let x = x.to_f32().unwrap();
160        let y = y.to_f32().unwrap();
161        let z = z.to_f32().unwrap();
162
163        let c = angle.cos();
164        let s = angle.sin();
165        let len = (x * x + y * y + z * z).sqrt();
166        let (x, y, z) = if len == 0.0 {
167            (1.0, 0.0, 0.0)
168        } else {
169            (x / len, y / len, z / len)
170        };
171        let omc = 1.0 - c;
172
173        Self {
174            m: [
175                [
176                    x * x * omc + c,
177                    x * y * omc - z * s,
178                    x * z * omc + y * s,
179                    0.0,
180                ],
181                [
182                    y * x * omc + z * s,
183                    y * y * omc + c,
184                    y * z * omc - x * s,
185                    0.0,
186                ],
187                [
188                    z * x * omc - y * s,
189                    z * y * omc + x * s,
190                    z * z * omc + c,
191                    0.0,
192                ],
193                [0.0, 0.0, 0.0, 1.0],
194            ],
195        }
196    }
197
198    pub fn transform_point(&self, point: Vector3) -> Vector3 {
199        let mut result = Vector3::new(0.0, 0.0, 0.0);
200
201        result.x = self.m[0][0] * point.x + self.m[0][1] * point.y + self.m[0][2] * point.z;
202        result.y = self.m[1][0] * point.x + self.m[1][1] * point.y + self.m[1][2] * point.z;
203        result.z = self.m[2][0] * point.x + self.m[2][1] * point.y + self.m[2][2] * point.z;
204
205        result
206    }
207
208    pub unsafe fn address_of(&self) -> *const f32 {
209        &self.m[0][0] as *const f32
210    }
211
212    pub fn get_fov(&self) -> f32 {
213        let f = self.m[1][1];
214        1.0 / f.atan() * 2.0
215    }
216
217    pub fn get_aspect(&self) -> f32 {
218        self.m[0][0] / self.m[1][1]
219    }
220
221    pub fn get_near(&self) -> f32 {
222        let nf = 1.0 / self.m[2][2];
223        (2.0 * self.m[3][2]) / (self.m[2][2] - nf)
224    }
225
226    pub fn inverse(&self) -> Matrix4 {
227        let m = &self.m;
228
229        let mut inv = [[0.0; 4]; 4];
230
231        inv[0][0] =
232            m[1][1] * m[2][2] * m[3][3] - m[1][1] * m[2][3] * m[3][2] - m[2][1] * m[1][2] * m[3][3]
233                + m[2][1] * m[1][3] * m[3][2]
234                + m[3][1] * m[1][2] * m[2][3]
235                - m[3][1] * m[1][3] * m[2][2];
236        inv[0][1] = -m[0][1] * m[2][2] * m[3][3]
237            + m[0][1] * m[2][3] * m[3][2]
238            + m[2][1] * m[0][2] * m[3][3]
239            - m[2][1] * m[0][3] * m[3][2]
240            - m[3][1] * m[0][2] * m[2][3]
241            + m[3][1] * m[0][3] * m[2][2];
242        inv[0][2] =
243            m[0][1] * m[1][2] * m[3][3] - m[0][1] * m[1][3] * m[3][2] - m[1][1] * m[0][2] * m[3][3]
244                + m[1][1] * m[0][3] * m[3][2]
245                + m[3][1] * m[0][2] * m[1][3]
246                - m[3][1] * m[0][3] * m[1][2];
247        inv[0][3] = -m[0][1] * m[1][2] * m[2][3]
248            + m[0][1] * m[1][3] * m[2][2]
249            + m[1][1] * m[0][2] * m[2][3]
250            - m[1][1] * m[0][3] * m[2][2]
251            - m[2][1] * m[0][2] * m[1][3]
252            + m[2][1] * m[0][3] * m[1][2];
253
254        inv[1][0] = -m[1][0] * m[2][2] * m[3][3]
255            + m[1][0] * m[2][3] * m[3][2]
256            + m[2][0] * m[1][2] * m[3][3]
257            - m[2][0] * m[1][3] * m[3][2]
258            - m[3][0] * m[1][2] * m[2][3]
259            + m[3][0] * m[1][3] * m[2][2];
260        inv[1][1] =
261            m[0][0] * m[2][2] * m[3][3] - m[0][0] * m[2][3] * m[3][2] - m[2][0] * m[0][2] * m[3][3]
262                + m[2][0] * m[0][3] * m[3][2]
263                + m[3][0] * m[0][2] * m[2][3]
264                - m[3][0] * m[0][3] * m[2][2];
265        inv[1][2] = -m[0][0] * m[1][2] * m[3][3]
266            + m[0][0] * m[1][3] * m[3][2]
267            + m[1][0] * m[0][2] * m[3][3]
268            - m[1][0] * m[0][3] * m[3][2]
269            - m[3][0] * m[0][2] * m[1][3]
270            + m[3][0] * m[0][3] * m[1][2];
271        inv[1][3] =
272            m[0][0] * m[1][2] * m[2][3] - m[0][0] * m[1][3] * m[2][2] - m[1][0] * m[0][2] * m[2][3]
273                + m[1][0] * m[0][3] * m[2][2]
274                + m[2][0] * m[0][2] * m[1][3]
275                - m[2][0] * m[0][3] * m[1][2];
276
277        inv[2][0] =
278            m[1][0] * m[2][1] * m[3][3] - m[1][0] * m[2][3] * m[3][1] - m[2][0] * m[1][1] * m[3][3]
279                + m[2][0] * m[1][3] * m[3][1]
280                + m[3][0] * m[1][1] * m[2][3]
281                - m[3][0] * m[1][3] * m[2][1];
282        inv[2][1] = -m[0][0] * m[2][1] * m[3][3]
283            + m[0][0] * m[2][3] * m[3][1]
284            + m[2][0] * m[0][1] * m[3][3]
285            - m[2][0] * m[0][3] * m[3][1]
286            - m[3][0] * m[0][1] * m[2][3]
287            + m[3][0] * m[0][3] * m[2][1];
288        inv[2][2] =
289            m[0][0] * m[1][1] * m[3][3] - m[0][0] * m[1][3] * m[3][1] - m[1][0] * m[0][1] * m[3][3]
290                + m[1][0] * m[0][3] * m[3][1]
291                + m[3][0] * m[0][1] * m[1][3]
292                - m[3][0] * m[0][3] * m[1][1];
293        inv[2][3] = -m[0][0] * m[1][1] * m[2][3]
294            + m[0][0] * m[1][3] * m[2][1]
295            + m[1][0] * m[0][1] * m[2][3]
296            - m[1][0] * m[0][3] * m[2][1]
297            - m[2][0] * m[0][1] * m[1][3]
298            + m[2][0] * m[0][3] * m[1][1];
299
300        inv[3][0] = -m[1][0] * m[2][1] * m[3][2]
301            + m[1][0] * m[2][2] * m[3][1]
302            + m[2][0] * m[1][1] * m[3][2]
303            - m[2][0] * m[1][2] * m[3][1]
304            - m[3][0] * m[1][1] * m[2][2]
305            + m[3][0] * m[1][2] * m[2][1];
306        inv[3][1] =
307            m[0][0] * m[2][1] * m[3][2] - m[0][0] * m[2][2] * m[3][1] - m[2][0] * m[0][1] * m[3][2]
308                + m[2][0] * m[0][2] * m[3][1]
309                + m[3][0] * m[0][1] * m[2][2]
310                - m[3][0] * m[0][2] * m[2][1];
311        inv[3][2] = -m[0][0] * m[1][1] * m[3][2]
312            + m[0][0] * m[1][2] * m[3][1]
313            + m[1][0] * m[0][1] * m[3][2]
314            - m[1][0] * m[0][2] * m[3][1]
315            - m[3][0] * m[0][1] * m[1][2]
316            + m[3][0] * m[0][2] * m[1][1];
317        inv[3][3] =
318            m[0][0] * m[1][1] * m[2][2] - m[0][0] * m[1][2] * m[2][1] - m[1][0] * m[0][1] * m[2][2]
319                + m[1][0] * m[0][2] * m[2][1]
320                + m[2][0] * m[0][1] * m[1][2]
321                - m[2][0] * m[0][2] * m[1][1];
322
323        let det =
324            m[0][0] * inv[0][0] + m[0][1] * inv[1][0] + m[0][2] * inv[2][0] + m[0][3] * inv[3][0];
325
326        if det == 0.0 {
327            return Matrix4::identity();
328        }
329
330        let det = 1.0 / det;
331
332        for i in 0..4 {
333            for j in 0..4 {
334                inv[i][j] *= det;
335            }
336        }
337
338        Matrix4 { m: inv }
339    }
340
341    pub const OPENGL_TO_WGPU_MATRIX: Self = Self {
342        m: [
343            [1.0, 0.0, 0.0, 0.0],
344            [0.0, 1.0, 0.0, 0.0],
345            [0.0, 0.0, 0.5, 0.5],
346            [0.0, 0.0, 0.0, 1.0],
347        ],
348    };
349}
350
351impl PartialEq for Matrix4 {
352    fn eq(&self, other: &Self) -> bool {
353        self.m[0] == other.m[0]
354            && self.m[1] == other.m[1]
355            && self.m[2] == other.m[2]
356            && self.m[3] == other.m[3]
357    }
358}
359
360impl Eq for Matrix4 {}
361
362impl Mul for Matrix4 {
363    type Output = Self;
364
365    fn mul(self, rhs: Self) -> Self {
366        let mut result = Self::new();
367
368        for i in 0..4 {
369            for j in 0..4 {
370                result.m[i][j] = self.m[i][0] * rhs.m[0][j]
371                    + self.m[i][1] * rhs.m[1][j]
372                    + self.m[i][2] * rhs.m[2][j]
373                    + self.m[i][3] * rhs.m[3][j];
374            }
375        }
376
377        result
378    }
379}
380
381impl Add for Matrix4 {
382    type Output = Self;
383
384    fn add(self, rhs: Self) -> Self {
385        let mut result = Self::new();
386
387        for i in 0..4 {
388            for j in 0..4 {
389                result.m[i][j] = self.m[i][j] + rhs.m[i][j];
390            }
391        }
392
393        result
394    }
395}
396
397impl Sub for Matrix4 {
398    type Output = Self;
399
400    fn sub(self, rhs: Self) -> Self {
401        let mut result = Self::new();
402
403        for i in 0..4 {
404            for j in 0..4 {
405                result.m[i][j] = self.m[i][j] - rhs.m[i][j];
406            }
407        }
408
409        result
410    }
411}
412
413impl Mul<Vector2> for Matrix4 {
414    type Output = Vector2;
415
416    fn mul(self, rhs: Vector2) -> Vector2 {
417        let mut result = Vector2::new(0.0, 0.0);
418
419        result.x = self.m[0][0] * rhs.x + self.m[0][1] * rhs.y + self.m[0][3];
420        result.y = self.m[1][0] * rhs.x + self.m[1][1] * rhs.y + self.m[1][3];
421
422        result
423    }
424}
425
426impl Mul<Vector3> for Matrix4 {
427    type Output = Vector3;
428
429    fn mul(self, rhs: Vector3) -> Vector3 {
430        let mut result = Vector3::new(0.0, 0.0, 0.0);
431
432        result.x =
433            self.m[0][0] * rhs.x + self.m[0][1] * rhs.y + self.m[0][2] * rhs.z + self.m[0][3];
434        result.y =
435            self.m[1][0] * rhs.x + self.m[1][1] * rhs.y + self.m[1][2] * rhs.z + self.m[1][3];
436        result.z =
437            self.m[2][0] * rhs.x + self.m[2][1] * rhs.y + self.m[2][2] * rhs.z + self.m[2][3];
438
439        result
440    }
441}
442
443impl Mul<Vector4> for Matrix4 {
444    type Output = Vector4;
445
446    fn mul(self, rhs: Vector4) -> Vector4 {
447        let mut result = Vector4::new(0.0, 0.0, 0.0, 0.0);
448
449        result.x = self.m[0][0] * rhs.x
450            + self.m[0][1] * rhs.y
451            + self.m[0][2] * rhs.z
452            + self.m[0][3] * rhs.w;
453        result.y = self.m[1][0] * rhs.x
454            + self.m[1][1] * rhs.y
455            + self.m[1][2] * rhs.z
456            + self.m[1][3] * rhs.w;
457        result.z = self.m[2][0] * rhs.x
458            + self.m[2][1] * rhs.y
459            + self.m[2][2] * rhs.z
460            + self.m[2][3] * rhs.w;
461        result.w = self.m[3][0] * rhs.x
462            + self.m[3][1] * rhs.y
463            + self.m[3][2] * rhs.z
464            + self.m[3][3] * rhs.w;
465
466        result
467    }
468}