Skip to main content

euca_math/
transform.rs

1use crate::{Mat4, Quat, Vec3};
2use serde::{Deserialize, Serialize};
3
4/// TRS transform: Translation, Rotation, Scale.
5/// Transformation order: Scale -> Rotate -> Translate.
6#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
7pub struct Transform {
8    pub translation: Vec3,
9    pub rotation: Quat,
10    pub scale: Vec3,
11}
12
13impl Default for Transform {
14    #[inline]
15    fn default() -> Self {
16        Self::IDENTITY
17    }
18}
19
20impl Transform {
21    /// The identity transform (no translation, rotation, or scale).
22    pub const IDENTITY: Self = Self {
23        translation: Vec3::ZERO,
24        rotation: Quat::IDENTITY,
25        scale: Vec3::ONE,
26    };
27
28    /// Creates a transform with only a translation.
29    #[inline]
30    pub fn from_translation(translation: Vec3) -> Self {
31        Self {
32            translation,
33            ..Self::IDENTITY
34        }
35    }
36
37    /// Creates a transform with only a rotation.
38    #[inline]
39    pub fn from_rotation(rotation: Quat) -> Self {
40        Self {
41            rotation,
42            ..Self::IDENTITY
43        }
44    }
45
46    /// Creates a transform with only a scale.
47    #[inline]
48    pub fn from_scale(scale: Vec3) -> Self {
49        Self {
50            scale,
51            ..Self::IDENTITY
52        }
53    }
54
55    /// Convert to a 4x4 matrix (Scale -> Rotate -> Translate).
56    #[inline]
57    pub fn to_matrix(self) -> Mat4 {
58        Mat4::from_scale_rotation_translation(self.scale, self.rotation, self.translation)
59    }
60
61    /// Transform a point (applies scale, rotation, and translation).
62    #[inline]
63    pub fn transform_point(self, point: Vec3) -> Vec3 {
64        let scaled = Vec3::new(
65            point.x * self.scale.x,
66            point.y * self.scale.y,
67            point.z * self.scale.z,
68        );
69        let rotated = self.rotation * scaled;
70        rotated + self.translation
71    }
72
73    /// Transform a direction vector (scale + rotation only, no translation).
74    #[inline]
75    pub fn transform_vector(self, vector: Vec3) -> Vec3 {
76        let scaled = Vec3::new(
77            vector.x * self.scale.x,
78            vector.y * self.scale.y,
79            vector.z * self.scale.z,
80        );
81        self.rotation * scaled
82    }
83
84    /// Compose two transforms.
85    #[inline]
86    #[allow(clippy::should_implement_trait)]
87    pub fn mul(self, other: Self) -> Self {
88        Self {
89            translation: self.transform_point(other.translation),
90            rotation: (self.rotation * other.rotation).normalize(),
91            scale: Vec3::new(
92                self.scale.x * other.scale.x,
93                self.scale.y * other.scale.y,
94                self.scale.z * other.scale.z,
95            ),
96        }
97    }
98
99    /// Compute inverse via matrix decomposition.
100    pub fn inverse(self) -> Self {
101        let mat = self.to_matrix().inverse();
102        // Extract translation from column 3
103        let translation = Vec3::new(mat.cols[3][0], mat.cols[3][1], mat.cols[3][2]);
104        // Extract scale from column lengths
105        let sx = Vec3::new(mat.cols[0][0], mat.cols[0][1], mat.cols[0][2]).length();
106        let sy = Vec3::new(mat.cols[1][0], mat.cols[1][1], mat.cols[1][2]).length();
107        let sz = Vec3::new(mat.cols[2][0], mat.cols[2][1], mat.cols[2][2]).length();
108        let scale = Vec3::new(sx, sy, sz);
109        // Extract rotation (divide columns by scale)
110        let rot_mat = Mat4 {
111            cols: [
112                [
113                    mat.cols[0][0] / sx,
114                    mat.cols[0][1] / sx,
115                    mat.cols[0][2] / sx,
116                    0.0,
117                ],
118                [
119                    mat.cols[1][0] / sy,
120                    mat.cols[1][1] / sy,
121                    mat.cols[1][2] / sy,
122                    0.0,
123                ],
124                [
125                    mat.cols[2][0] / sz,
126                    mat.cols[2][1] / sz,
127                    mat.cols[2][2] / sz,
128                    0.0,
129                ],
130                [0.0, 0.0, 0.0, 1.0],
131            ],
132        };
133        // Extract quaternion from rotation matrix
134        let trace = rot_mat.cols[0][0] + rot_mat.cols[1][1] + rot_mat.cols[2][2];
135        let rotation = if trace > 0.0 {
136            let s = (trace + 1.0).sqrt() * 2.0;
137            Quat::from_xyzw(
138                (rot_mat.cols[1][2] - rot_mat.cols[2][1]) / s,
139                (rot_mat.cols[2][0] - rot_mat.cols[0][2]) / s,
140                (rot_mat.cols[0][1] - rot_mat.cols[1][0]) / s,
141                0.25 * s,
142            )
143        } else {
144            Quat::IDENTITY
145        };
146
147        Self {
148            translation,
149            rotation: rotation.normalize(),
150            scale,
151        }
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use std::f32::consts::FRAC_PI_2;
159
160    #[test]
161    fn identity_transform() {
162        let t = Transform::IDENTITY;
163        let p = Vec3::new(1.0, 2.0, 3.0);
164        let r = t.transform_point(p);
165        assert!((r.x - p.x).abs() < 1e-6);
166        assert!((r.y - p.y).abs() < 1e-6);
167        assert!((r.z - p.z).abs() < 1e-6);
168    }
169
170    #[test]
171    fn translation_only() {
172        let t = Transform::from_translation(Vec3::new(10.0, 20.0, 30.0));
173        let p = Vec3::new(1.0, 2.0, 3.0);
174        let r = t.transform_point(p);
175        assert_eq!(r, Vec3::new(11.0, 22.0, 33.0));
176        // Vectors unaffected by translation
177        let v = t.transform_vector(p);
178        assert_eq!(v, p);
179    }
180
181    #[test]
182    fn scale_only() {
183        let t = Transform::from_scale(Vec3::new(2.0, 3.0, 4.0));
184        let r = t.transform_point(Vec3::ONE);
185        assert_eq!(r, Vec3::new(2.0, 3.0, 4.0));
186    }
187
188    #[test]
189    fn rotation_only() {
190        let t = Transform::from_rotation(Quat::from_axis_angle(Vec3::Z, FRAC_PI_2));
191        let r = t.transform_point(Vec3::X);
192        assert!(r.x.abs() < 1e-5);
193        assert!((r.y - 1.0).abs() < 1e-5);
194    }
195
196    #[test]
197    fn compose_transforms() {
198        let a = Transform::from_translation(Vec3::new(5.0, 0.0, 0.0));
199        let b = Transform::from_scale(Vec3::new(2.0, 2.0, 2.0));
200        let composed = a.mul(b);
201        let r = composed.transform_point(Vec3::X);
202        assert!((r.x - 7.0).abs() < 1e-5);
203    }
204
205    #[test]
206    fn matrix_consistency() {
207        let t = Transform {
208            translation: Vec3::new(1.0, 2.0, 3.0),
209            rotation: Quat::from_axis_angle(Vec3::Z, 0.5),
210            scale: Vec3::new(2.0, 2.0, 2.0),
211        };
212        let mat = t.to_matrix();
213        let p = Vec3::X;
214
215        let from_transform = t.transform_point(p);
216        let from_matrix = mat.transform_point3(p);
217
218        assert!((from_transform.x - from_matrix.x).abs() < 1e-5);
219        assert!((from_transform.y - from_matrix.y).abs() < 1e-5);
220        assert!((from_transform.z - from_matrix.z).abs() < 1e-5);
221    }
222
223    #[test]
224    fn inverse_roundtrip() {
225        let t = Transform {
226            translation: Vec3::new(1.0, 2.0, 3.0),
227            rotation: Quat::from_axis_angle(Vec3::Y, 0.7),
228            scale: Vec3::new(2.0, 0.5, 3.0),
229        };
230        let p = Vec3::new(4.0, 5.0, 6.0);
231        let forward = t.to_matrix();
232        let inv = forward.inverse();
233        let transformed = forward.transform_point3(p);
234        let result = inv.transform_point3(transformed);
235        assert!((result.x - p.x).abs() < 1e-4);
236        assert!((result.y - p.y).abs() < 1e-4);
237        assert!((result.z - p.z).abs() < 1e-4);
238    }
239}