Skip to main content

astrelis_geometry/
transform.rs

1//! 2D affine transformations.
2//!
3//! Provides a 2D transform matrix for translation, rotation, scaling, and skewing.
4
5use glam::{Mat3, Vec2};
6
7/// A 2D affine transformation matrix.
8///
9/// Internally uses a 3x3 matrix for affine transforms.
10/// The last row is always [0, 0, 1].
11#[derive(Debug, Clone, Copy, PartialEq)]
12pub struct Transform2D {
13    matrix: Mat3,
14}
15
16impl Default for Transform2D {
17    fn default() -> Self {
18        Self::IDENTITY
19    }
20}
21
22impl Transform2D {
23    /// Identity transform (no transformation).
24    pub const IDENTITY: Self = Self {
25        matrix: Mat3::IDENTITY,
26    };
27
28    /// Create from a 3x3 matrix.
29    pub fn from_mat3(matrix: Mat3) -> Self {
30        Self { matrix }
31    }
32
33    /// Create a translation transform.
34    pub fn translate(offset: Vec2) -> Self {
35        Self {
36            matrix: Mat3::from_translation(offset),
37        }
38    }
39
40    /// Create a rotation transform (angle in radians).
41    pub fn rotate(angle: f32) -> Self {
42        Self {
43            matrix: Mat3::from_angle(angle),
44        }
45    }
46
47    /// Create a uniform scale transform.
48    pub fn scale(factor: f32) -> Self {
49        Self {
50            matrix: Mat3::from_scale(Vec2::splat(factor)),
51        }
52    }
53
54    /// Create a non-uniform scale transform.
55    pub fn scale_xy(scale: Vec2) -> Self {
56        Self {
57            matrix: Mat3::from_scale(scale),
58        }
59    }
60
61    /// Create a skew transform.
62    ///
63    /// `skew_x` is the horizontal skew angle in radians.
64    /// `skew_y` is the vertical skew angle in radians.
65    pub fn skew(skew_x: f32, skew_y: f32) -> Self {
66        Self {
67            matrix: Mat3::from_cols(
68                glam::Vec3::new(1.0, skew_y.tan(), 0.0),
69                glam::Vec3::new(skew_x.tan(), 1.0, 0.0),
70                glam::Vec3::new(0.0, 0.0, 1.0),
71            ),
72        }
73    }
74
75    /// Combine two transforms (self then other).
76    pub fn then(&self, other: &Transform2D) -> Self {
77        Self {
78            matrix: other.matrix * self.matrix,
79        }
80    }
81
82    /// Add a translation after this transform.
83    pub fn then_translate(&self, offset: Vec2) -> Self {
84        self.then(&Transform2D::translate(offset))
85    }
86
87    /// Add a rotation after this transform.
88    pub fn then_rotate(&self, angle: f32) -> Self {
89        self.then(&Transform2D::rotate(angle))
90    }
91
92    /// Add a scale after this transform.
93    pub fn then_scale(&self, factor: f32) -> Self {
94        self.then(&Transform2D::scale(factor))
95    }
96
97    /// Add a non-uniform scale after this transform.
98    pub fn then_scale_xy(&self, scale: Vec2) -> Self {
99        self.then(&Transform2D::scale_xy(scale))
100    }
101
102    /// Transform a point.
103    pub fn transform_point(&self, point: Vec2) -> Vec2 {
104        self.matrix.transform_point2(point)
105    }
106
107    /// Transform a vector (ignores translation).
108    pub fn transform_vector(&self, vector: Vec2) -> Vec2 {
109        self.matrix.transform_vector2(vector)
110    }
111
112    /// Get the inverse transform, if it exists.
113    pub fn inverse(&self) -> Option<Self> {
114        let det = self.matrix.determinant();
115        if det.abs() < f32::EPSILON {
116            None
117        } else {
118            Some(Self {
119                matrix: self.matrix.inverse(),
120            })
121        }
122    }
123
124    /// Get the underlying 3x3 matrix.
125    pub fn as_mat3(&self) -> &Mat3 {
126        &self.matrix
127    }
128
129    /// Get the translation component.
130    pub fn translation(&self) -> Vec2 {
131        Vec2::new(self.matrix.z_axis.x, self.matrix.z_axis.y)
132    }
133
134    /// Get the scale component (approximate for non-uniform transforms).
135    pub fn scale_factor(&self) -> Vec2 {
136        Vec2::new(
137            Vec2::new(self.matrix.x_axis.x, self.matrix.x_axis.y).length(),
138            Vec2::new(self.matrix.y_axis.x, self.matrix.y_axis.y).length(),
139        )
140    }
141
142    /// Get the rotation angle in radians (approximate for skewed transforms).
143    pub fn rotation(&self) -> f32 {
144        self.matrix.x_axis.y.atan2(self.matrix.x_axis.x)
145    }
146}
147
148impl std::ops::Mul<Transform2D> for Transform2D {
149    type Output = Transform2D;
150
151    fn mul(self, rhs: Transform2D) -> Transform2D {
152        self.then(&rhs)
153    }
154}
155
156impl std::ops::Mul<Vec2> for Transform2D {
157    type Output = Vec2;
158
159    fn mul(self, rhs: Vec2) -> Vec2 {
160        self.transform_point(rhs)
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167    use std::f32::consts::PI;
168
169    #[test]
170    fn test_identity() {
171        let t = Transform2D::IDENTITY;
172        let point = Vec2::new(10.0, 20.0);
173        assert_eq!(t.transform_point(point), point);
174    }
175
176    #[test]
177    fn test_translate() {
178        let t = Transform2D::translate(Vec2::new(5.0, 10.0));
179        let point = Vec2::new(10.0, 20.0);
180        assert_eq!(t.transform_point(point), Vec2::new(15.0, 30.0));
181    }
182
183    #[test]
184    fn test_scale() {
185        let t = Transform2D::scale(2.0);
186        let point = Vec2::new(10.0, 20.0);
187        assert_eq!(t.transform_point(point), Vec2::new(20.0, 40.0));
188    }
189
190    #[test]
191    fn test_rotate_90() {
192        let t = Transform2D::rotate(PI / 2.0);
193        let point = Vec2::new(1.0, 0.0);
194        let result = t.transform_point(point);
195        assert!((result.x - 0.0).abs() < 0.001);
196        assert!((result.y - 1.0).abs() < 0.001);
197    }
198
199    #[test]
200    fn test_chain_transforms() {
201        let t = Transform2D::translate(Vec2::new(10.0, 0.0)).then_scale(2.0);
202        let point = Vec2::new(5.0, 5.0);
203        // First translate: (15, 5), then scale: (30, 10)
204        let result = t.transform_point(point);
205        assert_eq!(result, Vec2::new(30.0, 10.0));
206    }
207
208    #[test]
209    fn test_inverse() {
210        let t = Transform2D::translate(Vec2::new(10.0, 20.0)).then_scale(2.0);
211        let inv = t.inverse().unwrap();
212        let point = Vec2::new(5.0, 5.0);
213        let transformed = t.transform_point(point);
214        let restored = inv.transform_point(transformed);
215        assert!((restored - point).length() < 0.001);
216    }
217}