Skip to main content

oximedia_virtual/math/
quaternion.rs

1//! Quaternion types for 3D rotation.
2
3use super::matrix::Matrix3;
4use super::vector::Vector3;
5use serde::{Deserialize, Serialize};
6use std::ops::Mul;
7
8// ---------------------------------------------------------------------------
9// Quaternion
10// ---------------------------------------------------------------------------
11
12/// A quaternion (w + xi + yj + zk).
13#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
14pub struct Quaternion<T> {
15    pub w: T,
16    pub x: T,
17    pub y: T,
18    pub z: T,
19}
20
21impl Quaternion<f64> {
22    /// Construct.
23    #[must_use]
24    pub fn new(w: f64, x: f64, y: f64, z: f64) -> Self {
25        Self { w, x, y, z }
26    }
27
28    /// Norm.
29    #[must_use]
30    pub fn norm(&self) -> f64 {
31        (self.w * self.w + self.x * self.x + self.y * self.y + self.z * self.z).sqrt()
32    }
33
34    /// Normalise.
35    #[must_use]
36    pub fn normalize(&self) -> Self {
37        let n = self.norm();
38        if n < 1e-15 {
39            return *self;
40        }
41        Self::new(self.w / n, self.x / n, self.y / n, self.z / n)
42    }
43
44    /// Conjugate.
45    #[must_use]
46    pub fn conjugate(&self) -> Self {
47        Self::new(self.w, -self.x, -self.y, -self.z)
48    }
49}
50
51impl Mul for Quaternion<f64> {
52    type Output = Self;
53    fn mul(self, rhs: Self) -> Self {
54        Self::new(
55            self.w * rhs.w - self.x * rhs.x - self.y * rhs.y - self.z * rhs.z,
56            self.w * rhs.x + self.x * rhs.w + self.y * rhs.z - self.z * rhs.y,
57            self.w * rhs.y - self.x * rhs.z + self.y * rhs.w + self.z * rhs.x,
58            self.w * rhs.z + self.x * rhs.y - self.y * rhs.x + self.z * rhs.w,
59        )
60    }
61}
62
63// ---------------------------------------------------------------------------
64// Unit (for axis normalisation)
65// ---------------------------------------------------------------------------
66
67/// A unit-length value wrapper (like nalgebra::Unit).
68#[derive(Debug, Clone, Copy)]
69pub struct Unit<T>(pub T);
70
71impl Unit<Vector3<f64>> {
72    /// Normalise a vector into a Unit wrapper.
73    #[must_use]
74    pub fn new_normalize(v: Vector3<f64>) -> Self {
75        Unit(v.normalize())
76    }
77}
78
79// ---------------------------------------------------------------------------
80// UnitQuaternion
81// ---------------------------------------------------------------------------
82
83/// A unit quaternion representing a rotation in 3D.
84#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
85pub struct UnitQuaternion<T> {
86    q: Quaternion<T>,
87}
88
89impl UnitQuaternion<f64> {
90    /// Identity rotation.
91    #[must_use]
92    pub fn identity() -> Self {
93        Self {
94            q: Quaternion::new(1.0, 0.0, 0.0, 0.0),
95        }
96    }
97
98    /// Construct from an already-normalised Quaternion.
99    #[must_use]
100    pub fn from_quaternion(q: Quaternion<f64>) -> Self {
101        Self { q: q.normalize() }
102    }
103
104    /// Access inner quaternion.
105    #[must_use]
106    pub fn quaternion(&self) -> &Quaternion<f64> {
107        &self.q
108    }
109
110    /// Construct from Euler angles (roll, pitch, yaw) — ZYX convention.
111    #[must_use]
112    pub fn from_euler_angles(roll: f64, pitch: f64, yaw: f64) -> Self {
113        let (sr, cr) = (roll / 2.0).sin_cos();
114        let (sp, cp) = (pitch / 2.0).sin_cos();
115        let (sy, cy) = (yaw / 2.0).sin_cos();
116
117        let w = cr * cp * cy + sr * sp * sy;
118        let x = sr * cp * cy - cr * sp * sy;
119        let y = cr * sp * cy + sr * cp * sy;
120        let z = cr * cp * sy - sr * sp * cy;
121
122        Self {
123            q: Quaternion::new(w, x, y, z).normalize(),
124        }
125    }
126
127    /// Extract Euler angles (roll, pitch, yaw).
128    #[must_use]
129    pub fn euler_angles(&self) -> (f64, f64, f64) {
130        let q = &self.q;
131        // Roll (x-axis rotation)
132        let sinr_cosp = 2.0 * (q.w * q.x + q.y * q.z);
133        let cosr_cosp = 1.0 - 2.0 * (q.x * q.x + q.y * q.y);
134        let roll = sinr_cosp.atan2(cosr_cosp);
135
136        // Pitch (y-axis rotation)
137        let sinp = 2.0 * (q.w * q.y - q.z * q.x);
138        let pitch = if sinp.abs() >= 1.0 {
139            std::f64::consts::FRAC_PI_2.copysign(sinp)
140        } else {
141            sinp.asin()
142        };
143
144        // Yaw (z-axis rotation)
145        let siny_cosp = 2.0 * (q.w * q.z + q.x * q.y);
146        let cosy_cosp = 1.0 - 2.0 * (q.y * q.y + q.z * q.z);
147        let yaw = siny_cosp.atan2(cosy_cosp);
148
149        (roll, pitch, yaw)
150    }
151
152    /// Construct from axis-angle.
153    #[must_use]
154    pub fn from_axis_angle(axis: &Unit<Vector3<f64>>, angle: f64) -> Self {
155        let half = angle / 2.0;
156        let s = half.sin();
157        let c = half.cos();
158        let a = &axis.0;
159        Self {
160            q: Quaternion::new(c, a.x * s, a.y * s, a.z * s).normalize(),
161        }
162    }
163
164    /// Construct from a rotation matrix (3x3).
165    #[must_use]
166    pub fn from_matrix(m: &Matrix3<f64>) -> Self {
167        let d = &m.data;
168        let trace = d[0][0] + d[1][1] + d[2][2];
169
170        let (w, x, y, z) = if trace > 0.0 {
171            let s = (trace + 1.0).sqrt() * 2.0; // s = 4*w
172            (
173                0.25 * s,
174                (d[2][1] - d[1][2]) / s,
175                (d[0][2] - d[2][0]) / s,
176                (d[1][0] - d[0][1]) / s,
177            )
178        } else if d[0][0] > d[1][1] && d[0][0] > d[2][2] {
179            let s = (1.0 + d[0][0] - d[1][1] - d[2][2]).sqrt() * 2.0;
180            (
181                (d[2][1] - d[1][2]) / s,
182                0.25 * s,
183                (d[0][1] + d[1][0]) / s,
184                (d[0][2] + d[2][0]) / s,
185            )
186        } else if d[1][1] > d[2][2] {
187            let s = (1.0 + d[1][1] - d[0][0] - d[2][2]).sqrt() * 2.0;
188            (
189                (d[0][2] - d[2][0]) / s,
190                (d[0][1] + d[1][0]) / s,
191                0.25 * s,
192                (d[1][2] + d[2][1]) / s,
193            )
194        } else {
195            let s = (1.0 + d[2][2] - d[0][0] - d[1][1]).sqrt() * 2.0;
196            (
197                (d[1][0] - d[0][1]) / s,
198                (d[0][2] + d[2][0]) / s,
199                (d[1][2] + d[2][1]) / s,
200                0.25 * s,
201            )
202        };
203
204        Self {
205            q: Quaternion::new(w, x, y, z).normalize(),
206        }
207    }
208
209    /// SLERP interpolation.
210    #[must_use]
211    pub fn slerp(&self, other: &Self, t: f64) -> Self {
212        let mut dot = self.q.w * other.q.w
213            + self.q.x * other.q.x
214            + self.q.y * other.q.y
215            + self.q.z * other.q.z;
216
217        // If the dot product is negative, negate one to take the shorter path.
218        let mut other_q = other.q;
219        if dot < 0.0 {
220            other_q = Quaternion::new(-other_q.w, -other_q.x, -other_q.y, -other_q.z);
221            dot = -dot;
222        }
223
224        // Clamp for numerical safety.
225        dot = dot.min(1.0);
226
227        if dot > 0.9995 {
228            // Very close — linear interpolation then normalise.
229            let result = Quaternion::new(
230                self.q.w + t * (other_q.w - self.q.w),
231                self.q.x + t * (other_q.x - self.q.x),
232                self.q.y + t * (other_q.y - self.q.y),
233                self.q.z + t * (other_q.z - self.q.z),
234            );
235            return Self {
236                q: result.normalize(),
237            };
238        }
239
240        let theta = dot.acos();
241        let sin_theta = theta.sin();
242        let a = ((1.0 - t) * theta).sin() / sin_theta;
243        let b = (t * theta).sin() / sin_theta;
244
245        Self {
246            q: Quaternion::new(
247                a * self.q.w + b * other_q.w,
248                a * self.q.x + b * other_q.x,
249                a * self.q.y + b * other_q.y,
250                a * self.q.z + b * other_q.z,
251            )
252            .normalize(),
253        }
254    }
255
256    /// Convert to rotation matrix.
257    #[must_use]
258    pub fn to_rotation_matrix(&self) -> Matrix3<f64> {
259        let q = &self.q;
260        let xx = q.x * q.x;
261        let yy = q.y * q.y;
262        let zz = q.z * q.z;
263        let xy = q.x * q.y;
264        let xz = q.x * q.z;
265        let yz = q.y * q.z;
266        let wx = q.w * q.x;
267        let wy = q.w * q.y;
268        let wz = q.w * q.z;
269
270        let mut m = Matrix3::zeros();
271        m.data[0][0] = 1.0 - 2.0 * (yy + zz);
272        m.data[0][1] = 2.0 * (xy - wz);
273        m.data[0][2] = 2.0 * (xz + wy);
274        m.data[1][0] = 2.0 * (xy + wz);
275        m.data[1][1] = 1.0 - 2.0 * (xx + zz);
276        m.data[1][2] = 2.0 * (yz - wx);
277        m.data[2][0] = 2.0 * (xz - wy);
278        m.data[2][1] = 2.0 * (yz + wx);
279        m.data[2][2] = 1.0 - 2.0 * (xx + yy);
280        m
281    }
282}
283
284/// `UnitQuaternion * Vector3` — rotate a vector.
285impl Mul<Vector3<f64>> for UnitQuaternion<f64> {
286    type Output = Vector3<f64>;
287    fn mul(self, v: Vector3<f64>) -> Vector3<f64> {
288        let qv = Quaternion::new(0.0, v.x, v.y, v.z);
289        let result = self.q * qv * self.q.conjugate();
290        Vector3::new(result.x, result.y, result.z)
291    }
292}
293
294/// `&UnitQuaternion * Vector3`
295impl Mul<Vector3<f64>> for &UnitQuaternion<f64> {
296    type Output = Vector3<f64>;
297    fn mul(self, v: Vector3<f64>) -> Vector3<f64> {
298        (*self) * v
299    }
300}
301
302/// `UnitQuaternion * UnitQuaternion` — compose rotations.
303impl Mul for UnitQuaternion<f64> {
304    type Output = Self;
305    fn mul(self, rhs: Self) -> Self {
306        Self {
307            q: (self.q * rhs.q).normalize(),
308        }
309    }
310}
311
312/// `UnitQuaternion *= UnitQuaternion`
313impl std::ops::MulAssign for UnitQuaternion<f64> {
314    fn mul_assign(&mut self, rhs: Self) {
315        self.q = (self.q * rhs.q).normalize();
316    }
317}
318
319// ---------------------------------------------------------------------------
320// Tests
321// ---------------------------------------------------------------------------
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326
327    #[test]
328    fn test_identity_rotation() {
329        let q = UnitQuaternion::identity();
330        let v = Vector3::new(1.0, 0.0, 0.0);
331        let rotated = q * v;
332        assert!((rotated.x - 1.0).abs() < 1e-10);
333        assert!(rotated.y.abs() < 1e-10);
334    }
335
336    #[test]
337    fn test_slerp_endpoints() {
338        let a = UnitQuaternion::identity();
339        let b = UnitQuaternion::from_euler_angles(0.5, 0.0, 0.0);
340        let at0 = a.slerp(&b, 0.0);
341        let at1 = a.slerp(&b, 1.0);
342        assert!((at0.q.w - a.q.w).abs() < 1e-6);
343        assert!((at1.q.w - b.q.w).abs() < 1e-6);
344    }
345
346    #[test]
347    fn test_euler_roundtrip() {
348        let q = UnitQuaternion::from_euler_angles(0.1, 0.2, 0.3);
349        let (r, p, y) = q.euler_angles();
350        assert!((r - 0.1).abs() < 1e-6);
351        assert!((p - 0.2).abs() < 1e-6);
352        assert!((y - 0.3).abs() < 1e-6);
353    }
354
355    #[test]
356    fn test_from_matrix_identity() {
357        let m = Matrix3::identity();
358        let q = UnitQuaternion::from_matrix(&m);
359        assert!((q.q.w - 1.0).abs() < 1e-6);
360    }
361
362    #[test]
363    fn test_axis_angle_90_deg() {
364        let axis = Unit::new_normalize(Vector3::new(0.0, 0.0, 1.0));
365        let q = UnitQuaternion::from_axis_angle(&axis, std::f64::consts::FRAC_PI_2);
366        let v = q * Vector3::new(1.0, 0.0, 0.0);
367        assert!(v.x.abs() < 1e-6);
368        assert!((v.y - 1.0).abs() < 1e-6);
369    }
370}