Skip to main content

euca_math/
quat.rs

1use crate::Vec3;
2use serde::{Deserialize, Serialize};
3use std::ops::Mul;
4
5/// Quaternion (xyzw layout, unit quaternion for rotations).
6#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
7#[repr(C, align(16))]
8pub struct Quat {
9    pub x: f32,
10    pub y: f32,
11    pub z: f32,
12    pub w: f32,
13}
14
15impl Default for Quat {
16    #[inline(always)]
17    fn default() -> Self {
18        Self::IDENTITY
19    }
20}
21
22impl Quat {
23    /// The identity quaternion (no rotation).
24    pub const IDENTITY: Self = Self {
25        x: 0.0,
26        y: 0.0,
27        z: 0.0,
28        w: 1.0,
29    };
30
31    /// Creates a quaternion from raw x, y, z, w components.
32    #[inline(always)]
33    pub const fn from_xyzw(x: f32, y: f32, z: f32, w: f32) -> Self {
34        Self { x, y, z, w }
35    }
36
37    /// Create a quaternion from axis-angle rotation.
38    #[inline]
39    pub fn from_axis_angle(axis: Vec3, angle: f32) -> Self {
40        let half = angle * 0.5;
41        let s = half.sin();
42        let c = half.cos();
43        let a = axis.normalize();
44        Self {
45            x: a.x * s,
46            y: a.y * s,
47            z: a.z * s,
48            w: c,
49        }
50    }
51
52    /// Create from Euler angles (yaw, pitch, roll) in YXZ order.
53    #[inline]
54    pub fn from_euler(yaw: f32, pitch: f32, roll: f32) -> Self {
55        let (sy, cy) = (yaw * 0.5).sin_cos();
56        let (sp, cp) = (pitch * 0.5).sin_cos();
57        let (sr, cr) = (roll * 0.5).sin_cos();
58
59        Self {
60            x: cy * sp * cr + sy * cp * sr,
61            y: sy * cp * cr - cy * sp * sr,
62            z: cy * cp * sr - sy * sp * cr,
63            w: cy * cp * cr + sy * sp * sr,
64        }
65    }
66
67    /// Returns the length (norm) of the quaternion.
68    #[inline(always)]
69    pub fn length(self) -> f32 {
70        (self.x * self.x + self.y * self.y + self.z * self.z + self.w * self.w).sqrt()
71    }
72
73    /// Returns a unit-length quaternion in the same direction.
74    #[inline]
75    pub fn normalize(self) -> Self {
76        let inv = 1.0 / self.length();
77        Self {
78            x: self.x * inv,
79            y: self.y * inv,
80            z: self.z * inv,
81            w: self.w * inv,
82        }
83    }
84
85    /// Returns the conjugate (inverse for unit quaternions).
86    #[inline]
87    pub fn inverse(self) -> Self {
88        Self {
89            x: -self.x,
90            y: -self.y,
91            z: -self.z,
92            w: self.w,
93        }
94    }
95
96    /// Spherical linear interpolation.
97    #[inline]
98    pub fn slerp(self, mut end: Self, t: f32) -> Self {
99        let mut dot = self.x * end.x + self.y * end.y + self.z * end.z + self.w * end.w;
100
101        if dot < 0.0 {
102            end = Self {
103                x: -end.x,
104                y: -end.y,
105                z: -end.z,
106                w: -end.w,
107            };
108            dot = -dot;
109        }
110
111        if dot > 0.9995 {
112            return Self {
113                x: self.x + (end.x - self.x) * t,
114                y: self.y + (end.y - self.y) * t,
115                z: self.z + (end.z - self.z) * t,
116                w: self.w + (end.w - self.w) * t,
117            }
118            .normalize();
119        }
120
121        let theta = dot.acos();
122        let sin_theta = theta.sin();
123        let s0 = ((1.0 - t) * theta).sin() / sin_theta;
124        let s1 = (t * theta).sin() / sin_theta;
125
126        Self {
127            x: self.x * s0 + end.x * s1,
128            y: self.y * s0 + end.y * s1,
129            z: self.z * s0 + end.z * s1,
130            w: self.w * s0 + end.w * s1,
131        }
132    }
133}
134
135/// Compose rotations: `a * b` applies `b` first, then `a`.
136impl Mul for Quat {
137    type Output = Self;
138    #[inline]
139    fn mul(self, rhs: Self) -> Self {
140        Self {
141            x: self.w * rhs.x + self.x * rhs.w + self.y * rhs.z - self.z * rhs.y,
142            y: self.w * rhs.y - self.x * rhs.z + self.y * rhs.w + self.z * rhs.x,
143            z: self.w * rhs.z + self.x * rhs.y - self.y * rhs.x + self.z * rhs.w,
144            w: self.w * rhs.w - self.x * rhs.x - self.y * rhs.y - self.z * rhs.z,
145        }
146    }
147}
148
149/// Rotate a Vec3 by this quaternion.
150impl Mul<Vec3> for Quat {
151    type Output = Vec3;
152    #[inline]
153    fn mul(self, v: Vec3) -> Vec3 {
154        let u = Vec3::new(self.x, self.y, self.z);
155        let s = self.w;
156        u * (2.0 * u.dot(v)) + v * (s * s - u.dot(u)) + u.cross(v) * (2.0 * s)
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163    use std::f32::consts::FRAC_PI_2;
164
165    #[test]
166    fn identity_rotation() {
167        let v = Vec3::new(1.0, 2.0, 3.0);
168        let r = Quat::IDENTITY * v;
169        assert!((r.x - v.x).abs() < 1e-6);
170        assert!((r.y - v.y).abs() < 1e-6);
171        assert!((r.z - v.z).abs() < 1e-6);
172    }
173
174    #[test]
175    fn rotate_90_around_z() {
176        let q = Quat::from_axis_angle(Vec3::Z, FRAC_PI_2);
177        let r = q * Vec3::X;
178        assert!(r.x.abs() < 1e-5);
179        assert!((r.y - 1.0).abs() < 1e-5);
180    }
181
182    #[test]
183    fn inverse_undoes_rotation() {
184        let q = Quat::from_axis_angle(Vec3::Y, 1.0);
185        let v = Vec3::new(1.0, 2.0, 3.0);
186        let back = q.inverse() * (q * v);
187        assert!((back.x - v.x).abs() < 1e-4);
188        assert!((back.y - v.y).abs() < 1e-4);
189        assert!((back.z - v.z).abs() < 1e-4);
190    }
191
192    #[test]
193    fn slerp_halfway() {
194        let a = Quat::IDENTITY;
195        let b = Quat::from_axis_angle(Vec3::Z, FRAC_PI_2);
196        let mid = a.slerp(b, 0.5);
197        let v = mid * Vec3::X;
198        let expected = FRAC_PI_2 / 2.0;
199        assert!((v.x - expected.cos()).abs() < 1e-4);
200        assert!((v.y - expected.sin()).abs() < 1e-4);
201    }
202}