Skip to main content

phyz_math/
quaternion.rs

1//! Quaternion utilities for 3D rotations.
2//!
3//! Convention: q = [w; x; y; z] where w is scalar, (x,y,z) is vector part.
4
5use crate::{Mat3, Vec3};
6
7/// A unit quaternion representing a 3D rotation.
8#[derive(Debug, Clone, Copy)]
9pub struct Quat {
10    /// Scalar part (w).
11    pub w: f64,
12    /// Vector part (x, y, z).
13    pub v: Vec3,
14}
15
16impl Quat {
17    /// Create a new quaternion from scalar and vector parts.
18    pub fn new(w: f64, x: f64, y: f64, z: f64) -> Self {
19        Self {
20            w,
21            v: Vec3::new(x, y, z),
22        }
23    }
24
25    /// Identity quaternion (no rotation).
26    pub fn identity() -> Self {
27        Self {
28            w: 1.0,
29            v: Vec3::zeros(),
30        }
31    }
32
33    /// Create quaternion from axis-angle representation.
34    /// axis should be a unit vector, angle in radians.
35    pub fn from_axis_angle(axis: &Vec3, angle: f64) -> Self {
36        let half_angle = angle * 0.5;
37        let (s, c) = half_angle.sin_cos();
38        Self { w: c, v: axis * s }
39    }
40
41    /// Normalize this quaternion to unit length.
42    pub fn normalize(&self) -> Self {
43        let norm = (self.w * self.w + self.v.norm_squared()).sqrt();
44        if norm < 1e-12 {
45            return Self::identity();
46        }
47        Self {
48            w: self.w / norm,
49            v: self.v / norm,
50        }
51    }
52
53    /// Quaternion multiplication: self * other.
54    pub fn mul(&self, other: &Quat) -> Quat {
55        Quat {
56            w: self.w * other.w - self.v.dot(&other.v),
57            v: self.v.cross(&other.v) + other.v * self.w + self.v * other.w,
58        }
59    }
60
61    /// Conjugate of the quaternion (inverse for unit quaternions).
62    pub fn conjugate(&self) -> Quat {
63        Quat {
64            w: self.w,
65            v: -self.v,
66        }
67    }
68
69    /// Convert quaternion to 3x3 rotation matrix.
70    pub fn to_matrix(&self) -> Mat3 {
71        let w = self.w;
72        let x = self.v.x;
73        let y = self.v.y;
74        let z = self.v.z;
75
76        let x2 = x * x;
77        let y2 = y * y;
78        let z2 = z * z;
79        let xy = x * y;
80        let xz = x * z;
81        let yz = y * z;
82        let wx = w * x;
83        let wy = w * y;
84        let wz = w * z;
85
86        Mat3::new(
87            1.0 - 2.0 * (y2 + z2),
88            2.0 * (xy - wz),
89            2.0 * (xz + wy),
90            2.0 * (xy + wz),
91            1.0 - 2.0 * (x2 + z2),
92            2.0 * (yz - wx),
93            2.0 * (xz - wy),
94            2.0 * (yz + wx),
95            1.0 - 2.0 * (x2 + y2),
96        )
97    }
98
99    /// Convert rotation matrix to quaternion.
100    /// Reference: Shepperd's method (stable for all rotation matrices).
101    pub fn from_matrix(m: &Mat3) -> Quat {
102        let trace = m[(0, 0)] + m[(1, 1)] + m[(2, 2)];
103
104        if trace > 0.0 {
105            let s = (trace + 1.0).sqrt() * 2.0; // s = 4*w
106            Quat {
107                w: 0.25 * s,
108                v: Vec3::new(
109                    (m[(2, 1)] - m[(1, 2)]) / s,
110                    (m[(0, 2)] - m[(2, 0)]) / s,
111                    (m[(1, 0)] - m[(0, 1)]) / s,
112                ),
113            }
114        } else if m[(0, 0)] > m[(1, 1)] && m[(0, 0)] > m[(2, 2)] {
115            let s = (1.0 + m[(0, 0)] - m[(1, 1)] - m[(2, 2)]).sqrt() * 2.0; // s = 4*x
116            Quat {
117                w: (m[(2, 1)] - m[(1, 2)]) / s,
118                v: Vec3::new(
119                    0.25 * s,
120                    (m[(0, 1)] + m[(1, 0)]) / s,
121                    (m[(0, 2)] + m[(2, 0)]) / s,
122                ),
123            }
124        } else if m[(1, 1)] > m[(2, 2)] {
125            let s = (1.0 + m[(1, 1)] - m[(0, 0)] - m[(2, 2)]).sqrt() * 2.0; // s = 4*y
126            Quat {
127                w: (m[(0, 2)] - m[(2, 0)]) / s,
128                v: Vec3::new(
129                    (m[(0, 1)] + m[(1, 0)]) / s,
130                    0.25 * s,
131                    (m[(1, 2)] + m[(2, 1)]) / s,
132                ),
133            }
134        } else {
135            let s = (1.0 + m[(2, 2)] - m[(0, 0)] - m[(1, 1)]).sqrt() * 2.0; // s = 4*z
136            Quat {
137                w: (m[(1, 0)] - m[(0, 1)]) / s,
138                v: Vec3::new(
139                    (m[(0, 2)] + m[(2, 0)]) / s,
140                    (m[(1, 2)] + m[(2, 1)]) / s,
141                    0.25 * s,
142                ),
143            }
144        }
145    }
146
147    /// Exponential map: exp(θ u) where θ u is axis-angle representation.
148    /// Converts axis-angle to quaternion via q = [cos(θ/2), sin(θ/2) * u].
149    /// For small angles, uses first-order approximation.
150    pub fn exp(w: &Vec3) -> Quat {
151        let theta = w.norm();
152        if theta < 1e-10 {
153            // First-order approximation for small angles
154            Quat {
155                w: 1.0,
156                v: *w * 0.5,
157            }
158            .normalize()
159        } else {
160            let half_theta = theta * 0.5;
161            Quat {
162                w: half_theta.cos(),
163                v: *w * (half_theta.sin() / theta),
164            }
165        }
166    }
167
168    /// Logarithmic map: log(q) returns the axis-angle vector θ u such that q = exp(θ u).
169    pub fn log(&self) -> Vec3 {
170        let v_norm = self.v.norm();
171        if v_norm < 1e-10 {
172            return Vec3::zeros();
173        }
174        // For quaternion q = [cos(θ/2), sin(θ/2) * u], we have:
175        // θ = 2 * atan2(sin(θ/2), cos(θ/2)) = 2 * atan2(|v|, w)
176        let angle = 2.0 * v_norm.atan2(self.w);
177        self.v * (angle / v_norm)
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184    use approx::assert_relative_eq;
185
186    #[test]
187    fn test_identity() {
188        let q = Quat::identity();
189        assert_eq!(q.w, 1.0);
190        assert_eq!(q.v, Vec3::zeros());
191    }
192
193    #[test]
194    fn test_axis_angle() {
195        let axis = Vec3::new(0.0, 0.0, 1.0);
196        let angle = std::f64::consts::FRAC_PI_2; // 90 degrees
197        let q = Quat::from_axis_angle(&axis, angle);
198
199        let expected_w = (angle / 2.0).cos();
200        let expected_z = (angle / 2.0).sin();
201
202        assert_relative_eq!(q.w, expected_w, epsilon = 1e-10);
203        assert_relative_eq!(q.v.z, expected_z, epsilon = 1e-10);
204    }
205
206    #[test]
207    fn test_normalize() {
208        let q = Quat::new(1.0, 2.0, 3.0, 4.0);
209        let normalized = q.normalize();
210        let norm = (normalized.w * normalized.w + normalized.v.norm_squared()).sqrt();
211        assert_relative_eq!(norm, 1.0, epsilon = 1e-10);
212    }
213
214    #[test]
215    fn test_multiplication() {
216        // 90 degree rotation about Z, then 90 degree rotation about Z
217        let axis = Vec3::new(0.0, 0.0, 1.0);
218        let q1 = Quat::from_axis_angle(&axis, std::f64::consts::FRAC_PI_2);
219        let q2 = Quat::from_axis_angle(&axis, std::f64::consts::FRAC_PI_2);
220        let result = q1.mul(&q2);
221
222        // Should equal 180 degree rotation about Z
223        let expected = Quat::from_axis_angle(&axis, std::f64::consts::PI);
224
225        assert_relative_eq!(result.w, expected.w, epsilon = 1e-10);
226        assert_relative_eq!(result.v, expected.v, epsilon = 1e-10);
227    }
228
229    #[test]
230    fn test_to_matrix() {
231        let axis = Vec3::new(0.0, 0.0, 1.0);
232        let angle = std::f64::consts::FRAC_PI_2;
233        let q = Quat::from_axis_angle(&axis, angle);
234        let m = q.to_matrix();
235
236        // 90 degree rotation about Z should map X to Y
237        let x = Vec3::new(1.0, 0.0, 0.0);
238        let y = m * x;
239        assert_relative_eq!(y, Vec3::new(0.0, 1.0, 0.0), epsilon = 1e-10);
240    }
241
242    #[test]
243    fn test_matrix_roundtrip() {
244        let axis = Vec3::new(1.0, 2.0, 3.0).normalize();
245        let angle = 0.7;
246        let q = Quat::from_axis_angle(&axis, angle);
247        let m = q.to_matrix();
248        let q2 = Quat::from_matrix(&m);
249
250        // Quaternions q and -q represent the same rotation
251        let same = (q.w - q2.w).abs() < 1e-10 && (q.v - q2.v).norm() < 1e-10;
252        let negated = (q.w + q2.w).abs() < 1e-10 && (q.v + q2.v).norm() < 1e-10;
253
254        assert!(same || negated);
255    }
256
257    #[test]
258    fn test_exp_log() {
259        let w = Vec3::new(0.1, 0.2, 0.3);
260        let q = Quat::exp(&w);
261        let w2 = q.log();
262        assert_relative_eq!(w, w2, epsilon = 1e-10);
263    }
264
265    #[test]
266    fn test_conjugate() {
267        let q = Quat::new(0.5, 0.5, 0.5, 0.5).normalize();
268        let conj = q.conjugate();
269        let result = q.mul(&conj);
270        assert_relative_eq!(result.w, 1.0, epsilon = 1e-10);
271        assert_relative_eq!(result.v.norm(), 0.0, epsilon = 1e-10);
272    }
273}