1use crate::{Mat4, Quat, Vec3};
2use serde::{Deserialize, Serialize};
3
4#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
7pub struct Transform {
8 pub translation: Vec3,
9 pub rotation: Quat,
10 pub scale: Vec3,
11}
12
13impl Default for Transform {
14 #[inline]
15 fn default() -> Self {
16 Self::IDENTITY
17 }
18}
19
20impl Transform {
21 pub const IDENTITY: Self = Self {
23 translation: Vec3::ZERO,
24 rotation: Quat::IDENTITY,
25 scale: Vec3::ONE,
26 };
27
28 #[inline]
30 pub fn from_translation(translation: Vec3) -> Self {
31 Self {
32 translation,
33 ..Self::IDENTITY
34 }
35 }
36
37 #[inline]
39 pub fn from_rotation(rotation: Quat) -> Self {
40 Self {
41 rotation,
42 ..Self::IDENTITY
43 }
44 }
45
46 #[inline]
48 pub fn from_scale(scale: Vec3) -> Self {
49 Self {
50 scale,
51 ..Self::IDENTITY
52 }
53 }
54
55 #[inline]
57 pub fn to_matrix(self) -> Mat4 {
58 Mat4::from_scale_rotation_translation(self.scale, self.rotation, self.translation)
59 }
60
61 #[inline]
63 pub fn transform_point(self, point: Vec3) -> Vec3 {
64 let scaled = Vec3::new(
65 point.x * self.scale.x,
66 point.y * self.scale.y,
67 point.z * self.scale.z,
68 );
69 let rotated = self.rotation * scaled;
70 rotated + self.translation
71 }
72
73 #[inline]
75 pub fn transform_vector(self, vector: Vec3) -> Vec3 {
76 let scaled = Vec3::new(
77 vector.x * self.scale.x,
78 vector.y * self.scale.y,
79 vector.z * self.scale.z,
80 );
81 self.rotation * scaled
82 }
83
84 #[inline]
86 #[allow(clippy::should_implement_trait)]
87 pub fn mul(self, other: Self) -> Self {
88 Self {
89 translation: self.transform_point(other.translation),
90 rotation: (self.rotation * other.rotation).normalize(),
91 scale: Vec3::new(
92 self.scale.x * other.scale.x,
93 self.scale.y * other.scale.y,
94 self.scale.z * other.scale.z,
95 ),
96 }
97 }
98
99 pub fn inverse(self) -> Self {
101 let mat = self.to_matrix().inverse();
102 let translation = Vec3::new(mat.cols[3][0], mat.cols[3][1], mat.cols[3][2]);
104 let sx = Vec3::new(mat.cols[0][0], mat.cols[0][1], mat.cols[0][2]).length();
106 let sy = Vec3::new(mat.cols[1][0], mat.cols[1][1], mat.cols[1][2]).length();
107 let sz = Vec3::new(mat.cols[2][0], mat.cols[2][1], mat.cols[2][2]).length();
108 let scale = Vec3::new(sx, sy, sz);
109 let rot_mat = Mat4 {
111 cols: [
112 [
113 mat.cols[0][0] / sx,
114 mat.cols[0][1] / sx,
115 mat.cols[0][2] / sx,
116 0.0,
117 ],
118 [
119 mat.cols[1][0] / sy,
120 mat.cols[1][1] / sy,
121 mat.cols[1][2] / sy,
122 0.0,
123 ],
124 [
125 mat.cols[2][0] / sz,
126 mat.cols[2][1] / sz,
127 mat.cols[2][2] / sz,
128 0.0,
129 ],
130 [0.0, 0.0, 0.0, 1.0],
131 ],
132 };
133 let trace = rot_mat.cols[0][0] + rot_mat.cols[1][1] + rot_mat.cols[2][2];
135 let rotation = if trace > 0.0 {
136 let s = (trace + 1.0).sqrt() * 2.0;
137 Quat::from_xyzw(
138 (rot_mat.cols[1][2] - rot_mat.cols[2][1]) / s,
139 (rot_mat.cols[2][0] - rot_mat.cols[0][2]) / s,
140 (rot_mat.cols[0][1] - rot_mat.cols[1][0]) / s,
141 0.25 * s,
142 )
143 } else {
144 Quat::IDENTITY
145 };
146
147 Self {
148 translation,
149 rotation: rotation.normalize(),
150 scale,
151 }
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158 use std::f32::consts::FRAC_PI_2;
159
160 #[test]
161 fn identity_transform() {
162 let t = Transform::IDENTITY;
163 let p = Vec3::new(1.0, 2.0, 3.0);
164 let r = t.transform_point(p);
165 assert!((r.x - p.x).abs() < 1e-6);
166 assert!((r.y - p.y).abs() < 1e-6);
167 assert!((r.z - p.z).abs() < 1e-6);
168 }
169
170 #[test]
171 fn translation_only() {
172 let t = Transform::from_translation(Vec3::new(10.0, 20.0, 30.0));
173 let p = Vec3::new(1.0, 2.0, 3.0);
174 let r = t.transform_point(p);
175 assert_eq!(r, Vec3::new(11.0, 22.0, 33.0));
176 let v = t.transform_vector(p);
178 assert_eq!(v, p);
179 }
180
181 #[test]
182 fn scale_only() {
183 let t = Transform::from_scale(Vec3::new(2.0, 3.0, 4.0));
184 let r = t.transform_point(Vec3::ONE);
185 assert_eq!(r, Vec3::new(2.0, 3.0, 4.0));
186 }
187
188 #[test]
189 fn rotation_only() {
190 let t = Transform::from_rotation(Quat::from_axis_angle(Vec3::Z, FRAC_PI_2));
191 let r = t.transform_point(Vec3::X);
192 assert!(r.x.abs() < 1e-5);
193 assert!((r.y - 1.0).abs() < 1e-5);
194 }
195
196 #[test]
197 fn compose_transforms() {
198 let a = Transform::from_translation(Vec3::new(5.0, 0.0, 0.0));
199 let b = Transform::from_scale(Vec3::new(2.0, 2.0, 2.0));
200 let composed = a.mul(b);
201 let r = composed.transform_point(Vec3::X);
202 assert!((r.x - 7.0).abs() < 1e-5);
203 }
204
205 #[test]
206 fn matrix_consistency() {
207 let t = Transform {
208 translation: Vec3::new(1.0, 2.0, 3.0),
209 rotation: Quat::from_axis_angle(Vec3::Z, 0.5),
210 scale: Vec3::new(2.0, 2.0, 2.0),
211 };
212 let mat = t.to_matrix();
213 let p = Vec3::X;
214
215 let from_transform = t.transform_point(p);
216 let from_matrix = mat.transform_point3(p);
217
218 assert!((from_transform.x - from_matrix.x).abs() < 1e-5);
219 assert!((from_transform.y - from_matrix.y).abs() < 1e-5);
220 assert!((from_transform.z - from_matrix.z).abs() < 1e-5);
221 }
222
223 #[test]
224 fn inverse_roundtrip() {
225 let t = Transform {
226 translation: Vec3::new(1.0, 2.0, 3.0),
227 rotation: Quat::from_axis_angle(Vec3::Y, 0.7),
228 scale: Vec3::new(2.0, 0.5, 3.0),
229 };
230 let p = Vec3::new(4.0, 5.0, 6.0);
231 let forward = t.to_matrix();
232 let inv = forward.inverse();
233 let transformed = forward.transform_point3(p);
234 let result = inv.transform_point3(transformed);
235 assert!((result.x - p.x).abs() < 1e-4);
236 assert!((result.y - p.y).abs() < 1e-4);
237 assert!((result.z - p.z).abs() < 1e-4);
238 }
239}