fj_math/
transform.rs

1use std::ops;
2
3use nalgebra::Perspective3;
4
5use crate::{Circle, Line, Scalar};
6
7use super::{Aabb, Point, Segment, Triangle, Vector};
8
9/// An affine transform
10#[repr(C)]
11#[derive(Debug, Clone, Copy, Default)]
12pub struct Transform(nalgebra::Transform<f64, nalgebra::TAffine, 3>);
13
14impl Transform {
15    /// Construct an identity transform
16    pub fn identity() -> Self {
17        Self(nalgebra::Transform::identity())
18    }
19
20    /// Construct a translation
21    pub fn translation(offset: impl Into<Vector<3>>) -> Self {
22        let offset = offset.into();
23
24        Self(nalgebra::Transform::from_matrix_unchecked(
25            nalgebra::OMatrix::new_translation(&offset.to_na()),
26        ))
27    }
28
29    /// Construct a rotation
30    ///
31    /// The direction of the vector defines the rotation axis. Its length
32    /// defines the angle of the rotation.
33    pub fn rotation(axis_angle: impl Into<Vector<3>>) -> Self {
34        let axis_angle = axis_angle.into();
35
36        Self(nalgebra::Transform::from_matrix_unchecked(
37            nalgebra::OMatrix::<_, nalgebra::Const<4>, _>::new_rotation(
38                axis_angle.to_na(),
39            ),
40        ))
41    }
42
43    /// Construct a scaling
44    pub fn scale(scaling_factor: f64) -> Self {
45        Self(nalgebra::Transform::from_matrix_unchecked(
46            nalgebra::OMatrix::new_scaling(scaling_factor),
47        ))
48    }
49
50    /// Transform the given point
51    pub fn transform_point(&self, point: &Point<3>) -> Point<3> {
52        Point::from(self.0.transform_point(&point.to_na()))
53    }
54
55    /// Inverse transform given point
56    pub fn inverse_transform_point(&self, point: &Point<3>) -> Point<3> {
57        Point::from(self.0.inverse_transform_point(&point.to_na()))
58    }
59
60    /// Transform the given vector
61    pub fn transform_vector(&self, vector: &Vector<3>) -> Vector<3> {
62        Vector::from(self.0.transform_vector(&vector.to_na()))
63    }
64
65    /// Transform the given line
66    pub fn transform_line(&self, line: &Line<3>) -> Line<3> {
67        Line::from_origin_and_direction(
68            self.transform_point(&line.origin()),
69            self.transform_vector(&line.direction()),
70        )
71    }
72
73    /// Transform the given segment
74    pub fn transform_segment(&self, segment: &Segment<3>) -> Segment<3> {
75        let [a, b] = &segment.points();
76        Segment::from([self.transform_point(a), self.transform_point(b)])
77    }
78
79    /// Transform the given triangle
80    pub fn transform_triangle(&self, triangle: &Triangle<3>) -> Triangle<3> {
81        let [a, b, c] = &triangle.points();
82        Triangle::from([
83            self.transform_point(a),
84            self.transform_point(b),
85            self.transform_point(c),
86        ])
87    }
88
89    /// Transform the given circle
90    pub fn transform_circle(&self, circle: &Circle<3>) -> Circle<3> {
91        Circle::new(
92            self.transform_point(&circle.center()),
93            self.transform_vector(&circle.a()),
94            self.transform_vector(&circle.b()),
95        )
96    }
97
98    /// Inverse transform
99    pub fn inverse(&self) -> Self {
100        Self(self.0.inverse())
101    }
102
103    /// Transpose transform
104    pub fn transpose(&self) -> Self {
105        Self(nalgebra::Transform::from_matrix_unchecked(
106            self.0.to_homogeneous().transpose(),
107        ))
108    }
109
110    /// Project transform according to camera specification, return data as an array.
111    /// Used primarily for graphics code.
112    pub fn project_to_array(
113        &self,
114        aspect_ratio: f64,
115        fovy: f64,
116        znear: f64,
117        zfar: f64,
118    ) -> [Scalar; 16] {
119        let projection = Perspective3::new(aspect_ratio, fovy, znear, zfar);
120
121        let mut array = [0.; 16];
122        array.copy_from_slice(
123            (projection.to_projective() * self.0).matrix().as_slice(),
124        );
125
126        array.map(Scalar::from)
127    }
128
129    /// Return a copy of the inner nalgebra transform
130    pub fn get_inner(&self) -> nalgebra::Transform<f64, nalgebra::TAffine, 3> {
131        self.0
132    }
133
134    /// Transform the given axis-aligned bounding box
135    pub fn transform_aabb(&self, aabb: &Aabb<3>) -> Aabb<3> {
136        Aabb {
137            min: self.transform_point(&aabb.min),
138            max: self.transform_point(&aabb.max),
139        }
140    }
141
142    /// Exposes the data of this Transform as a slice of f64.
143    pub fn data(&self) -> &[f64] {
144        self.0.matrix().data.as_slice()
145    }
146
147    /// Extract the rotation component of this transform
148    pub fn extract_rotation(&self) -> Self {
149        Self(nalgebra::Transform::from_matrix_unchecked(
150            self.0.matrix().fixed_resize::<3, 3>(0.).to_homogeneous(),
151        ))
152    }
153
154    /// Extract the translation component of this transform
155    pub fn extract_translation(&self) -> Self {
156        *self * self.extract_rotation().inverse()
157    }
158}
159
160impl ops::Mul<Self> for Transform {
161    type Output = Self;
162
163    fn mul(self, rhs: Self) -> Self::Output {
164        Self(self.0.mul(rhs.0))
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use approx::assert_abs_diff_eq;
171
172    use crate::{Line, Point, Scalar, Vector};
173
174    use super::Transform;
175
176    #[test]
177    fn transform() {
178        let line = Line::from_origin_and_direction(
179            Point::from([1., 0., 0.]),
180            Vector::from([0., 1., 0.]),
181        );
182
183        let transform = Transform::translation([1., 2., 3.])
184            * Transform::rotation(Vector::unit_z() * (Scalar::PI / 2.));
185        let line = transform.transform_line(&line);
186
187        assert_abs_diff_eq!(
188            line,
189            Line::from_origin_and_direction(
190                Point::from([1., 3., 3.]),
191                Vector::from([-1., 0., 0.]),
192            ),
193            epsilon = Scalar::from(1e-8),
194        );
195    }
196
197    #[test]
198    fn extract_rotation_translation() {
199        let rotation =
200            Transform::rotation(Vector::unit_z() * (Scalar::PI / 2.));
201        let translation = Transform::translation([1., 2., 3.]);
202
203        assert_abs_diff_eq!(
204            (translation * rotation).extract_rotation().data(),
205            rotation.data(),
206            epsilon = 1e-8,
207        );
208
209        assert_abs_diff_eq!(
210            (translation * rotation).extract_translation().data(),
211            translation.data(),
212            epsilon = 1e-8,
213        );
214
215        assert_abs_diff_eq!(
216            (rotation * translation).extract_rotation().data(),
217            rotation.data(),
218            epsilon = 1e-8,
219        );
220
221        assert_abs_diff_eq!(
222            (rotation * translation).extract_translation().data(),
223            Transform::translation([-2., 1., 3.]).data(),
224            epsilon = 1e-8,
225        );
226    }
227}