1use super::matrix::Matrix3;
4use super::vector::Vector3;
5use serde::{Deserialize, Serialize};
6use std::ops::Mul;
7
8#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
14pub struct Quaternion<T> {
15 pub w: T,
16 pub x: T,
17 pub y: T,
18 pub z: T,
19}
20
21impl Quaternion<f64> {
22 #[must_use]
24 pub fn new(w: f64, x: f64, y: f64, z: f64) -> Self {
25 Self { w, x, y, z }
26 }
27
28 #[must_use]
30 pub fn norm(&self) -> f64 {
31 (self.w * self.w + self.x * self.x + self.y * self.y + self.z * self.z).sqrt()
32 }
33
34 #[must_use]
36 pub fn normalize(&self) -> Self {
37 let n = self.norm();
38 if n < 1e-15 {
39 return *self;
40 }
41 Self::new(self.w / n, self.x / n, self.y / n, self.z / n)
42 }
43
44 #[must_use]
46 pub fn conjugate(&self) -> Self {
47 Self::new(self.w, -self.x, -self.y, -self.z)
48 }
49}
50
51impl Mul for Quaternion<f64> {
52 type Output = Self;
53 fn mul(self, rhs: Self) -> Self {
54 Self::new(
55 self.w * rhs.w - self.x * rhs.x - self.y * rhs.y - self.z * rhs.z,
56 self.w * rhs.x + self.x * rhs.w + self.y * rhs.z - self.z * rhs.y,
57 self.w * rhs.y - self.x * rhs.z + self.y * rhs.w + self.z * rhs.x,
58 self.w * rhs.z + self.x * rhs.y - self.y * rhs.x + self.z * rhs.w,
59 )
60 }
61}
62
63#[derive(Debug, Clone, Copy)]
69pub struct Unit<T>(pub T);
70
71impl Unit<Vector3<f64>> {
72 #[must_use]
74 pub fn new_normalize(v: Vector3<f64>) -> Self {
75 Unit(v.normalize())
76 }
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
85pub struct UnitQuaternion<T> {
86 q: Quaternion<T>,
87}
88
89impl UnitQuaternion<f64> {
90 #[must_use]
92 pub fn identity() -> Self {
93 Self {
94 q: Quaternion::new(1.0, 0.0, 0.0, 0.0),
95 }
96 }
97
98 #[must_use]
100 pub fn from_quaternion(q: Quaternion<f64>) -> Self {
101 Self { q: q.normalize() }
102 }
103
104 #[must_use]
106 pub fn quaternion(&self) -> &Quaternion<f64> {
107 &self.q
108 }
109
110 #[must_use]
112 pub fn from_euler_angles(roll: f64, pitch: f64, yaw: f64) -> Self {
113 let (sr, cr) = (roll / 2.0).sin_cos();
114 let (sp, cp) = (pitch / 2.0).sin_cos();
115 let (sy, cy) = (yaw / 2.0).sin_cos();
116
117 let w = cr * cp * cy + sr * sp * sy;
118 let x = sr * cp * cy - cr * sp * sy;
119 let y = cr * sp * cy + sr * cp * sy;
120 let z = cr * cp * sy - sr * sp * cy;
121
122 Self {
123 q: Quaternion::new(w, x, y, z).normalize(),
124 }
125 }
126
127 #[must_use]
129 pub fn euler_angles(&self) -> (f64, f64, f64) {
130 let q = &self.q;
131 let sinr_cosp = 2.0 * (q.w * q.x + q.y * q.z);
133 let cosr_cosp = 1.0 - 2.0 * (q.x * q.x + q.y * q.y);
134 let roll = sinr_cosp.atan2(cosr_cosp);
135
136 let sinp = 2.0 * (q.w * q.y - q.z * q.x);
138 let pitch = if sinp.abs() >= 1.0 {
139 std::f64::consts::FRAC_PI_2.copysign(sinp)
140 } else {
141 sinp.asin()
142 };
143
144 let siny_cosp = 2.0 * (q.w * q.z + q.x * q.y);
146 let cosy_cosp = 1.0 - 2.0 * (q.y * q.y + q.z * q.z);
147 let yaw = siny_cosp.atan2(cosy_cosp);
148
149 (roll, pitch, yaw)
150 }
151
152 #[must_use]
154 pub fn from_axis_angle(axis: &Unit<Vector3<f64>>, angle: f64) -> Self {
155 let half = angle / 2.0;
156 let s = half.sin();
157 let c = half.cos();
158 let a = &axis.0;
159 Self {
160 q: Quaternion::new(c, a.x * s, a.y * s, a.z * s).normalize(),
161 }
162 }
163
164 #[must_use]
166 pub fn from_matrix(m: &Matrix3<f64>) -> Self {
167 let d = &m.data;
168 let trace = d[0][0] + d[1][1] + d[2][2];
169
170 let (w, x, y, z) = if trace > 0.0 {
171 let s = (trace + 1.0).sqrt() * 2.0; (
173 0.25 * s,
174 (d[2][1] - d[1][2]) / s,
175 (d[0][2] - d[2][0]) / s,
176 (d[1][0] - d[0][1]) / s,
177 )
178 } else if d[0][0] > d[1][1] && d[0][0] > d[2][2] {
179 let s = (1.0 + d[0][0] - d[1][1] - d[2][2]).sqrt() * 2.0;
180 (
181 (d[2][1] - d[1][2]) / s,
182 0.25 * s,
183 (d[0][1] + d[1][0]) / s,
184 (d[0][2] + d[2][0]) / s,
185 )
186 } else if d[1][1] > d[2][2] {
187 let s = (1.0 + d[1][1] - d[0][0] - d[2][2]).sqrt() * 2.0;
188 (
189 (d[0][2] - d[2][0]) / s,
190 (d[0][1] + d[1][0]) / s,
191 0.25 * s,
192 (d[1][2] + d[2][1]) / s,
193 )
194 } else {
195 let s = (1.0 + d[2][2] - d[0][0] - d[1][1]).sqrt() * 2.0;
196 (
197 (d[1][0] - d[0][1]) / s,
198 (d[0][2] + d[2][0]) / s,
199 (d[1][2] + d[2][1]) / s,
200 0.25 * s,
201 )
202 };
203
204 Self {
205 q: Quaternion::new(w, x, y, z).normalize(),
206 }
207 }
208
209 #[must_use]
211 pub fn slerp(&self, other: &Self, t: f64) -> Self {
212 let mut dot = self.q.w * other.q.w
213 + self.q.x * other.q.x
214 + self.q.y * other.q.y
215 + self.q.z * other.q.z;
216
217 let mut other_q = other.q;
219 if dot < 0.0 {
220 other_q = Quaternion::new(-other_q.w, -other_q.x, -other_q.y, -other_q.z);
221 dot = -dot;
222 }
223
224 dot = dot.min(1.0);
226
227 if dot > 0.9995 {
228 let result = Quaternion::new(
230 self.q.w + t * (other_q.w - self.q.w),
231 self.q.x + t * (other_q.x - self.q.x),
232 self.q.y + t * (other_q.y - self.q.y),
233 self.q.z + t * (other_q.z - self.q.z),
234 );
235 return Self {
236 q: result.normalize(),
237 };
238 }
239
240 let theta = dot.acos();
241 let sin_theta = theta.sin();
242 let a = ((1.0 - t) * theta).sin() / sin_theta;
243 let b = (t * theta).sin() / sin_theta;
244
245 Self {
246 q: Quaternion::new(
247 a * self.q.w + b * other_q.w,
248 a * self.q.x + b * other_q.x,
249 a * self.q.y + b * other_q.y,
250 a * self.q.z + b * other_q.z,
251 )
252 .normalize(),
253 }
254 }
255
256 #[must_use]
258 pub fn to_rotation_matrix(&self) -> Matrix3<f64> {
259 let q = &self.q;
260 let xx = q.x * q.x;
261 let yy = q.y * q.y;
262 let zz = q.z * q.z;
263 let xy = q.x * q.y;
264 let xz = q.x * q.z;
265 let yz = q.y * q.z;
266 let wx = q.w * q.x;
267 let wy = q.w * q.y;
268 let wz = q.w * q.z;
269
270 let mut m = Matrix3::zeros();
271 m.data[0][0] = 1.0 - 2.0 * (yy + zz);
272 m.data[0][1] = 2.0 * (xy - wz);
273 m.data[0][2] = 2.0 * (xz + wy);
274 m.data[1][0] = 2.0 * (xy + wz);
275 m.data[1][1] = 1.0 - 2.0 * (xx + zz);
276 m.data[1][2] = 2.0 * (yz - wx);
277 m.data[2][0] = 2.0 * (xz - wy);
278 m.data[2][1] = 2.0 * (yz + wx);
279 m.data[2][2] = 1.0 - 2.0 * (xx + yy);
280 m
281 }
282}
283
284impl Mul<Vector3<f64>> for UnitQuaternion<f64> {
286 type Output = Vector3<f64>;
287 fn mul(self, v: Vector3<f64>) -> Vector3<f64> {
288 let qv = Quaternion::new(0.0, v.x, v.y, v.z);
289 let result = self.q * qv * self.q.conjugate();
290 Vector3::new(result.x, result.y, result.z)
291 }
292}
293
294impl Mul<Vector3<f64>> for &UnitQuaternion<f64> {
296 type Output = Vector3<f64>;
297 fn mul(self, v: Vector3<f64>) -> Vector3<f64> {
298 (*self) * v
299 }
300}
301
302impl Mul for UnitQuaternion<f64> {
304 type Output = Self;
305 fn mul(self, rhs: Self) -> Self {
306 Self {
307 q: (self.q * rhs.q).normalize(),
308 }
309 }
310}
311
312impl std::ops::MulAssign for UnitQuaternion<f64> {
314 fn mul_assign(&mut self, rhs: Self) {
315 self.q = (self.q * rhs.q).normalize();
316 }
317}
318
319#[cfg(test)]
324mod tests {
325 use super::*;
326
327 #[test]
328 fn test_identity_rotation() {
329 let q = UnitQuaternion::identity();
330 let v = Vector3::new(1.0, 0.0, 0.0);
331 let rotated = q * v;
332 assert!((rotated.x - 1.0).abs() < 1e-10);
333 assert!(rotated.y.abs() < 1e-10);
334 }
335
336 #[test]
337 fn test_slerp_endpoints() {
338 let a = UnitQuaternion::identity();
339 let b = UnitQuaternion::from_euler_angles(0.5, 0.0, 0.0);
340 let at0 = a.slerp(&b, 0.0);
341 let at1 = a.slerp(&b, 1.0);
342 assert!((at0.q.w - a.q.w).abs() < 1e-6);
343 assert!((at1.q.w - b.q.w).abs() < 1e-6);
344 }
345
346 #[test]
347 fn test_euler_roundtrip() {
348 let q = UnitQuaternion::from_euler_angles(0.1, 0.2, 0.3);
349 let (r, p, y) = q.euler_angles();
350 assert!((r - 0.1).abs() < 1e-6);
351 assert!((p - 0.2).abs() < 1e-6);
352 assert!((y - 0.3).abs() < 1e-6);
353 }
354
355 #[test]
356 fn test_from_matrix_identity() {
357 let m = Matrix3::identity();
358 let q = UnitQuaternion::from_matrix(&m);
359 assert!((q.q.w - 1.0).abs() < 1e-6);
360 }
361
362 #[test]
363 fn test_axis_angle_90_deg() {
364 let axis = Unit::new_normalize(Vector3::new(0.0, 0.0, 1.0));
365 let q = UnitQuaternion::from_axis_angle(&axis, std::f64::consts::FRAC_PI_2);
366 let v = q * Vector3::new(1.0, 0.0, 0.0);
367 assert!(v.x.abs() < 1e-6);
368 assert!((v.y - 1.0).abs() < 1e-6);
369 }
370}