Skip to main content

scenix_math/
quat.rs

1use core::ops::{Mul, MulAssign, Neg};
2
3use crate::{EPSILON, Mat4, Vec3, Vec4, acos, clamp, cos, sin, sqrt};
4
5/// A unit quaternion representing 3D rotation.
6#[derive(Clone, Copy, Debug, PartialEq)]
7#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
8pub struct Quat {
9    /// X component of the imaginary part.
10    pub x: f32,
11    /// Y component of the imaginary part.
12    pub y: f32,
13    /// Z component of the imaginary part.
14    pub z: f32,
15    /// Real component.
16    pub w: f32,
17}
18
19impl Quat {
20    /// Identity rotation.
21    pub const IDENTITY: Self = Self::new(0.0, 0.0, 0.0, 1.0);
22
23    /// Creates a quaternion from components.
24    #[inline]
25    pub const fn new(x: f32, y: f32, z: f32, w: f32) -> Self {
26        Self { x, y, z, w }
27    }
28
29    /// Creates a quaternion from an axis and angle in radians.
30    pub fn from_axis_angle(axis: Vec3, angle_rad: f32) -> Self {
31        let axis = axis.normalize();
32        if axis.length_squared() <= EPSILON {
33            return Self::IDENTITY;
34        }
35        let half = angle_rad * 0.5;
36        let s = sin(half);
37        Self::new(axis.x * s, axis.y * s, axis.z * s, cos(half)).normalize()
38    }
39
40    /// Creates a quaternion from XYZ Euler angles in radians.
41    #[inline]
42    pub fn from_euler_xyz(x: f32, y: f32, z: f32) -> Self {
43        crate::Euler::new(x, y, z, crate::RotationOrder::XYZ).to_quat()
44    }
45
46    /// Creates the shortest rotation from one direction to another.
47    pub fn from_rotation_arc(from: Vec3, to: Vec3) -> Self {
48        let from = from.normalize();
49        let to = to.normalize();
50        if from.length_squared() <= EPSILON || to.length_squared() <= EPSILON {
51            return Self::IDENTITY;
52        }
53
54        let dot = from.dot(to);
55        if dot > 1.0 - EPSILON {
56            return Self::IDENTITY;
57        }
58        if dot < -1.0 + EPSILON {
59            let axis = if from.x.abs() < 0.9 {
60                from.cross(Vec3::X).normalize()
61            } else {
62                from.cross(Vec3::Y).normalize()
63            };
64            return Self::from_axis_angle(axis, core::f32::consts::PI);
65        }
66
67        let cross = from.cross(to);
68        Self::new(cross.x, cross.y, cross.z, 1.0 + dot).normalize()
69    }
70
71    /// Extracts a quaternion from the rotation part of a matrix.
72    pub fn from_mat4(matrix: Mat4) -> Self {
73        let m00 = matrix.get(0, 0);
74        let m11 = matrix.get(1, 1);
75        let m22 = matrix.get(2, 2);
76        let trace = m00 + m11 + m22;
77
78        if trace > 0.0 {
79            let s = sqrt(trace + 1.0) * 2.0;
80            Self::new(
81                (matrix.get(2, 1) - matrix.get(1, 2)) / s,
82                (matrix.get(0, 2) - matrix.get(2, 0)) / s,
83                (matrix.get(1, 0) - matrix.get(0, 1)) / s,
84                0.25 * s,
85            )
86        } else if m00 > m11 && m00 > m22 {
87            let s = sqrt(1.0 + m00 - m11 - m22) * 2.0;
88            Self::new(
89                0.25 * s,
90                (matrix.get(0, 1) + matrix.get(1, 0)) / s,
91                (matrix.get(0, 2) + matrix.get(2, 0)) / s,
92                (matrix.get(2, 1) - matrix.get(1, 2)) / s,
93            )
94        } else if m11 > m22 {
95            let s = sqrt(1.0 + m11 - m00 - m22) * 2.0;
96            Self::new(
97                (matrix.get(0, 1) + matrix.get(1, 0)) / s,
98                0.25 * s,
99                (matrix.get(1, 2) + matrix.get(2, 1)) / s,
100                (matrix.get(0, 2) - matrix.get(2, 0)) / s,
101            )
102        } else {
103            let s = sqrt(1.0 + m22 - m00 - m11) * 2.0;
104            Self::new(
105                (matrix.get(0, 2) + matrix.get(2, 0)) / s,
106                (matrix.get(1, 2) + matrix.get(2, 1)) / s,
107                0.25 * s,
108                (matrix.get(1, 0) - matrix.get(0, 1)) / s,
109            )
110        }
111        .normalize()
112    }
113
114    /// Returns the dot product.
115    #[inline]
116    pub fn dot(self, rhs: Self) -> f32 {
117        self.x * rhs.x + self.y * rhs.y + self.z * rhs.z + self.w * rhs.w
118    }
119
120    /// Returns the squared length.
121    #[inline]
122    pub fn length_squared(self) -> f32 {
123        self.dot(self)
124    }
125
126    /// Returns the length.
127    #[inline]
128    pub fn length(self) -> f32 {
129        sqrt(self.length_squared())
130    }
131
132    /// Multiplies two quaternions.
133    #[inline]
134    pub fn mul_quat(self, rhs: Self) -> Self {
135        Self::new(
136            self.w * rhs.x + self.x * rhs.w + self.y * rhs.z - self.z * rhs.y,
137            self.w * rhs.y - self.x * rhs.z + self.y * rhs.w + self.z * rhs.x,
138            self.w * rhs.z + self.x * rhs.y - self.y * rhs.x + self.z * rhs.w,
139            self.w * rhs.w - self.x * rhs.x - self.y * rhs.y - self.z * rhs.z,
140        )
141    }
142
143    /// Rotates a vector by this quaternion.
144    #[inline]
145    pub fn mul_vec3(self, rhs: Vec3) -> Vec3 {
146        let q = self.normalize();
147        let u = Vec3::new(q.x, q.y, q.z);
148        let s = q.w;
149        u * (2.0 * u.dot(rhs)) + rhs * (s * s - u.dot(u)) + u.cross(rhs) * (2.0 * s)
150    }
151
152    /// Returns the conjugate quaternion.
153    #[inline]
154    pub fn conjugate(self) -> Self {
155        Self::new(-self.x, -self.y, -self.z, self.w)
156    }
157
158    /// Returns the inverse quaternion.
159    #[inline]
160    pub fn inverse(self) -> Self {
161        let len_sq = self.length_squared();
162        if len_sq <= EPSILON {
163            Self::IDENTITY
164        } else {
165            self.conjugate() * (1.0 / len_sq)
166        }
167    }
168
169    /// Returns a normalized quaternion, or identity for a near-zero input.
170    #[inline]
171    pub fn normalize(self) -> Self {
172        let length = self.length();
173        if length <= EPSILON {
174            Self::IDENTITY
175        } else {
176            self * (1.0 / length)
177        }
178    }
179
180    /// Spherically interpolates toward `rhs`.
181    pub fn slerp(self, rhs: Self, t: f32) -> Self {
182        let t = clamp(t, 0.0, 1.0);
183        if t <= 0.0 {
184            return self;
185        }
186        if t >= 1.0 {
187            return rhs;
188        }
189        let mut end = rhs;
190        let mut cos_half_theta = self.dot(end);
191
192        if cos_half_theta < -EPSILON {
193            end = -end;
194            cos_half_theta = -cos_half_theta;
195        }
196
197        if cos_half_theta >= 1.0 - EPSILON {
198            return Self::new(
199                self.x + t * (end.x - self.x),
200                self.y + t * (end.y - self.y),
201                self.z + t * (end.z - self.z),
202                self.w + t * (end.w - self.w),
203            )
204            .normalize();
205        }
206
207        let half_theta = acos(clamp(cos_half_theta, -1.0, 1.0));
208        let sin_half_theta = sin(half_theta);
209        if sin_half_theta.abs() <= EPSILON {
210            return self;
211        }
212
213        let ratio_a = sin((1.0 - t) * half_theta) / sin_half_theta;
214        let ratio_b = sin(t * half_theta) / sin_half_theta;
215        Self::new(
216            self.x * ratio_a + end.x * ratio_b,
217            self.y * ratio_a + end.y * ratio_b,
218            self.z * ratio_a + end.z * ratio_b,
219            self.w * ratio_a + end.w * ratio_b,
220        )
221        .normalize()
222    }
223
224    /// Converts this quaternion to a rotation matrix.
225    pub fn to_mat4(self) -> Mat4 {
226        let q = self.normalize();
227        let x2 = q.x + q.x;
228        let y2 = q.y + q.y;
229        let z2 = q.z + q.z;
230        let xx = q.x * x2;
231        let xy = q.x * y2;
232        let xz = q.x * z2;
233        let yy = q.y * y2;
234        let yz = q.y * z2;
235        let zz = q.z * z2;
236        let wx = q.w * x2;
237        let wy = q.w * y2;
238        let wz = q.w * z2;
239
240        Mat4::from_cols(
241            Vec4::new(1.0 - (yy + zz), xy + wz, xz - wy, 0.0),
242            Vec4::new(xy - wz, 1.0 - (xx + zz), yz + wx, 0.0),
243            Vec4::new(xz + wy, yz - wx, 1.0 - (xx + yy), 0.0),
244            Vec4::W,
245        )
246    }
247
248    /// Extracts XYZ Euler angles in radians.
249    #[inline]
250    pub fn to_euler_xyz(self) -> Vec3 {
251        let euler = crate::Euler::from_quat(self, crate::RotationOrder::XYZ);
252        Vec3::new(euler.x, euler.y, euler.z)
253    }
254
255    /// Returns the absolute angular distance to another quaternion.
256    #[inline]
257    pub fn angle_between(self, rhs: Self) -> f32 {
258        2.0 * acos(clamp(
259            self.normalize().dot(rhs.normalize()).abs(),
260            -1.0,
261            1.0,
262        ))
263    }
264
265    /// Returns the quaternion as an array `[x, y, z, w]`.
266    #[inline]
267    pub const fn to_array(self) -> [f32; 4] {
268        [self.x, self.y, self.z, self.w]
269    }
270}
271
272impl Default for Quat {
273    #[inline]
274    fn default() -> Self {
275        Self::IDENTITY
276    }
277}
278
279impl Mul for Quat {
280    type Output = Self;
281
282    #[inline]
283    fn mul(self, rhs: Self) -> Self::Output {
284        self.mul_quat(rhs)
285    }
286}
287
288impl MulAssign for Quat {
289    #[inline]
290    fn mul_assign(&mut self, rhs: Self) {
291        *self = self.mul_quat(rhs);
292    }
293}
294
295impl Mul<Vec3> for Quat {
296    type Output = Vec3;
297
298    #[inline]
299    fn mul(self, rhs: Vec3) -> Self::Output {
300        self.mul_vec3(rhs)
301    }
302}
303
304impl Mul<f32> for Quat {
305    type Output = Self;
306
307    #[inline]
308    fn mul(self, rhs: f32) -> Self::Output {
309        Self::new(self.x * rhs, self.y * rhs, self.z * rhs, self.w * rhs)
310    }
311}
312
313impl Neg for Quat {
314    type Output = Self;
315
316    #[inline]
317    fn neg(self) -> Self::Output {
318        Self::new(-self.x, -self.y, -self.z, -self.w)
319    }
320}
321
322#[cfg(feature = "approx")]
323impl approx::AbsDiffEq for Quat {
324    type Epsilon = f32;
325
326    #[inline]
327    fn default_epsilon() -> Self::Epsilon {
328        f32::default_epsilon()
329    }
330
331    #[inline]
332    fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
333        f32::abs_diff_eq(&self.x, &other.x, epsilon)
334            && f32::abs_diff_eq(&self.y, &other.y, epsilon)
335            && f32::abs_diff_eq(&self.z, &other.z, epsilon)
336            && f32::abs_diff_eq(&self.w, &other.w, epsilon)
337    }
338}
339
340#[cfg(feature = "approx")]
341impl approx::RelativeEq for Quat {
342    #[inline]
343    fn default_max_relative() -> Self::Epsilon {
344        f32::default_max_relative()
345    }
346
347    #[inline]
348    fn relative_eq(
349        &self,
350        other: &Self,
351        epsilon: Self::Epsilon,
352        max_relative: Self::Epsilon,
353    ) -> bool {
354        f32::relative_eq(&self.x, &other.x, epsilon, max_relative)
355            && f32::relative_eq(&self.y, &other.y, epsilon, max_relative)
356            && f32::relative_eq(&self.z, &other.z, epsilon, max_relative)
357            && f32::relative_eq(&self.w, &other.w, epsilon, max_relative)
358    }
359}
360
361#[cfg(feature = "approx")]
362impl approx::UlpsEq for Quat {
363    #[inline]
364    fn default_max_ulps() -> u32 {
365        f32::default_max_ulps()
366    }
367
368    #[inline]
369    fn ulps_eq(&self, other: &Self, epsilon: Self::Epsilon, max_ulps: u32) -> bool {
370        f32::ulps_eq(&self.x, &other.x, epsilon, max_ulps)
371            && f32::ulps_eq(&self.y, &other.y, epsilon, max_ulps)
372            && f32::ulps_eq(&self.z, &other.z, epsilon, max_ulps)
373            && f32::ulps_eq(&self.w, &other.w, epsilon, max_ulps)
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380    use crate::assert_close;
381
382    #[test]
383    fn slerp_handles_endpoints_and_midpoint() {
384        let a = Quat::IDENTITY;
385        let b = Quat::from_axis_angle(Vec3::Y, core::f32::consts::PI);
386        assert_eq!(a.slerp(b, 0.0), a);
387        assert_eq!(a.slerp(b, 1.0), b);
388
389        let midpoint = a.slerp(b, 0.5);
390        let rotated = midpoint.mul_vec3(Vec3::X);
391        assert_close(rotated.x, 0.0);
392        assert_close(rotated.z, -1.0);
393    }
394
395    #[test]
396    fn inverse_undoes_rotation() {
397        let q = Quat::from_axis_angle(Vec3::Y, 0.8);
398        let v = Vec3::new(1.0, 2.0, 3.0);
399        let rotated = q.mul_vec3(v);
400        let restored = q.inverse().mul_vec3(rotated);
401        assert_close(restored.x, v.x);
402        assert_close(restored.y, v.y);
403        assert_close(restored.z, v.z);
404    }
405
406    #[test]
407    fn rotation_arc_rotates_between_vectors() {
408        let q = Quat::from_rotation_arc(Vec3::X, Vec3::Y);
409        let rotated = q.mul_vec3(Vec3::X);
410        assert_close(rotated.x, 0.0);
411        assert_close(rotated.y, 1.0);
412    }
413}