1use serde::{Deserialize, Serialize};
2
3use crate::quicksilver_compat::about_equal;
4use crate::{Scalar, Vector};
5use std::{
6 cmp::{Eq, PartialEq},
7 default::Default,
8 f32::consts::PI,
9 fmt,
10 ops::Mul,
11};
12
13#[derive(Clone, Copy, Debug, Deserialize, Serialize)]
37pub struct Transform([f32; 9]);
38
39impl Transform {
40 pub const IDENTITY: Transform =
42 Transform::from_array([[1f32, 0f32, 0f32], [0f32, 1f32, 0f32], [0f32, 0f32, 1f32]]);
43
44 pub const fn from_array(array: [[f32; 3]; 3]) -> Transform {
46 Transform([
47 array[0][0],
48 array[0][1],
49 array[0][2],
50 array[1][0],
51 array[1][1],
52 array[1][2],
53 array[2][0],
54 array[2][1],
55 array[2][2],
56 ])
57 }
58
59 pub fn rotate<T: Scalar>(angle: T) -> Transform {
61 let angle = angle.float();
62 let c = (angle * PI / 180f32).cos();
63 let s = (angle * PI / 180f32).sin();
64 Transform::from_array([[c, -s, 0f32], [s, c, 0f32], [0f32, 0f32, 1f32]])
65 }
66
67 pub fn translate(vec: impl Into<Vector>) -> Transform {
69 let vec = vec.into();
70 Transform::from_array([[1f32, 0f32, vec.x], [0f32, 1f32, vec.y], [0f32, 0f32, 1f32]])
71 }
72
73 pub fn scale(vec: impl Into<Vector>) -> Transform {
75 let vec = vec.into();
76 Transform::from_array([[vec.x, 0f32, 0f32], [0f32, vec.y, 0f32], [0f32, 0f32, 1f32]])
77 }
78
79 pub fn horizontal_flip() -> Transform {
80 Transform::from_array([[-1f32, 0f32, 0f32], [0f32, 1f32, 0f32], [0f32, 0f32, 1f32]])
81 }
82
83 pub fn vertical_flip() -> Transform {
84 Transform::from_array([[1f32, 0f32, 0f32], [0f32, -1f32, 0f32], [0f32, 0f32, 1f32]])
85 }
86
87 pub fn as_slice(&self) -> &[f32] {
88 &self.0
89 }
90 pub fn row_major(&self) -> Vec<f32> {
91 vec![
92 self.0[0], self.0[3], self.0[6], self.0[1], self.0[4], self.0[7], self.0[2], self.0[5],
93 self.0[8],
94 ]
95 }
96
97 #[must_use]
109 pub fn inverse(&self) -> Transform {
110 let det = self.0[0] * (self.0[4] * self.0[8] - self.0[7] * self.0[5])
111 - self.0[1] * (self.0[3] * self.0[8] - self.0[5] * self.0[6])
112 + self.0[2] * (self.0[3] * self.0[7] - self.0[4] * self.0[6]);
113
114 let inv_det = det.recip();
115
116 let mut inverse = Transform::IDENTITY;
117 inverse.0[0] = self.0[4] * self.0[8] - self.0[7] * self.0[5];
118 inverse.0[1] = self.0[2] * self.0[7] - self.0[1] * self.0[8];
119 inverse.0[2] = self.0[1] * self.0[5] - self.0[2] * self.0[4];
120 inverse.0[3] = self.0[5] * self.0[6] - self.0[3] * self.0[8];
121 inverse.0[4] = self.0[0] * self.0[8] - self.0[2] * self.0[6];
122 inverse.0[5] = self.0[3] * self.0[2] - self.0[0] * self.0[5];
123 inverse.0[6] = self.0[3] * self.0[7] - self.0[6] * self.0[4];
124 inverse.0[7] = self.0[6] * self.0[1] - self.0[0] * self.0[7];
125 inverse.0[8] = self.0[0] * self.0[4] - self.0[3] * self.0[1];
126 inverse * inv_det
127 }
128}
129
130impl Mul<Transform> for Transform {
132 type Output = Transform;
133
134 fn mul(self, other: Transform) -> Transform {
135 let mut returnval = Transform::IDENTITY;
136 for i in 0..3 {
137 for j in 0..3 {
138 returnval.0[i * 3 + j] = 0f32;
139 for k in 0..3 {
140 returnval.0[i * 3 + j] += other.0[k * 3 + j] * self.0[i * 3 + k];
141 }
142 }
143 }
144 returnval
145 }
146}
147
148impl Mul<Vector> for Transform {
150 type Output = Vector;
151
152 fn mul(self, other: Vector) -> Vector {
153 Vector::new(
154 other.x * self.0[0] + other.y * self.0[1] + self.0[2],
155 other.x * self.0[3] + other.y * self.0[4] + self.0[5],
156 )
157 }
158}
159impl Mul<Transform> for Vector {
160 type Output = Vector;
161
162 fn mul(self, t: Transform) -> Vector {
163 Vector::new(
164 self.x * t.0[0] + self.y * t.0[3] + t.0[6],
165 self.x * t.0[1] + self.y * t.0[4] + t.0[7],
166 )
167 }
168}
169
170impl<T: Scalar> Mul<T> for Transform {
175 type Output = Transform;
176
177 fn mul(self, other: T) -> Transform {
178 let other = other.float();
179 let mut ret = Transform::IDENTITY;
180 for i in 0..3 {
181 for j in 0..3 {
182 ret.0[i * 3 + j] = self.0[i * 3 + j] * other;
183 }
184 }
185 ret
186 }
187}
188
189impl fmt::Display for Transform {
190 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
191 write!(f, "[")?;
192 for i in 0..3 {
193 for j in 0..3 {
194 write!(f, "{},", self.0[i * 3 + j])?;
195 }
196 write!(f, "\n")?;
197 }
198 write!(f, "]")
199 }
200}
201
202impl Default for Transform {
203 fn default() -> Transform {
204 Transform::IDENTITY
205 }
206}
207
208impl PartialEq for Transform {
209 fn eq(&self, other: &Transform) -> bool {
210 for i in 0..3 {
211 for j in 0..3 {
212 if !about_equal(self.0[i * 3 + j], other.0[i * 3 + j]) {
213 return false;
214 }
215 }
216 }
217 true
218 }
219}
220
221impl Eq for Transform {}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226
227 #[test]
228 fn equality() {
229 assert_eq!(Transform::IDENTITY, Transform::IDENTITY);
230 assert_eq!(Transform::rotate(5), Transform::rotate(5));
231 }
232
233 #[test]
234 fn inverse() {
235 let vec = Vector::new(2, 4);
236 let translate = Transform::scale(Vector::ONE * 0.5);
237 let inverse = translate.inverse();
238 let transformed = inverse * vec;
239 let expected = vec * 2;
240 assert_eq!(transformed, expected);
241 }
242
243 #[test]
244 fn scale() {
245 let trans = Transform::scale(Vector::ONE * 2);
246 let vec = Vector::new(2, 5);
247 let scaled = trans * vec;
248 let expected = vec * 2;
249 assert_eq!(scaled, expected);
250 }
251
252 #[test]
253 fn translate() {
254 let translate = Vector::new(3, 4);
255 let trans = Transform::translate(translate);
256 let vec = Vector::ONE;
257 let translated = trans * vec;
258 let expected = vec + translate;
259 assert_eq!(translated, expected);
260 }
261
262 #[test]
263 fn identity() {
264 let trans = Transform::IDENTITY
265 * Transform::translate(Vector::ZERO)
266 * Transform::rotate(0f32)
267 * Transform::scale(Vector::ONE);
268 let vec = Vector::new(15, 12);
269 assert_eq!(vec, trans * vec);
270 }
271
272 #[test]
273 fn complex_inverse() {
274 let a = Transform::rotate(5f32)
275 * Transform::scale(Vector::new(0.2, 1.23))
276 * Transform::translate(Vector::ONE * 100f32);
277 let a_inv = a.inverse();
278 let vec = Vector::new(120f32, 151f32);
279 assert_eq!(vec, a * a_inv * vec);
280 assert_eq!(vec, a_inv * a * vec);
281 }
282}