use crate::{Mat3, Vec3};
#[derive(Debug, Clone, Copy)]
pub struct Quat {
pub w: f64,
pub v: Vec3,
}
impl Quat {
pub fn new(w: f64, x: f64, y: f64, z: f64) -> Self {
Self {
w,
v: Vec3::new(x, y, z),
}
}
pub fn identity() -> Self {
Self {
w: 1.0,
v: Vec3::zero(),
}
}
pub fn from_axis_angle(axis: &Vec3, angle: f64) -> Self {
let half_angle = angle * 0.5;
let (s, c) = half_angle.sin_cos();
Self { w: c, v: *axis * s }
}
pub fn normalize(&self) -> Self {
let norm = (self.w * self.w + self.v.norm_sq()).sqrt();
if norm < 1e-12 {
return Self::identity();
}
Self {
w: self.w / norm,
v: self.v / norm,
}
}
pub fn mul(&self, other: &Quat) -> Quat {
Quat {
w: self.w * other.w - self.v.dot(other.v),
v: self.v.cross(other.v) + other.v * self.w + self.v * other.w,
}
}
pub fn conjugate(&self) -> Quat {
Quat {
w: self.w,
v: -self.v,
}
}
pub fn to_matrix(&self) -> Mat3 {
let w = self.w;
let x = self.v.x;
let y = self.v.y;
let z = self.v.z;
let x2 = x * x;
let y2 = y * y;
let z2 = z * z;
let xy = x * y;
let xz = x * z;
let yz = y * z;
let wx = w * x;
let wy = w * y;
let wz = w * z;
Mat3::new(
1.0 - 2.0 * (y2 + z2),
2.0 * (xy - wz),
2.0 * (xz + wy),
2.0 * (xy + wz),
1.0 - 2.0 * (x2 + z2),
2.0 * (yz - wx),
2.0 * (xz - wy),
2.0 * (yz + wx),
1.0 - 2.0 * (x2 + y2),
)
}
pub fn from_matrix(m: &Mat3) -> Quat {
let trace = m[(0, 0)] + m[(1, 1)] + m[(2, 2)];
if trace > 0.0 {
let s = (trace + 1.0).sqrt() * 2.0; Quat {
w: 0.25 * s,
v: Vec3::new(
(m[(2, 1)] - m[(1, 2)]) / s,
(m[(0, 2)] - m[(2, 0)]) / s,
(m[(1, 0)] - m[(0, 1)]) / s,
),
}
} else if m[(0, 0)] > m[(1, 1)] && m[(0, 0)] > m[(2, 2)] {
let s = (1.0 + m[(0, 0)] - m[(1, 1)] - m[(2, 2)]).sqrt() * 2.0; Quat {
w: (m[(2, 1)] - m[(1, 2)]) / s,
v: Vec3::new(
0.25 * s,
(m[(0, 1)] + m[(1, 0)]) / s,
(m[(0, 2)] + m[(2, 0)]) / s,
),
}
} else if m[(1, 1)] > m[(2, 2)] {
let s = (1.0 + m[(1, 1)] - m[(0, 0)] - m[(2, 2)]).sqrt() * 2.0; Quat {
w: (m[(0, 2)] - m[(2, 0)]) / s,
v: Vec3::new(
(m[(0, 1)] + m[(1, 0)]) / s,
0.25 * s,
(m[(1, 2)] + m[(2, 1)]) / s,
),
}
} else {
let s = (1.0 + m[(2, 2)] - m[(0, 0)] - m[(1, 1)]).sqrt() * 2.0; Quat {
w: (m[(1, 0)] - m[(0, 1)]) / s,
v: Vec3::new(
(m[(0, 2)] + m[(2, 0)]) / s,
(m[(1, 2)] + m[(2, 1)]) / s,
0.25 * s,
),
}
}
}
pub fn exp(w: &Vec3) -> Quat {
let theta = w.norm();
if theta < 1e-10 {
Quat {
w: 1.0,
v: *w * 0.5,
}
.normalize()
} else {
let half_theta = theta * 0.5;
Quat {
w: half_theta.cos(),
v: *w * (half_theta.sin() / theta),
}
}
}
pub fn log(&self) -> Vec3 {
let v_norm = self.v.norm();
if v_norm < 1e-10 {
return Vec3::zero();
}
let angle = 2.0 * v_norm.atan2(self.w);
self.v * (angle / v_norm)
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPS: f64 = 1e-10;
#[test]
fn test_identity() {
let q = Quat::identity();
assert_eq!(q.w, 1.0);
assert_eq!(q.v, Vec3::zero());
}
#[test]
fn test_axis_angle() {
let axis = Vec3::new(0.0, 0.0, 1.0);
let angle = std::f64::consts::FRAC_PI_2; let q = Quat::from_axis_angle(&axis, angle);
let expected_w = (angle / 2.0).cos();
let expected_z = (angle / 2.0).sin();
assert!((q.w - expected_w).abs() < EPS);
assert!((q.v.z - expected_z).abs() < EPS);
}
#[test]
fn test_normalize() {
let q = Quat::new(1.0, 2.0, 3.0, 4.0);
let normalized = q.normalize();
let norm = (normalized.w * normalized.w + normalized.v.norm_sq()).sqrt();
assert!((norm - 1.0).abs() < EPS);
}
#[test]
fn test_multiplication() {
let axis = Vec3::new(0.0, 0.0, 1.0);
let q1 = Quat::from_axis_angle(&axis, std::f64::consts::FRAC_PI_2);
let q2 = Quat::from_axis_angle(&axis, std::f64::consts::FRAC_PI_2);
let result = q1.mul(&q2);
let expected = Quat::from_axis_angle(&axis, std::f64::consts::PI);
assert!((result.w - expected.w).abs() < EPS);
assert!((result.v.x - expected.v.x).abs() < EPS);
assert!((result.v.y - expected.v.y).abs() < EPS);
assert!((result.v.z - expected.v.z).abs() < EPS);
}
#[test]
fn test_to_matrix() {
let axis = Vec3::new(0.0, 0.0, 1.0);
let angle = std::f64::consts::FRAC_PI_2;
let q = Quat::from_axis_angle(&axis, angle);
let m = q.to_matrix();
let x = Vec3::new(1.0, 0.0, 0.0);
let y = m * x;
assert!((y.x - 0.0).abs() < EPS);
assert!((y.y - 1.0).abs() < EPS);
assert!((y.z - 0.0).abs() < EPS);
}
#[test]
fn test_matrix_roundtrip() {
let axis = Vec3::new(1.0, 2.0, 3.0).normalize();
let angle = 0.7;
let q = Quat::from_axis_angle(&axis, angle);
let m = q.to_matrix();
let q2 = Quat::from_matrix(&m);
let same = (q.w - q2.w).abs() < EPS && (q.v - q2.v).norm() < EPS;
let negated = (q.w + q2.w).abs() < EPS && (q.v + q2.v).norm() < EPS;
assert!(same || negated);
}
#[test]
fn test_exp_log() {
let w = Vec3::new(0.1, 0.2, 0.3);
let q = Quat::exp(&w);
let w2 = q.log();
assert!((w.x - w2.x).abs() < EPS);
assert!((w.y - w2.y).abs() < EPS);
assert!((w.z - w2.z).abs() < EPS);
}
#[test]
fn test_conjugate() {
let q = Quat::new(0.5, 0.5, 0.5, 0.5).normalize();
let conj = q.conjugate();
let result = q.mul(&conj);
assert!((result.w - 1.0).abs() < EPS);
assert!(result.v.norm() < EPS);
}
}