1use crate::{Mat3, Vec3};
6
7#[derive(Debug, Clone, Copy)]
9pub struct Quat {
10 pub w: f64,
12 pub v: Vec3,
14}
15
16impl Quat {
17 pub fn new(w: f64, x: f64, y: f64, z: f64) -> Self {
19 Self {
20 w,
21 v: Vec3::new(x, y, z),
22 }
23 }
24
25 pub fn identity() -> Self {
27 Self {
28 w: 1.0,
29 v: Vec3::zeros(),
30 }
31 }
32
33 pub fn from_axis_angle(axis: &Vec3, angle: f64) -> Self {
36 let half_angle = angle * 0.5;
37 let (s, c) = half_angle.sin_cos();
38 Self { w: c, v: axis * s }
39 }
40
41 pub fn normalize(&self) -> Self {
43 let norm = (self.w * self.w + self.v.norm_squared()).sqrt();
44 if norm < 1e-12 {
45 return Self::identity();
46 }
47 Self {
48 w: self.w / norm,
49 v: self.v / norm,
50 }
51 }
52
53 pub fn mul(&self, other: &Quat) -> Quat {
55 Quat {
56 w: self.w * other.w - self.v.dot(&other.v),
57 v: self.v.cross(&other.v) + other.v * self.w + self.v * other.w,
58 }
59 }
60
61 pub fn conjugate(&self) -> Quat {
63 Quat {
64 w: self.w,
65 v: -self.v,
66 }
67 }
68
69 pub fn to_matrix(&self) -> Mat3 {
71 let w = self.w;
72 let x = self.v.x;
73 let y = self.v.y;
74 let z = self.v.z;
75
76 let x2 = x * x;
77 let y2 = y * y;
78 let z2 = z * z;
79 let xy = x * y;
80 let xz = x * z;
81 let yz = y * z;
82 let wx = w * x;
83 let wy = w * y;
84 let wz = w * z;
85
86 Mat3::new(
87 1.0 - 2.0 * (y2 + z2),
88 2.0 * (xy - wz),
89 2.0 * (xz + wy),
90 2.0 * (xy + wz),
91 1.0 - 2.0 * (x2 + z2),
92 2.0 * (yz - wx),
93 2.0 * (xz - wy),
94 2.0 * (yz + wx),
95 1.0 - 2.0 * (x2 + y2),
96 )
97 }
98
99 pub fn from_matrix(m: &Mat3) -> Quat {
102 let trace = m[(0, 0)] + m[(1, 1)] + m[(2, 2)];
103
104 if trace > 0.0 {
105 let s = (trace + 1.0).sqrt() * 2.0; Quat {
107 w: 0.25 * s,
108 v: Vec3::new(
109 (m[(2, 1)] - m[(1, 2)]) / s,
110 (m[(0, 2)] - m[(2, 0)]) / s,
111 (m[(1, 0)] - m[(0, 1)]) / s,
112 ),
113 }
114 } else if m[(0, 0)] > m[(1, 1)] && m[(0, 0)] > m[(2, 2)] {
115 let s = (1.0 + m[(0, 0)] - m[(1, 1)] - m[(2, 2)]).sqrt() * 2.0; Quat {
117 w: (m[(2, 1)] - m[(1, 2)]) / s,
118 v: Vec3::new(
119 0.25 * s,
120 (m[(0, 1)] + m[(1, 0)]) / s,
121 (m[(0, 2)] + m[(2, 0)]) / s,
122 ),
123 }
124 } else if m[(1, 1)] > m[(2, 2)] {
125 let s = (1.0 + m[(1, 1)] - m[(0, 0)] - m[(2, 2)]).sqrt() * 2.0; Quat {
127 w: (m[(0, 2)] - m[(2, 0)]) / s,
128 v: Vec3::new(
129 (m[(0, 1)] + m[(1, 0)]) / s,
130 0.25 * s,
131 (m[(1, 2)] + m[(2, 1)]) / s,
132 ),
133 }
134 } else {
135 let s = (1.0 + m[(2, 2)] - m[(0, 0)] - m[(1, 1)]).sqrt() * 2.0; Quat {
137 w: (m[(1, 0)] - m[(0, 1)]) / s,
138 v: Vec3::new(
139 (m[(0, 2)] + m[(2, 0)]) / s,
140 (m[(1, 2)] + m[(2, 1)]) / s,
141 0.25 * s,
142 ),
143 }
144 }
145 }
146
147 pub fn exp(w: &Vec3) -> Quat {
151 let theta = w.norm();
152 if theta < 1e-10 {
153 Quat {
155 w: 1.0,
156 v: *w * 0.5,
157 }
158 .normalize()
159 } else {
160 let half_theta = theta * 0.5;
161 Quat {
162 w: half_theta.cos(),
163 v: *w * (half_theta.sin() / theta),
164 }
165 }
166 }
167
168 pub fn log(&self) -> Vec3 {
170 let v_norm = self.v.norm();
171 if v_norm < 1e-10 {
172 return Vec3::zeros();
173 }
174 let angle = 2.0 * v_norm.atan2(self.w);
177 self.v * (angle / v_norm)
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184 use approx::assert_relative_eq;
185
186 #[test]
187 fn test_identity() {
188 let q = Quat::identity();
189 assert_eq!(q.w, 1.0);
190 assert_eq!(q.v, Vec3::zeros());
191 }
192
193 #[test]
194 fn test_axis_angle() {
195 let axis = Vec3::new(0.0, 0.0, 1.0);
196 let angle = std::f64::consts::FRAC_PI_2; let q = Quat::from_axis_angle(&axis, angle);
198
199 let expected_w = (angle / 2.0).cos();
200 let expected_z = (angle / 2.0).sin();
201
202 assert_relative_eq!(q.w, expected_w, epsilon = 1e-10);
203 assert_relative_eq!(q.v.z, expected_z, epsilon = 1e-10);
204 }
205
206 #[test]
207 fn test_normalize() {
208 let q = Quat::new(1.0, 2.0, 3.0, 4.0);
209 let normalized = q.normalize();
210 let norm = (normalized.w * normalized.w + normalized.v.norm_squared()).sqrt();
211 assert_relative_eq!(norm, 1.0, epsilon = 1e-10);
212 }
213
214 #[test]
215 fn test_multiplication() {
216 let axis = Vec3::new(0.0, 0.0, 1.0);
218 let q1 = Quat::from_axis_angle(&axis, std::f64::consts::FRAC_PI_2);
219 let q2 = Quat::from_axis_angle(&axis, std::f64::consts::FRAC_PI_2);
220 let result = q1.mul(&q2);
221
222 let expected = Quat::from_axis_angle(&axis, std::f64::consts::PI);
224
225 assert_relative_eq!(result.w, expected.w, epsilon = 1e-10);
226 assert_relative_eq!(result.v, expected.v, epsilon = 1e-10);
227 }
228
229 #[test]
230 fn test_to_matrix() {
231 let axis = Vec3::new(0.0, 0.0, 1.0);
232 let angle = std::f64::consts::FRAC_PI_2;
233 let q = Quat::from_axis_angle(&axis, angle);
234 let m = q.to_matrix();
235
236 let x = Vec3::new(1.0, 0.0, 0.0);
238 let y = m * x;
239 assert_relative_eq!(y, Vec3::new(0.0, 1.0, 0.0), epsilon = 1e-10);
240 }
241
242 #[test]
243 fn test_matrix_roundtrip() {
244 let axis = Vec3::new(1.0, 2.0, 3.0).normalize();
245 let angle = 0.7;
246 let q = Quat::from_axis_angle(&axis, angle);
247 let m = q.to_matrix();
248 let q2 = Quat::from_matrix(&m);
249
250 let same = (q.w - q2.w).abs() < 1e-10 && (q.v - q2.v).norm() < 1e-10;
252 let negated = (q.w + q2.w).abs() < 1e-10 && (q.v + q2.v).norm() < 1e-10;
253
254 assert!(same || negated);
255 }
256
257 #[test]
258 fn test_exp_log() {
259 let w = Vec3::new(0.1, 0.2, 0.3);
260 let q = Quat::exp(&w);
261 let w2 = q.log();
262 assert_relative_eq!(w, w2, epsilon = 1e-10);
263 }
264
265 #[test]
266 fn test_conjugate() {
267 let q = Quat::new(0.5, 0.5, 0.5, 0.5).normalize();
268 let conj = q.conjugate();
269 let result = q.mul(&conj);
270 assert_relative_eq!(result.w, 1.0, epsilon = 1e-10);
271 assert_relative_eq!(result.v.norm(), 0.0, epsilon = 1e-10);
272 }
273}