avila_math/geometry/
matrix.rs

1/// 4x4 Matrix for 3D transformations
2///
3/// Column-major matrix representation for compatibility with OpenGL/Vulkan
4use super::vector::{Vector3, Vector4};
5use std::ops::Mul;
6
7/// 4x4 Matrix (column-major)
8#[derive(Debug, Clone, Copy, PartialEq)]
9pub struct Matrix4 {
10    /// Column-major data: columns[col][row]
11    pub data: [[f64; 4]; 4],
12}
13
14impl Matrix4 {
15    /// Create identity matrix
16    #[inline]
17    pub fn identity() -> Self {
18        Self {
19            data: [
20                [1.0, 0.0, 0.0, 0.0],
21                [0.0, 1.0, 0.0, 0.0],
22                [0.0, 0.0, 1.0, 0.0],
23                [0.0, 0.0, 0.0, 1.0],
24            ],
25        }
26    }
27
28    /// Create zero matrix
29    #[inline]
30    pub fn zero() -> Self {
31        Self {
32            data: [[0.0; 4]; 4],
33        }
34    }
35
36    /// Create translation matrix
37    #[inline]
38    pub fn translate(translation: Vector3) -> Self {
39        let mut m = Self::identity();
40        m.data[3][0] = translation.x;
41        m.data[3][1] = translation.y;
42        m.data[3][2] = translation.z;
43        m
44    }
45
46    /// Create uniform scale matrix
47    #[inline]
48    pub fn scale_uniform(scale: f64) -> Self {
49        Self::scale(Vector3::new(scale, scale, scale))
50    }
51
52    /// Create non-uniform scale matrix
53    #[inline]
54    pub fn scale(scale: Vector3) -> Self {
55        let mut m = Self::identity();
56        m.data[0][0] = scale.x;
57        m.data[1][1] = scale.y;
58        m.data[2][2] = scale.z;
59        m
60    }
61
62    /// Create rotation matrix around X axis
63    pub fn rotate_x(angle_rad: f64) -> Self {
64        let c = angle_rad.cos();
65        let s = angle_rad.sin();
66        let mut m = Self::identity();
67        m.data[1][1] = c;
68        m.data[1][2] = s;
69        m.data[2][1] = -s;
70        m.data[2][2] = c;
71        m
72    }
73
74    /// Create rotation matrix around Y axis
75    pub fn rotate_y(angle_rad: f64) -> Self {
76        let c = angle_rad.cos();
77        let s = angle_rad.sin();
78        let mut m = Self::identity();
79        m.data[0][0] = c;
80        m.data[0][2] = -s;
81        m.data[2][0] = s;
82        m.data[2][2] = c;
83        m
84    }
85
86    /// Create rotation matrix around Z axis
87    pub fn rotate_z(angle_rad: f64) -> Self {
88        let c = angle_rad.cos();
89        let s = angle_rad.sin();
90        let mut m = Self::identity();
91        m.data[0][0] = c;
92        m.data[0][1] = s;
93        m.data[1][0] = -s;
94        m.data[1][1] = c;
95        m
96    }
97
98    /// Create look-at view matrix
99    pub fn look_at(eye: Vector3, target: Vector3, up: Vector3) -> Self {
100        let f = (target - eye).normalize();
101        let r = f.cross(up).normalize();
102        let u = r.cross(f);
103
104        let mut m = Self::identity();
105        m.data[0][0] = r.x;
106        m.data[1][0] = r.y;
107        m.data[2][0] = r.z;
108        m.data[0][1] = u.x;
109        m.data[1][1] = u.y;
110        m.data[2][1] = u.z;
111        m.data[0][2] = -f.x;
112        m.data[1][2] = -f.y;
113        m.data[2][2] = -f.z;
114        m.data[3][0] = -r.dot(eye);
115        m.data[3][1] = -u.dot(eye);
116        m.data[3][2] = f.dot(eye);
117        m
118    }
119
120    /// Create perspective projection matrix
121    pub fn perspective(fov_y_rad: f64, aspect: f64, near: f64, far: f64) -> Self {
122        let tan_half_fovy = (fov_y_rad / 2.0).tan();
123        let mut m = Self::zero();
124        m.data[0][0] = 1.0 / (aspect * tan_half_fovy);
125        m.data[1][1] = 1.0 / tan_half_fovy;
126        m.data[2][2] = -(far + near) / (far - near);
127        m.data[2][3] = -1.0;
128        m.data[3][2] = -(2.0 * far * near) / (far - near);
129        m
130    }
131
132    /// Create orthographic projection matrix
133    pub fn orthographic(left: f64, right: f64, bottom: f64, top: f64, near: f64, far: f64) -> Self {
134        let mut m = Self::identity();
135        m.data[0][0] = 2.0 / (right - left);
136        m.data[1][1] = 2.0 / (top - bottom);
137        m.data[2][2] = -2.0 / (far - near);
138        m.data[3][0] = -(right + left) / (right - left);
139        m.data[3][1] = -(top + bottom) / (top - bottom);
140        m.data[3][2] = -(far + near) / (far - near);
141        m
142    }
143
144    /// Transform a point (w=1)
145    pub fn transform_point(&self, point: Vector3) -> Vector3 {
146        let v = Vector4::new(point.x, point.y, point.z, 1.0);
147        let transformed = self.transform_vector4(v);
148
149        // Perspective division
150        if transformed.w != 0.0 {
151            Vector3::new(
152                transformed.x / transformed.w,
153                transformed.y / transformed.w,
154                transformed.z / transformed.w,
155            )
156        } else {
157            transformed.to_vector3()
158        }
159    }
160
161    /// Transform a direction (w=0)
162    pub fn transform_direction(&self, dir: Vector3) -> Vector3 {
163        let v = Vector4::new(dir.x, dir.y, dir.z, 0.0);
164        self.transform_vector4(v).to_vector3()
165    }
166
167    /// Transform a Vector4
168    pub fn transform_vector4(&self, v: Vector4) -> Vector4 {
169        Vector4::new(
170            self.data[0][0] * v.x
171                + self.data[1][0] * v.y
172                + self.data[2][0] * v.z
173                + self.data[3][0] * v.w,
174            self.data[0][1] * v.x
175                + self.data[1][1] * v.y
176                + self.data[2][1] * v.z
177                + self.data[3][1] * v.w,
178            self.data[0][2] * v.x
179                + self.data[1][2] * v.y
180                + self.data[2][2] * v.z
181                + self.data[3][2] * v.w,
182            self.data[0][3] * v.x
183                + self.data[1][3] * v.y
184                + self.data[2][3] * v.z
185                + self.data[3][3] * v.w,
186        )
187    }
188
189    /// Transpose matrix
190    pub fn transpose(&self) -> Self {
191        let mut result = Self::zero();
192        for i in 0..4 {
193            for j in 0..4 {
194                result.data[i][j] = self.data[j][i];
195            }
196        }
197        result
198    }
199
200    /// Determinant of matrix
201    pub fn determinant(&self) -> f64 {
202        let m = &self.data;
203
204        let a = m[0][0]
205            * (m[1][1] * (m[2][2] * m[3][3] - m[2][3] * m[3][2])
206                - m[1][2] * (m[2][1] * m[3][3] - m[2][3] * m[3][1])
207                + m[1][3] * (m[2][1] * m[3][2] - m[2][2] * m[3][1]));
208
209        let b = m[0][1]
210            * (m[1][0] * (m[2][2] * m[3][3] - m[2][3] * m[3][2])
211                - m[1][2] * (m[2][0] * m[3][3] - m[2][3] * m[3][0])
212                + m[1][3] * (m[2][0] * m[3][2] - m[2][2] * m[3][0]));
213
214        let c = m[0][2]
215            * (m[1][0] * (m[2][1] * m[3][3] - m[2][3] * m[3][1])
216                - m[1][1] * (m[2][0] * m[3][3] - m[2][3] * m[3][0])
217                + m[1][3] * (m[2][0] * m[3][1] - m[2][1] * m[3][0]));
218
219        let d = m[0][3]
220            * (m[1][0] * (m[2][1] * m[3][2] - m[2][2] * m[3][1])
221                - m[1][1] * (m[2][0] * m[3][2] - m[2][2] * m[3][0])
222                + m[1][2] * (m[2][0] * m[3][1] - m[2][1] * m[3][0]));
223
224        a - b + c - d
225    }
226}
227
228/// Matrix multiplication
229impl Mul for Matrix4 {
230    type Output = Self;
231
232    fn mul(self, rhs: Self) -> Self {
233        let mut result = Self::zero();
234        for i in 0..4 {
235            for j in 0..4 {
236                for k in 0..4 {
237                    result.data[i][j] += self.data[k][j] * rhs.data[i][k];
238                }
239            }
240        }
241        result
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    #[test]
250    fn test_identity() {
251        let m = Matrix4::identity();
252        let v = Vector3::new(1.0, 2.0, 3.0);
253        let transformed = m.transform_point(v);
254        assert!((transformed.x - v.x).abs() < 1e-10);
255        assert!((transformed.y - v.y).abs() < 1e-10);
256        assert!((transformed.z - v.z).abs() < 1e-10);
257    }
258
259    #[test]
260    fn test_translation() {
261        let m = Matrix4::translate(Vector3::new(10.0, 5.0, -3.0));
262        let v = Vector3::new(1.0, 2.0, 3.0);
263        let transformed = m.transform_point(v);
264        assert!((transformed.x - 11.0).abs() < 1e-10);
265        assert!((transformed.y - 7.0).abs() < 1e-10);
266        assert!((transformed.z - 0.0).abs() < 1e-10);
267    }
268
269    #[test]
270    fn test_scale() {
271        let m = Matrix4::scale_uniform(2.0);
272        let v = Vector3::new(1.0, 2.0, 3.0);
273        let transformed = m.transform_point(v);
274        assert!((transformed.x - 2.0).abs() < 1e-10);
275        assert!((transformed.y - 4.0).abs() < 1e-10);
276        assert!((transformed.z - 6.0).abs() < 1e-10);
277    }
278
279    #[test]
280    fn test_matrix_multiplication() {
281        let t = Matrix4::translate(Vector3::new(1.0, 0.0, 0.0));
282        let s = Matrix4::scale_uniform(2.0);
283        let combined = t * s;
284
285        let v = Vector3::new(1.0, 1.0, 1.0);
286        let result = combined.transform_point(v);
287        assert!((result.x - 3.0).abs() < 1e-10);
288        assert!((result.y - 2.0).abs() < 1e-10);
289    }
290}