use serde::{Deserialize, Serialize};
use std::{
fmt,
ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign},
};
use approxim::approx_derive::RelativeEq;
use rand::{
Rng, RngExt,
distr::{Distribution, StandardUniform},
};
use rand_distr::StandardNormal;
use crate::{Cartesian, Cross, Error, InnerProduct, Rotate, Rotation, RotationMatrix, Unit};
#[derive(Clone, Copy, Debug, PartialEq, RelativeEq, Serialize, Deserialize)]
pub struct Quaternion {
pub scalar: f64,
pub vector: Cartesian<3>,
}
impl Quaternion {
#[inline]
#[must_use]
pub fn norm_squared(&self) -> f64 {
self.scalar * self.scalar + self.vector.dot(&self.vector)
}
#[inline]
#[must_use]
pub fn norm(&self) -> f64 {
self.norm_squared().sqrt()
}
#[inline]
#[must_use]
pub fn conjugate(self) -> Self {
Self {
scalar: self.scalar,
vector: -self.vector,
}
}
#[inline]
pub fn to_versor(self) -> Result<Versor, Error> {
let mag = self.norm();
if mag == 0.0 {
Err(Error::InvalidQuaternionMagnitude)
} else {
Ok(Versor(self / mag))
}
}
#[inline]
#[must_use]
pub fn to_versor_unchecked(self) -> Versor {
Versor(self / self.norm())
}
}
impl From<[f64; 4]> for Quaternion {
#[inline]
fn from(value: [f64; 4]) -> Self {
Self {
scalar: value[0],
vector: [value[1], value[2], value[3]].into(),
}
}
}
impl fmt::Display for Quaternion {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "[{}, {}]", self.scalar, self.vector)
}
}
impl Add for Quaternion {
type Output = Self;
#[inline]
fn add(self, rhs: Self) -> Self {
Self {
scalar: self.scalar + rhs.scalar,
vector: self.vector + rhs.vector,
}
}
}
impl AddAssign for Quaternion {
#[inline]
fn add_assign(&mut self, rhs: Self) {
self.scalar += rhs.scalar;
self.vector += rhs.vector;
}
}
impl Div<f64> for Quaternion {
type Output = Self;
#[inline]
fn div(self, rhs: f64) -> Self {
Self {
scalar: self.scalar / rhs,
vector: self.vector / rhs,
}
}
}
impl DivAssign<f64> for Quaternion {
#[inline]
fn div_assign(&mut self, rhs: f64) {
self.scalar /= rhs;
self.vector /= rhs;
}
}
impl Mul<f64> for Quaternion {
type Output = Self;
#[inline]
fn mul(self, rhs: f64) -> Self {
Self {
scalar: self.scalar * rhs,
vector: self.vector * rhs,
}
}
}
impl MulAssign<f64> for Quaternion {
#[inline]
fn mul_assign(&mut self, rhs: f64) {
self.scalar *= rhs;
self.vector *= rhs;
}
}
impl Mul<Quaternion> for Quaternion {
type Output = Self;
#[inline]
fn mul(self, rhs: Quaternion) -> Self {
Self {
scalar: (self.scalar * rhs.scalar - self.vector.dot(&rhs.vector)),
vector: (rhs.vector * self.scalar
+ self.vector * rhs.scalar
+ self.vector.cross(&rhs.vector)),
}
}
}
impl MulAssign<Quaternion> for Quaternion {
#[inline]
fn mul_assign(&mut self, rhs: Quaternion) {
let result = *self * rhs;
self.scalar = result.scalar;
self.vector = result.vector;
}
}
impl Sub for Quaternion {
type Output = Self;
#[inline]
fn sub(self, rhs: Self) -> Self {
Self {
scalar: self.scalar - rhs.scalar,
vector: self.vector - rhs.vector,
}
}
}
impl SubAssign for Quaternion {
#[inline]
fn sub_assign(&mut self, rhs: Self) {
self.scalar -= rhs.scalar;
self.vector -= rhs.vector;
}
}
#[derive(Clone, Copy, Debug, PartialEq, RelativeEq, Serialize, Deserialize)]
pub struct Versor(Quaternion);
impl Versor {
#[inline]
fn dot_as_cartesian(&self, other: &Self) -> f64 {
self.get().scalar * other.get().scalar + self.get().vector.dot(&other.get().vector)
}
#[inline]
#[must_use]
pub fn from_axis_angle(axis: Unit<Cartesian<3>>, angle: f64) -> Self {
let Unit(axis_vector) = axis;
Versor(Quaternion {
scalar: (angle / 2.0).cos(),
vector: axis_vector * (angle / 2.0).sin(),
})
}
#[inline]
#[must_use]
pub fn normalized(self) -> Self {
let Versor(q) = self;
let f = 1.0 / q.norm();
Self(Quaternion {
scalar: q.scalar * f,
vector: q.vector * f,
})
}
#[inline]
#[must_use]
pub fn get(&self) -> &Quaternion {
&self.0
}
#[inline]
#[must_use]
pub fn arc_distance(&self, other: &Self) -> f64 {
self.dot_as_cartesian(other).acos()
}
#[inline]
#[must_use]
pub fn half_euclidean_norm_squared(&self, other: &Self) -> f64 {
1.0 - self.dot_as_cartesian(other)
}
}
impl From<Versor> for RotationMatrix<3> {
#[inline]
fn from(versor: Versor) -> RotationMatrix<3> {
let Versor(quaternion) = versor;
let a = quaternion.scalar;
let b = quaternion.vector[0];
let c = quaternion.vector[1];
let d = quaternion.vector[2];
RotationMatrix {
rows: [
[
a * a + b * b - c * c - d * d,
2.0 * b * c - 2.0 * a * d,
2.0 * b * d + 2.0 * a * c,
]
.into(),
[
2.0 * b * c + 2.0 * a * d,
a * a - b * b + c * c - d * d,
2.0 * c * d - 2.0 * a * b,
]
.into(),
[
2.0 * b * d - 2.0 * a * c,
2.0 * c * d + 2.0 * a * b,
a * a - b * b - c * c + d * d,
]
.into(),
],
}
}
}
impl Default for Versor {
#[inline]
fn default() -> Self {
Self(Quaternion {
scalar: 1.0,
vector: [0.0, 0.0, 0.0].into(),
})
}
}
impl Rotate<Cartesian<3>> for Versor {
type Matrix = RotationMatrix<3>;
#[inline]
fn rotate(&self, vector: &Cartesian<3>) -> Cartesian<3> {
let Versor(quaternion) = self;
*vector
* (quaternion.scalar * quaternion.scalar - quaternion.vector.dot(&quaternion.vector))
+ quaternion.vector.cross(vector) * (2.0 * quaternion.scalar)
+ quaternion.vector * (2.0 * quaternion.vector.dot(vector))
}
}
impl Rotation for Versor {
#[inline]
fn combine(&self, other: &Self) -> Self {
let Versor(a) = self;
let Versor(b) = other;
Versor(a.mul(*b))
}
#[inline]
fn identity() -> Self {
Self::default()
}
#[inline]
fn inverted(self) -> Self {
let Versor(quaternion) = self;
Versor(quaternion.conjugate())
}
}
impl fmt::Display for Versor {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl Distribution<Versor> for StandardUniform {
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Versor {
let scalar = rng.sample::<f64, _>(StandardNormal);
let vector = Cartesian::<3>::from(std::array::from_fn(|_| rng.sample(StandardNormal)));
let norm = (vector.norm_squared() + (scalar * scalar)).sqrt();
Versor(Quaternion {
scalar: scalar / norm,
vector: vector / norm,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use approxim::{assert_abs_diff_eq, assert_relative_eq};
use rand::{SeedableRng, rngs::StdRng};
use rstest::*;
use std::f64::consts::PI;
mod quaternion {
use super::*;
#[test]
fn from_array() {
let q = Quaternion::from([2.0, -3.0, 4.0, 7.0]);
assert!(q.scalar == 2.0);
assert!(q.vector == [-3.0, 4.0, 7.0].into());
}
#[test]
fn norm() {
let q = Quaternion::from([1.0, 4.0, -3.0, -2.0]);
assert_eq!(q.norm_squared(), 30.0);
assert_eq!(q.norm(), 30.0_f64.sqrt());
}
#[test]
fn conjugate() {
let q1 = Quaternion::from([1.0, -2.0, 4.0, -0.5]);
let q2 = q1.conjugate();
assert_eq!(q2, [1.0, 2.0, -4.0, 0.5].into());
assert_relative_eq!(q2 * q1, [q2.norm() * q1.norm(), 0.0, 0.0, 0.0].into());
}
#[test]
fn to_versor() {
let q = Quaternion::from([5.0, 3.0, -1.0, 1.0]);
assert_relative_eq!(
q.to_versor()
.expect("hard-coded quatnernion should be non zero"),
Versor(Quaternion {
scalar: 5.0 / 6.0,
vector: [3.0 / 6.0, -1.0 / 6.0, 1.0 / 6.0].into()
})
);
assert_relative_eq!(
q.to_versor_unchecked(),
Versor(Quaternion {
scalar: 5.0 / 6.0,
vector: [3.0 / 6.0, -1.0 / 6.0, 1.0 / 6.0].into()
})
);
let zero = Quaternion::from([0.0, 0.0, 0.0, 0.0]);
assert!(matches!(
zero.to_versor(),
Err(Error::InvalidQuaternionMagnitude)
));
}
#[test]
fn ops() {
let a = Quaternion::from([1.0, -2.0, 6.0, -4.0]);
let b = Quaternion::from([-2.0, 6.0, 4.0, 1.0]);
assert_eq!(a + b, [-1.0, 4.0, 10.0, -3.0].into());
let mut c = a;
c += b;
assert_eq!(c, [-1.0, 4.0, 10.0, -3.0].into());
assert_eq!(a - b, [3.0, -8.0, 2.0, -5.0].into());
let mut c = a;
c -= b;
assert_eq!(c, [3.0, -8.0, 2.0, -5.0].into());
assert_eq!(a * 2.0, [2.0, -4.0, 12.0, -8.0].into());
let mut c = a;
c *= 2.0;
assert_eq!(c, [2.0, -4.0, 12.0, -8.0].into());
assert_eq!(a / 2.0, [0.5, -1.0, 3.0, -2.0].into());
let mut c = a;
c /= 2.0;
assert_eq!(c, [0.5, -1.0, 3.0, -2.0].into());
assert_eq!(a * b, [-10.0, 32.0, -30.0, -35.0].into());
let mut c = a;
c *= b;
assert_eq!(c, [-10.0, 32.0, -30.0, -35.0].into());
}
#[test]
fn display() {
let q = Quaternion {
scalar: 0.5,
vector: [0.125, -0.875, 2.125].into(),
};
let s = format!("{q}");
assert_eq!(s, "[0.5, [0.125, -0.875, 2.125]]");
}
}
mod versor {
use super::*;
#[test]
fn default() {
let a = Versor::default();
assert!(a.get() == &[1.0, 0.0, 0.0, 0.0].into());
}
#[test]
fn identity() {
let a = Versor::identity();
assert!(a.get() == &[1.0, 0.0, 0.0, 0.0].into());
}
#[rstest(
theta => [0.0, PI / 2.0, 1e-12 * PI, -3.0, 12345.6],
axis => [[1.0, 0.0, 0.0].try_into().expect("hard-coded vector should have non-zero length"), [1.0, -1.0, 1.0].try_into().expect("hard-coded vector should have non-zero length")],
)]
fn from_axis_angle(theta: f64, axis: Unit<Cartesian<3>>) {
let Unit(axis_vector) = axis;
let Versor(q) = Versor::from_axis_angle(axis, theta);
assert_relative_eq!(q.scalar, (theta / 2.0).cos());
assert_relative_eq!(q.vector, axis_vector * (theta / 2.0).sin());
}
#[rstest(
theta_1 => [0.0, PI / 2.0, -3.0],
theta_2 => [-0.0, -PI / 3.0, PI, 2.0 * PI]
)]
fn combine_same_axis(theta_1: f64, theta_2: f64) {
let axis = [1.0, 0.0, 0.0]
.try_into()
.expect("hard-coded vector should have non-zero length");
let Unit(axis_vector) = axis;
let a = Versor::from_axis_angle(axis, theta_1);
let b = Versor::from_axis_angle(axis, theta_2);
let c = a.combine(&b);
let theta = theta_1 + theta_2;
let Versor(q) = c;
assert_relative_eq!(q.scalar, (theta / 2.0).cos());
assert_relative_eq!(q.vector, axis_vector * (theta / 2.0).sin());
}
fn validate_rotations<R: Rotate<Cartesian<3>>>(z_pi_2: &R, y_pi_4: &R) {
assert_relative_eq!(
z_pi_2.rotate(&[0.0, 0.0, 1.0].into()),
[0.0, 0.0, 1.0].into()
);
assert_relative_eq!(
z_pi_2.rotate(&[1.0, 0.0, 4.25].into()),
[0.0, 1.0, 4.25].into()
);
assert_relative_eq!(
z_pi_2.rotate(&[0.0, 1.0, -8.75].into()),
[-1.0, 0.0, -8.75].into()
);
let sqrt_2_2 = 2.0_f64.sqrt() / 2.0;
assert_relative_eq!(
y_pi_4.rotate(&[0.0, -10.0, 0.0].into()),
[0.0, -10.0, 0.0].into()
);
assert_relative_eq!(
y_pi_4.rotate(&[1.0, -15.0, 0.0].into()),
[sqrt_2_2, -15.0, -sqrt_2_2].into()
);
assert_relative_eq!(
y_pi_4.rotate(&[sqrt_2_2, -15.0, -sqrt_2_2].into()),
[0.0, -15.0, -1.0].into()
);
}
#[test]
fn rotate() {
let z_pi_2 = Versor::from_axis_angle(
[0.0, 0.0, 1.0]
.try_into()
.expect("hard-coded vector should have non-zero length"),
PI / 2.0,
);
let y_pi_4 = Versor::from_axis_angle(
[0.0, 1.0, 0.0]
.try_into()
.expect("hard-coded vector should have non-zero length"),
PI / 4.0,
);
validate_rotations(&z_pi_2, &y_pi_4);
}
#[test]
fn precompute() {
let z_pi_2 = RotationMatrix::from(Versor::from_axis_angle(
[0.0, 0.0, 1.0]
.try_into()
.expect("hard-coded vector should have non-zero length"),
PI / 2.0,
));
let y_pi_4 = RotationMatrix::from(Versor::from_axis_angle(
[0.0, 1.0, 0.0]
.try_into()
.expect("hard-coded vector should have non-zero length"),
PI / 4.0,
));
validate_rotations(&z_pi_2, &y_pi_4);
}
#[test]
fn combine_different_axis() {
let a = Versor::from_axis_angle(
[1.0, 0.0, 0.0]
.try_into()
.expect("hard-coded vector should have non-zero length"),
PI / 4.0,
);
let b = Versor::from_axis_angle(
[0.0, 0.0, 1.0]
.try_into()
.expect("hard-coded vector should have non-zero length"),
PI / 2.0,
);
let q = a.combine(&b);
let v = q.rotate(&[1.0, 0.0, 0.0].into());
assert_relative_eq!(v, [0.0, 2.0_f64.sqrt() / 2.0, 2.0_f64.sqrt() / 2.0].into());
}
#[rstest(theta => [0.0, 1.0, 2.125])]
fn inverted(theta: f64) {
let q1 = Versor::from_axis_angle(
[1.0, 0.5, -2.0]
.try_into()
.expect("hard-coded vector should have non-zero length"),
theta,
);
let q2 = q1.inverted();
assert_relative_eq!(q1.combine(&q2), Versor::identity());
}
#[test]
fn display() {
let v = Versor(Quaternion {
scalar: 0.5,
vector: [0.125, -0.875, 2.125].into(),
});
let s = format!("{v}");
assert_eq!(s, "[0.5, [0.125, -0.875, 2.125]]");
}
#[test]
fn normalized() {
let v = Versor(Quaternion {
scalar: 5.0,
vector: [3.0, -1.0, 1.0].into(),
});
assert_relative_eq!(
v.normalized(),
Versor(Quaternion {
scalar: 5.0 / 6.0,
vector: [3.0 / 6.0, -1.0 / 6.0, 1.0 / 6.0].into()
})
);
}
#[test]
fn random() {
const CHECK_VECTORS: [Cartesian<3>; 3] = [
Cartesian {
coordinates: [1.0, 0.0, 0.0],
},
Cartesian {
coordinates: [0.0, 1.0, 0.0],
},
Cartesian {
coordinates: [1.0, 0.0, 1.0],
},
];
let samples: u32 = 20_000;
let reference = Cartesian::from([1.0, 0.0, 0.0]);
let mut dot_sums = [0.0; CHECK_VECTORS.len()];
let mut rng = StdRng::seed_from_u64(1);
for _ in 0..samples {
let q: Versor = rng.random();
assert_relative_eq!(q.get().norm_squared(), 1.0, max_relative = 1e-15);
let v = q.rotate(&reference);
for i in 0..CHECK_VECTORS.len() {
dot_sums[i] += v.dot(&CHECK_VECTORS[i]);
}
}
for dot_sum in dot_sums {
assert_abs_diff_eq!(dot_sum / f64::from(samples), 0.0, epsilon = 0.01);
}
}
}
}