use super::matrix::Matrix3;
use super::vector::Vector3;
use serde::{Deserialize, Serialize};
use std::ops::Mul;
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct Quaternion<T> {
pub w: T,
pub x: T,
pub y: T,
pub z: T,
}
impl Quaternion<f64> {
#[must_use]
pub fn new(w: f64, x: f64, y: f64, z: f64) -> Self {
Self { w, x, y, z }
}
#[must_use]
pub fn norm(&self) -> f64 {
(self.w * self.w + self.x * self.x + self.y * self.y + self.z * self.z).sqrt()
}
#[must_use]
pub fn normalize(&self) -> Self {
let n = self.norm();
if n < 1e-15 {
return *self;
}
Self::new(self.w / n, self.x / n, self.y / n, self.z / n)
}
#[must_use]
pub fn conjugate(&self) -> Self {
Self::new(self.w, -self.x, -self.y, -self.z)
}
}
impl Mul for Quaternion<f64> {
type Output = Self;
fn mul(self, rhs: Self) -> Self {
Self::new(
self.w * rhs.w - self.x * rhs.x - self.y * rhs.y - self.z * rhs.z,
self.w * rhs.x + self.x * rhs.w + self.y * rhs.z - self.z * rhs.y,
self.w * rhs.y - self.x * rhs.z + self.y * rhs.w + self.z * rhs.x,
self.w * rhs.z + self.x * rhs.y - self.y * rhs.x + self.z * rhs.w,
)
}
}
#[derive(Debug, Clone, Copy)]
pub struct Unit<T>(pub T);
impl Unit<Vector3<f64>> {
#[must_use]
pub fn new_normalize(v: Vector3<f64>) -> Self {
Unit(v.normalize())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct UnitQuaternion<T> {
q: Quaternion<T>,
}
impl UnitQuaternion<f64> {
#[must_use]
pub fn identity() -> Self {
Self {
q: Quaternion::new(1.0, 0.0, 0.0, 0.0),
}
}
#[must_use]
pub fn from_quaternion(q: Quaternion<f64>) -> Self {
Self { q: q.normalize() }
}
#[must_use]
pub fn quaternion(&self) -> &Quaternion<f64> {
&self.q
}
#[must_use]
pub fn from_euler_angles(roll: f64, pitch: f64, yaw: f64) -> Self {
let (sr, cr) = (roll / 2.0).sin_cos();
let (sp, cp) = (pitch / 2.0).sin_cos();
let (sy, cy) = (yaw / 2.0).sin_cos();
let w = cr * cp * cy + sr * sp * sy;
let x = sr * cp * cy - cr * sp * sy;
let y = cr * sp * cy + sr * cp * sy;
let z = cr * cp * sy - sr * sp * cy;
Self {
q: Quaternion::new(w, x, y, z).normalize(),
}
}
#[must_use]
pub fn euler_angles(&self) -> (f64, f64, f64) {
let q = &self.q;
let sinr_cosp = 2.0 * (q.w * q.x + q.y * q.z);
let cosr_cosp = 1.0 - 2.0 * (q.x * q.x + q.y * q.y);
let roll = sinr_cosp.atan2(cosr_cosp);
let sinp = 2.0 * (q.w * q.y - q.z * q.x);
let pitch = if sinp.abs() >= 1.0 {
std::f64::consts::FRAC_PI_2.copysign(sinp)
} else {
sinp.asin()
};
let siny_cosp = 2.0 * (q.w * q.z + q.x * q.y);
let cosy_cosp = 1.0 - 2.0 * (q.y * q.y + q.z * q.z);
let yaw = siny_cosp.atan2(cosy_cosp);
(roll, pitch, yaw)
}
#[must_use]
pub fn from_axis_angle(axis: &Unit<Vector3<f64>>, angle: f64) -> Self {
let half = angle / 2.0;
let s = half.sin();
let c = half.cos();
let a = &axis.0;
Self {
q: Quaternion::new(c, a.x * s, a.y * s, a.z * s).normalize(),
}
}
#[must_use]
pub fn from_matrix(m: &Matrix3<f64>) -> Self {
let d = &m.data;
let trace = d[0][0] + d[1][1] + d[2][2];
let (w, x, y, z) = if trace > 0.0 {
let s = (trace + 1.0).sqrt() * 2.0; (
0.25 * s,
(d[2][1] - d[1][2]) / s,
(d[0][2] - d[2][0]) / s,
(d[1][0] - d[0][1]) / s,
)
} else if d[0][0] > d[1][1] && d[0][0] > d[2][2] {
let s = (1.0 + d[0][0] - d[1][1] - d[2][2]).sqrt() * 2.0;
(
(d[2][1] - d[1][2]) / s,
0.25 * s,
(d[0][1] + d[1][0]) / s,
(d[0][2] + d[2][0]) / s,
)
} else if d[1][1] > d[2][2] {
let s = (1.0 + d[1][1] - d[0][0] - d[2][2]).sqrt() * 2.0;
(
(d[0][2] - d[2][0]) / s,
(d[0][1] + d[1][0]) / s,
0.25 * s,
(d[1][2] + d[2][1]) / s,
)
} else {
let s = (1.0 + d[2][2] - d[0][0] - d[1][1]).sqrt() * 2.0;
(
(d[1][0] - d[0][1]) / s,
(d[0][2] + d[2][0]) / s,
(d[1][2] + d[2][1]) / s,
0.25 * s,
)
};
Self {
q: Quaternion::new(w, x, y, z).normalize(),
}
}
#[must_use]
pub fn slerp(&self, other: &Self, t: f64) -> Self {
let mut dot = self.q.w * other.q.w
+ self.q.x * other.q.x
+ self.q.y * other.q.y
+ self.q.z * other.q.z;
let mut other_q = other.q;
if dot < 0.0 {
other_q = Quaternion::new(-other_q.w, -other_q.x, -other_q.y, -other_q.z);
dot = -dot;
}
dot = dot.min(1.0);
if dot > 0.9995 {
let result = Quaternion::new(
self.q.w + t * (other_q.w - self.q.w),
self.q.x + t * (other_q.x - self.q.x),
self.q.y + t * (other_q.y - self.q.y),
self.q.z + t * (other_q.z - self.q.z),
);
return Self {
q: result.normalize(),
};
}
let theta = dot.acos();
let sin_theta = theta.sin();
let a = ((1.0 - t) * theta).sin() / sin_theta;
let b = (t * theta).sin() / sin_theta;
Self {
q: Quaternion::new(
a * self.q.w + b * other_q.w,
a * self.q.x + b * other_q.x,
a * self.q.y + b * other_q.y,
a * self.q.z + b * other_q.z,
)
.normalize(),
}
}
#[must_use]
pub fn to_rotation_matrix(&self) -> Matrix3<f64> {
let q = &self.q;
let xx = q.x * q.x;
let yy = q.y * q.y;
let zz = q.z * q.z;
let xy = q.x * q.y;
let xz = q.x * q.z;
let yz = q.y * q.z;
let wx = q.w * q.x;
let wy = q.w * q.y;
let wz = q.w * q.z;
let mut m = Matrix3::zeros();
m.data[0][0] = 1.0 - 2.0 * (yy + zz);
m.data[0][1] = 2.0 * (xy - wz);
m.data[0][2] = 2.0 * (xz + wy);
m.data[1][0] = 2.0 * (xy + wz);
m.data[1][1] = 1.0 - 2.0 * (xx + zz);
m.data[1][2] = 2.0 * (yz - wx);
m.data[2][0] = 2.0 * (xz - wy);
m.data[2][1] = 2.0 * (yz + wx);
m.data[2][2] = 1.0 - 2.0 * (xx + yy);
m
}
}
impl Mul<Vector3<f64>> for UnitQuaternion<f64> {
type Output = Vector3<f64>;
fn mul(self, v: Vector3<f64>) -> Vector3<f64> {
let qv = Quaternion::new(0.0, v.x, v.y, v.z);
let result = self.q * qv * self.q.conjugate();
Vector3::new(result.x, result.y, result.z)
}
}
impl Mul<Vector3<f64>> for &UnitQuaternion<f64> {
type Output = Vector3<f64>;
fn mul(self, v: Vector3<f64>) -> Vector3<f64> {
(*self) * v
}
}
impl Mul for UnitQuaternion<f64> {
type Output = Self;
fn mul(self, rhs: Self) -> Self {
Self {
q: (self.q * rhs.q).normalize(),
}
}
}
impl std::ops::MulAssign for UnitQuaternion<f64> {
fn mul_assign(&mut self, rhs: Self) {
self.q = (self.q * rhs.q).normalize();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_identity_rotation() {
let q = UnitQuaternion::identity();
let v = Vector3::new(1.0, 0.0, 0.0);
let rotated = q * v;
assert!((rotated.x - 1.0).abs() < 1e-10);
assert!(rotated.y.abs() < 1e-10);
}
#[test]
fn test_slerp_endpoints() {
let a = UnitQuaternion::identity();
let b = UnitQuaternion::from_euler_angles(0.5, 0.0, 0.0);
let at0 = a.slerp(&b, 0.0);
let at1 = a.slerp(&b, 1.0);
assert!((at0.q.w - a.q.w).abs() < 1e-6);
assert!((at1.q.w - b.q.w).abs() < 1e-6);
}
#[test]
fn test_euler_roundtrip() {
let q = UnitQuaternion::from_euler_angles(0.1, 0.2, 0.3);
let (r, p, y) = q.euler_angles();
assert!((r - 0.1).abs() < 1e-6);
assert!((p - 0.2).abs() < 1e-6);
assert!((y - 0.3).abs() < 1e-6);
}
#[test]
fn test_from_matrix_identity() {
let m = Matrix3::identity();
let q = UnitQuaternion::from_matrix(&m);
assert!((q.q.w - 1.0).abs() < 1e-6);
}
#[test]
fn test_axis_angle_90_deg() {
let axis = Unit::new_normalize(Vector3::new(0.0, 0.0, 1.0));
let q = UnitQuaternion::from_axis_angle(&axis, std::f64::consts::FRAC_PI_2);
let v = q * Vector3::new(1.0, 0.0, 0.0);
assert!(v.x.abs() < 1e-6);
assert!((v.y - 1.0).abs() < 1e-6);
}
}