use std::ops::{Mul, MulAssign};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct UnitQuaternion {
w: f64,
x: f64,
y: f64,
z: f64,
}
impl UnitQuaternion {
#[inline]
#[must_use]
pub fn w(&self) -> f64 {
self.w
}
#[inline]
#[must_use]
pub fn x(&self) -> f64 {
self.x
}
#[inline]
#[must_use]
pub fn y(&self) -> f64 {
self.y
}
#[inline]
#[must_use]
pub fn z(&self) -> f64 {
self.z
}
#[inline]
#[must_use]
pub fn components(&self) -> (f64, f64, f64, f64) {
(self.w, self.x, self.y, self.z)
}
#[inline]
#[must_use]
pub fn to_array(&self) -> [f64; 4] {
[self.w, self.x, self.y, self.z]
}
}
const NORM_EPSILON: f64 = 1e-10;
const IDENTITY_EPSILON: f64 = 1e-10;
#[allow(dead_code)]
const SINGULARITY_EPSILON: f64 = 1e-6;
impl UnitQuaternion {
#[must_use]
pub fn identity() -> Self {
Self {
w: 1.0,
x: 0.0,
y: 0.0,
z: 0.0,
}
}
#[must_use]
pub fn new(w: f64, x: f64, y: f64, z: f64) -> Self {
let norm = (w * w + x * x + y * y + z * z).sqrt();
if norm < NORM_EPSILON {
return Self::identity();
}
Self {
w: w / norm,
x: x / norm,
y: y / norm,
z: z / norm,
}
}
#[must_use]
pub fn from_axis_angle(axis: [f64; 3], angle: f64) -> Self {
let [nx, ny, nz] = axis;
let axis_norm = (nx * nx + ny * ny + nz * nz).sqrt();
if axis_norm < NORM_EPSILON {
return Self::identity();
}
let half_angle = angle / 2.0;
let (sin_half, cos_half) = half_angle.sin_cos();
Self {
w: cos_half,
x: sin_half * nx / axis_norm,
y: sin_half * ny / axis_norm,
z: sin_half * nz / axis_norm,
}
}
#[must_use]
pub fn to_axis_angle(&self) -> ([f64; 3], f64) {
let angle = 2.0 * self.w.clamp(-1.0, 1.0).acos();
let sin_half = (1.0 - self.w * self.w).max(0.0).sqrt();
if sin_half < IDENTITY_EPSILON {
return ([1.0, 0.0, 0.0], 0.0);
}
let axis = [self.x / sin_half, self.y / sin_half, self.z / sin_half];
(axis, angle)
}
#[must_use]
pub fn rotation_x(theta: f64) -> Self {
Self::from_axis_angle([1.0, 0.0, 0.0], theta)
}
#[must_use]
pub fn rotation_y(theta: f64) -> Self {
Self::from_axis_angle([0.0, 1.0, 0.0], theta)
}
#[must_use]
pub fn rotation_z(theta: f64) -> Self {
Self::from_axis_angle([0.0, 0.0, 1.0], theta)
}
#[must_use]
pub fn conjugate(&self) -> Self {
Self {
w: self.w,
x: -self.x,
y: -self.y,
z: -self.z,
}
}
#[must_use]
pub fn inverse(&self) -> Self {
self.conjugate()
}
#[must_use]
pub fn norm_squared(&self) -> f64 {
self.w * self.w + self.x * self.x + self.y * self.y + self.z * self.z
}
#[must_use]
pub fn norm(&self) -> f64 {
self.norm_squared().sqrt()
}
#[must_use]
pub fn normalize(&self) -> Self {
Self::new(self.w, self.x, self.y, self.z)
}
#[must_use]
pub fn distance_to_identity(&self) -> f64 {
2.0 * self.w.abs().clamp(-1.0, 1.0).acos()
}
#[must_use]
pub fn distance_to(&self, other: &Self) -> f64 {
let dot = self.w * other.w + self.x * other.x + self.y * other.y + self.z * other.z;
dot.abs().clamp(-1.0, 1.0).acos()
}
#[must_use]
pub fn slerp(&self, other: &Self, t: f64) -> Self {
let dot = self.w * other.w + self.x * other.x + self.y * other.y + self.z * other.z;
if dot.abs() > 0.9995 {
return Self::new(
self.w + t * (other.w - self.w),
self.x + t * (other.x - self.x),
self.y + t * (other.y - self.y),
self.z + t * (other.z - self.z),
);
}
let (other_w, other_x, other_y, other_z) = if dot < 0.0 {
(-other.w, -other.x, -other.y, -other.z)
} else {
(other.w, other.x, other.y, other.z)
};
let theta = dot.abs().acos();
let sin_theta = theta.sin();
let a = ((1.0 - t) * theta).sin() / sin_theta;
let b = (t * theta).sin() / sin_theta;
Self::new(
a * self.w + b * other_w,
a * self.x + b * other_x,
a * self.y + b * other_y,
a * self.z + b * other_z,
)
}
#[must_use]
pub fn exp(v: [f64; 3]) -> Self {
let norm = (v[0] * v[0] + v[1] * v[1] + v[2] * v[2]).sqrt();
if norm < NORM_EPSILON {
return Self::identity();
}
let half_norm = norm / 2.0;
let (sin_half, cos_half) = half_norm.sin_cos();
Self {
w: cos_half,
x: sin_half * v[0] / norm,
y: sin_half * v[1] / norm,
z: sin_half * v[2] / norm,
}
}
#[must_use]
pub fn log(&self) -> [f64; 3] {
let theta = 2.0 * self.w.clamp(-1.0, 1.0).acos();
let sin_half = (1.0 - self.w * self.w).max(0.0).sqrt();
if sin_half < IDENTITY_EPSILON {
return [0.0, 0.0, 0.0];
}
let scale = theta / sin_half;
[scale * self.x, scale * self.y, scale * self.z]
}
#[must_use]
pub fn to_matrix(&self) -> [[num_complex::Complex64; 2]; 2] {
use num_complex::Complex64;
[
[
Complex64::new(self.w, self.x),
Complex64::new(-self.y, self.z),
],
[
Complex64::new(self.y, self.z),
Complex64::new(self.w, -self.x),
],
]
}
#[must_use]
pub fn from_matrix(matrix: [[num_complex::Complex64; 2]; 2]) -> Self {
let a = matrix[0][0].re;
let b = matrix[0][0].im;
let c = matrix[1][0].re;
let d = matrix[1][0].im;
Self::new(a, b, c, d)
}
#[must_use]
pub fn rotate_vector(&self, v: [f64; 3]) -> [f64; 3] {
let (w, qx, qy, qz) = (self.w, self.x, self.y, self.z);
let (vx, vy, vz) = (v[0], v[1], v[2]);
let cx = qy * vz - qz * vy;
let cy = qz * vx - qx * vz;
let cz = qx * vy - qy * vx;
let ccx = qy * cz - qz * cy;
let ccy = qz * cx - qx * cz;
let ccz = qx * cy - qy * cx;
[
vx + 2.0 * (w * cx + ccx),
vy + 2.0 * (w * cy + ccy),
vz + 2.0 * (w * cz + ccz),
]
}
#[must_use]
pub fn verify_unit(&self, tolerance: f64) -> bool {
(self.norm_squared() - 1.0).abs() < tolerance
}
#[must_use]
pub fn to_rotation_matrix(&self) -> [[f64; 3]; 3] {
let (w, x, y, z) = (self.w, self.x, self.y, self.z);
[
[
1.0 - 2.0 * (y * y + z * z),
2.0 * (x * y - w * z),
2.0 * (x * z + w * y),
],
[
2.0 * (x * y + w * z),
1.0 - 2.0 * (x * x + z * z),
2.0 * (y * z - w * x),
],
[
2.0 * (x * z - w * y),
2.0 * (y * z + w * x),
1.0 - 2.0 * (x * x + y * y),
],
]
}
}
impl Mul for UnitQuaternion {
type Output = Self;
fn mul(self, rhs: Self) -> Self {
Self::new(
self.w * rhs.w - self.x * rhs.x - self.y * rhs.y - self.z * rhs.z,
self.w * rhs.x + self.x * rhs.w + self.y * rhs.z - self.z * rhs.y,
self.w * rhs.y - self.x * rhs.z + self.y * rhs.w + self.z * rhs.x,
self.w * rhs.z + self.x * rhs.y - self.y * rhs.x + self.z * rhs.w,
)
}
}
impl MulAssign for UnitQuaternion {
fn mul_assign(&mut self, rhs: Self) {
*self = *self * rhs;
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::f64::consts::PI;
#[test]
fn test_identity() {
let q = UnitQuaternion::identity();
assert_eq!(q.w, 1.0);
assert_eq!(q.x, 0.0);
assert_eq!(q.y, 0.0);
assert_eq!(q.z, 0.0);
assert!(q.verify_unit(1e-10));
}
#[test]
fn test_normalization() {
let q = UnitQuaternion::new(1.0, 2.0, 3.0, 4.0);
assert!((q.norm() - 1.0).abs() < 1e-10);
assert!(q.verify_unit(1e-10));
}
#[test]
fn test_axis_angle_roundtrip() {
let axis = [1.0, 2.0, 3.0];
let angle = PI / 3.0;
let q = UnitQuaternion::from_axis_angle(axis, angle);
let (axis_out, angle_out) = q.to_axis_angle();
let axis_norm = (axis[0] * axis[0] + axis[1] * axis[1] + axis[2] * axis[2]).sqrt();
let axis_normalized = [
axis[0] / axis_norm,
axis[1] / axis_norm,
axis[2] / axis_norm,
];
assert!((angle - angle_out).abs() < 1e-10);
for i in 0..3 {
assert!((axis_normalized[i] - axis_out[i]).abs() < 1e-10);
}
println!(
"✓ Axis-angle roundtrip: axis=[{:.3},{:.3},{:.3}], angle={:.3}",
axis_out[0], axis_out[1], axis_out[2], angle_out
);
}
#[test]
fn test_rotation_x() {
let q = UnitQuaternion::rotation_x(PI / 2.0);
let v = [0.0, 1.0, 0.0];
let rotated = q.rotate_vector(v);
assert!(rotated[0].abs() < 1e-10);
assert!(rotated[1].abs() < 1e-10);
assert!((rotated[2] - 1.0).abs() < 1e-10);
println!(
"✓ Rotation X by π/2: (0,1,0) → ({:.3},{:.3},{:.3})",
rotated[0], rotated[1], rotated[2]
);
}
#[test]
fn test_quaternion_multiplication() {
let q1 = UnitQuaternion::rotation_x(PI / 4.0);
let q2 = UnitQuaternion::rotation_y(PI / 4.0);
let q3 = q1 * q2;
assert!(q3.verify_unit(1e-10));
let v = [1.0, 0.0, 0.0];
let rotated_composed = q3.rotate_vector(v);
let rotated_separate = q1.rotate_vector(q2.rotate_vector(v));
for i in 0..3 {
assert!((rotated_composed[i] - rotated_separate[i]).abs() < 1e-10);
}
println!("✓ Quaternion multiplication preserves unitarity and composition");
}
#[test]
fn test_inverse() {
let q = UnitQuaternion::rotation_z(PI / 3.0);
let q_inv = q.inverse();
let product = q * q_inv;
assert!((product.w - 1.0).abs() < 1e-10);
assert!(product.x.abs() < 1e-10);
assert!(product.y.abs() < 1e-10);
assert!(product.z.abs() < 1e-10);
println!("✓ Inverse: q · q⁻¹ = identity");
}
#[test]
fn test_distance_to_identity() {
let q_identity = UnitQuaternion::identity();
assert!(q_identity.distance_to_identity() < 1e-10);
let q_pi = UnitQuaternion::rotation_x(PI);
assert!((q_pi.distance_to_identity() - PI).abs() < 1e-10);
let q_half_pi = UnitQuaternion::rotation_y(PI / 2.0);
assert!((q_half_pi.distance_to_identity() - PI / 2.0).abs() < 1e-10);
println!(
"✓ Distance to identity: d(I)=0, d(π)={:.3}, d(π/2)={:.3}",
q_pi.distance_to_identity(),
q_half_pi.distance_to_identity()
);
}
#[test]
fn test_exp_log_roundtrip() {
let v = [0.5, 0.3, 0.2];
let q = UnitQuaternion::exp(v);
let v_out = q.log();
for i in 0..3 {
assert!((v[i] - v_out[i]).abs() < 1e-10);
}
println!(
"✓ Exp-log roundtrip: v=[{:.3},{:.3},{:.3}]",
v[0], v[1], v[2]
);
}
#[test]
fn test_slerp() {
let q1 = UnitQuaternion::rotation_x(0.0);
let q2 = UnitQuaternion::rotation_x(PI / 2.0);
let q_mid = q1.slerp(&q2, 0.5);
assert!((q_mid.distance_to_identity() - PI / 4.0).abs() < 1e-10);
let q_0 = q1.slerp(&q2, 0.0);
assert!(q_0.distance_to(&q1) < 1e-10);
let q_1 = q1.slerp(&q2, 1.0);
assert!(q_1.distance_to(&q2) < 1e-10);
println!("✓ SLERP: smooth interpolation on S³");
}
#[test]
fn test_near_pi_axis_angle_roundtrip() {
let axis = [0.0, 0.0, 1.0];
let angle = PI - 1e-8;
let q = UnitQuaternion::from_axis_angle(axis, angle);
assert!(q.verify_unit(1e-10));
let (axis_out, angle_out) = q.to_axis_angle();
assert!(
(angle - angle_out).abs() < 1e-6,
"Near-π angle roundtrip: got {}, expected {}",
angle_out,
angle
);
assert!(
(axis_out[2] - 1.0).abs() < 1e-6,
"Near-π axis roundtrip: got {:?}",
axis_out
);
let v = [1.0, 0.0, 0.0];
let rotated = q.rotate_vector(v);
assert!(
(rotated[0] + 1.0).abs() < 1e-6,
"Near-π rotation: got {:?}",
rotated
);
}
#[test]
fn test_exp_log_near_identity() {
let v = [1e-12, 0.0, 0.0];
let q = UnitQuaternion::exp(v);
assert!(q.verify_unit(1e-14));
assert!(q.distance_to_identity() < 1e-10);
let log_q = q.log();
for &c in &log_q {
assert!(c.abs() < 1e-8);
}
}
#[test]
fn test_matrix_conversion() {
let q = UnitQuaternion::rotation_z(PI / 3.0);
let matrix = q.to_matrix();
let q_back = UnitQuaternion::from_matrix(matrix);
assert!((q.w - q_back.w).abs() < 1e-10);
assert!((q.x - q_back.x).abs() < 1e-10);
assert!((q.y - q_back.y).abs() < 1e-10);
assert!((q.z - q_back.z).abs() < 1e-10);
println!("✓ Matrix conversion roundtrip preserves quaternion");
}
}