1use crate::Vec3;
2use serde::{Deserialize, Serialize};
3use std::ops::Mul;
4
5#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
7#[repr(C, align(16))]
8pub struct Quat {
9 pub x: f32,
10 pub y: f32,
11 pub z: f32,
12 pub w: f32,
13}
14
15impl Default for Quat {
16 #[inline(always)]
17 fn default() -> Self {
18 Self::IDENTITY
19 }
20}
21
22impl Quat {
23 pub const IDENTITY: Self = Self {
25 x: 0.0,
26 y: 0.0,
27 z: 0.0,
28 w: 1.0,
29 };
30
31 #[inline(always)]
33 pub const fn from_xyzw(x: f32, y: f32, z: f32, w: f32) -> Self {
34 Self { x, y, z, w }
35 }
36
37 #[inline]
39 pub fn from_axis_angle(axis: Vec3, angle: f32) -> Self {
40 let half = angle * 0.5;
41 let s = half.sin();
42 let c = half.cos();
43 let a = axis.normalize();
44 Self {
45 x: a.x * s,
46 y: a.y * s,
47 z: a.z * s,
48 w: c,
49 }
50 }
51
52 #[inline]
54 pub fn from_euler(yaw: f32, pitch: f32, roll: f32) -> Self {
55 let (sy, cy) = (yaw * 0.5).sin_cos();
56 let (sp, cp) = (pitch * 0.5).sin_cos();
57 let (sr, cr) = (roll * 0.5).sin_cos();
58
59 Self {
60 x: cy * sp * cr + sy * cp * sr,
61 y: sy * cp * cr - cy * sp * sr,
62 z: cy * cp * sr - sy * sp * cr,
63 w: cy * cp * cr + sy * sp * sr,
64 }
65 }
66
67 #[inline(always)]
69 pub fn length(self) -> f32 {
70 (self.x * self.x + self.y * self.y + self.z * self.z + self.w * self.w).sqrt()
71 }
72
73 #[inline]
75 pub fn normalize(self) -> Self {
76 let inv = 1.0 / self.length();
77 Self {
78 x: self.x * inv,
79 y: self.y * inv,
80 z: self.z * inv,
81 w: self.w * inv,
82 }
83 }
84
85 #[inline]
87 pub fn inverse(self) -> Self {
88 Self {
89 x: -self.x,
90 y: -self.y,
91 z: -self.z,
92 w: self.w,
93 }
94 }
95
96 #[inline]
98 pub fn slerp(self, mut end: Self, t: f32) -> Self {
99 let mut dot = self.x * end.x + self.y * end.y + self.z * end.z + self.w * end.w;
100
101 if dot < 0.0 {
102 end = Self {
103 x: -end.x,
104 y: -end.y,
105 z: -end.z,
106 w: -end.w,
107 };
108 dot = -dot;
109 }
110
111 if dot > 0.9995 {
112 return Self {
113 x: self.x + (end.x - self.x) * t,
114 y: self.y + (end.y - self.y) * t,
115 z: self.z + (end.z - self.z) * t,
116 w: self.w + (end.w - self.w) * t,
117 }
118 .normalize();
119 }
120
121 let theta = dot.acos();
122 let sin_theta = theta.sin();
123 let s0 = ((1.0 - t) * theta).sin() / sin_theta;
124 let s1 = (t * theta).sin() / sin_theta;
125
126 Self {
127 x: self.x * s0 + end.x * s1,
128 y: self.y * s0 + end.y * s1,
129 z: self.z * s0 + end.z * s1,
130 w: self.w * s0 + end.w * s1,
131 }
132 }
133}
134
135impl Mul for Quat {
137 type Output = Self;
138 #[inline]
139 fn mul(self, rhs: Self) -> Self {
140 Self {
141 x: self.w * rhs.x + self.x * rhs.w + self.y * rhs.z - self.z * rhs.y,
142 y: self.w * rhs.y - self.x * rhs.z + self.y * rhs.w + self.z * rhs.x,
143 z: self.w * rhs.z + self.x * rhs.y - self.y * rhs.x + self.z * rhs.w,
144 w: self.w * rhs.w - self.x * rhs.x - self.y * rhs.y - self.z * rhs.z,
145 }
146 }
147}
148
149impl Mul<Vec3> for Quat {
151 type Output = Vec3;
152 #[inline]
153 fn mul(self, v: Vec3) -> Vec3 {
154 let u = Vec3::new(self.x, self.y, self.z);
155 let s = self.w;
156 u * (2.0 * u.dot(v)) + v * (s * s - u.dot(u)) + u.cross(v) * (2.0 * s)
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163 use std::f32::consts::FRAC_PI_2;
164
165 #[test]
166 fn identity_rotation() {
167 let v = Vec3::new(1.0, 2.0, 3.0);
168 let r = Quat::IDENTITY * v;
169 assert!((r.x - v.x).abs() < 1e-6);
170 assert!((r.y - v.y).abs() < 1e-6);
171 assert!((r.z - v.z).abs() < 1e-6);
172 }
173
174 #[test]
175 fn rotate_90_around_z() {
176 let q = Quat::from_axis_angle(Vec3::Z, FRAC_PI_2);
177 let r = q * Vec3::X;
178 assert!(r.x.abs() < 1e-5);
179 assert!((r.y - 1.0).abs() < 1e-5);
180 }
181
182 #[test]
183 fn inverse_undoes_rotation() {
184 let q = Quat::from_axis_angle(Vec3::Y, 1.0);
185 let v = Vec3::new(1.0, 2.0, 3.0);
186 let back = q.inverse() * (q * v);
187 assert!((back.x - v.x).abs() < 1e-4);
188 assert!((back.y - v.y).abs() < 1e-4);
189 assert!((back.z - v.z).abs() < 1e-4);
190 }
191
192 #[test]
193 fn slerp_halfway() {
194 let a = Quat::IDENTITY;
195 let b = Quat::from_axis_angle(Vec3::Z, FRAC_PI_2);
196 let mid = a.slerp(b, 0.5);
197 let v = mid * Vec3::X;
198 let expected = FRAC_PI_2 / 2.0;
199 assert!((v.x - expected.cos()).abs() < 1e-4);
200 assert!((v.y - expected.sin()).abs() < 1e-4);
201 }
202}