use core::marker::PhantomData;
use glam::{DMat3, DQuat, DVec3};
use crate::sealed::QuatSealed;
pub trait Layout: QuatSealed + 'static {
const NAME: &'static str;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ScalarFirst;
impl QuatSealed for ScalarFirst {}
impl Layout for ScalarFirst {
const NAME: &'static str = "ScalarFirst";
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ScalarLast;
impl QuatSealed for ScalarLast {}
impl Layout for ScalarLast {
const NAME: &'static str = "ScalarLast";
}
pub trait Transform: QuatSealed + 'static {
const NAME: &'static str;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct LeftTransform;
impl QuatSealed for LeftTransform {}
impl Transform for LeftTransform {
const NAME: &'static str = "LeftTransform";
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct RightTransform;
impl QuatSealed for RightTransform {}
impl Transform for RightTransform {
const NAME: &'static str = "RightTransform";
}
#[repr(transparent)]
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Quat<L: Layout, T: Transform> {
pub data: [f64; 4],
_l: PhantomData<L>,
_t: PhantomData<T>,
}
impl<L: Layout, T: Transform> Quat<L, T> {
#[inline]
pub const fn from_array(data: [f64; 4]) -> Self {
Self {
data,
_l: PhantomData,
_t: PhantomData,
}
}
#[inline]
pub fn norm_squared(&self) -> f64 {
self.data[0] * self.data[0]
+ self.data[1] * self.data[1]
+ self.data[2] * self.data[2]
+ self.data[3] * self.data[3]
}
#[inline]
pub fn norm(&self) -> f64 {
self.norm_squared().sqrt()
}
}
pub type JeodQuat = Quat<ScalarFirst, LeftTransform>;
impl<T: Transform> Quat<ScalarFirst, T> {
#[inline]
pub fn to_scalar_last(self) -> Quat<ScalarLast, T> {
Quat::from_array([self.data[1], self.data[2], self.data[3], self.data[0]])
}
}
impl<T: Transform> Quat<ScalarLast, T> {
#[inline]
pub fn to_scalar_first(self) -> Quat<ScalarFirst, T> {
Quat::from_array([self.data[3], self.data[0], self.data[1], self.data[2]])
}
}
impl Quat<ScalarLast, LeftTransform> {
#[inline]
pub fn to_glam(self) -> DQuat {
DQuat::from_xyzw(self.data[0], self.data[1], self.data[2], self.data[3])
}
}
impl From<DQuat> for Quat<ScalarLast, LeftTransform> {
#[inline]
fn from(q: DQuat) -> Self {
Self::from_array([q.x, q.y, q.z, q.w])
}
}
#[derive(Debug, thiserror::Error)]
#[error("quaternion norm {norm} deviates from 1 by {deviation:.3e}, which exceeds tolerance {tolerance:.3e}")]
pub struct NotNormalized {
pub norm: f64,
pub deviation: f64,
pub tolerance: f64,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct NormalizedQuat<L: Layout, T: Transform>(Quat<L, T>);
impl<L: Layout, T: Transform> NormalizedQuat<L, T> {
pub const DEFAULT_TOLERANCE: f64 = 1e-12;
#[inline]
pub fn new(q: Quat<L, T>) -> Result<Self, NotNormalized> {
Self::new_with_tolerance(q, Self::DEFAULT_TOLERANCE)
}
#[inline]
pub fn new_with_tolerance(q: Quat<L, T>, tolerance: f64) -> Result<Self, NotNormalized> {
let norm = q.norm();
let deviation = (norm - 1.0).abs();
if deviation <= tolerance {
Ok(Self(q))
} else {
Err(NotNormalized {
norm,
deviation,
tolerance,
})
}
}
#[inline]
pub fn renormalize(q: Quat<L, T>) -> Option<Self> {
let n = q.norm();
if !(n.is_finite() && n > 0.0) {
return None;
}
let inv = 1.0 / n;
Some(Self(Quat::from_array([
q.data[0] * inv,
q.data[1] * inv,
q.data[2] * inv,
q.data[3] * inv,
])))
}
#[inline]
pub const fn inner(self) -> Quat<L, T> {
self.0
}
}
pub const NORM_LIMIT: f64 = 2.107_342e-8;
impl JeodQuat {
#[inline]
pub const fn identity() -> Self {
Self::from_array([1.0, 0.0, 0.0, 0.0])
}
#[inline]
pub const fn new(scalar: f64, vx: f64, vy: f64, vz: f64) -> Self {
Self::from_array([scalar, vx, vy, vz])
}
#[inline]
pub fn scalar(&self) -> f64 {
self.data[0]
}
#[inline]
pub fn vector(&self) -> DVec3 {
DVec3::new(self.data[1], self.data[2], self.data[3])
}
#[inline]
pub fn norm_sq(&self) -> f64 {
self.norm_squared()
}
#[inline]
pub fn to_glam(&self) -> DQuat {
DQuat::from_xyzw(self.data[1], self.data[2], self.data[3], self.data[0])
}
#[inline]
pub fn from_glam(q: DQuat) -> Self {
Self::from_array([q.w, q.x, q.y, q.z])
}
#[inline]
pub fn conjugate(&self) -> Self {
Self::from_array([self.data[0], -self.data[1], -self.data[2], -self.data[3]])
}
pub fn multiply(&self, other: &Self) -> Self {
let s1 = self.scalar();
let v1 = self.vector();
let s2 = other.scalar();
let v2 = other.vector();
let ps = s1 * s2 - v1.dot(v2);
let pv = v2 * s1 + v1 * s2 + v1.cross(v2);
Self::from_array([ps, pv.x, pv.y, pv.z])
}
pub fn normalize(&mut self) {
let qmagsq = self.norm_squared();
assert!(qmagsq > 0.0, "cannot normalize a zero quaternion");
let fact = if (1.0 - qmagsq).abs() < NORM_LIMIT {
2.0 / (1.0 + qmagsq)
} else {
1.0 / qmagsq.sqrt()
};
for d in self.data.iter_mut() {
*d *= fact;
}
if self.data[0] < 0.0 {
for d in self.data.iter_mut() {
*d = -*d;
}
}
}
pub fn left_quat_from_eigen_rotation(angle: f64, axis: DVec3) -> Self {
let half = angle * 0.5;
let s = half.cos();
let v = -half.sin() * axis;
let mut q = Self::from_array([s, v.x, v.y, v.z]);
q.normalize();
q
}
pub fn left_quat_to_transformation(&self) -> DMat3 {
left_quat_to_transformation_impl(self)
}
pub fn left_quat_from_transformation(mat: &DMat3) -> Self {
let t = |r: usize, c: usize| -> f64 { mat.col(c)[r] };
let tr = t(0, 0) + t(1, 1) + t(2, 2);
let vals = [tr, t(0, 0), t(1, 1), t(2, 2)];
let max_idx = vals
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(core::cmp::Ordering::Equal))
.unwrap()
.0;
let mut q = [0.0_f64; 4];
match max_idx {
0 => {
q[0] = 0.5 * (1.0 + tr).sqrt();
let inv4qs = 0.25 / q[0];
q[1] = (t(2, 1) - t(1, 2)) * inv4qs;
q[2] = (t(0, 2) - t(2, 0)) * inv4qs;
q[3] = (t(1, 0) - t(0, 1)) * inv4qs;
}
1 => {
q[1] = 0.5 * (1.0 + 2.0 * t(0, 0) - tr).sqrt();
let inv4qv0 = 0.25 / q[1];
q[0] = (t(2, 1) - t(1, 2)) * inv4qv0;
q[2] = (t(0, 1) + t(1, 0)) * inv4qv0;
q[3] = (t(0, 2) + t(2, 0)) * inv4qv0;
}
2 => {
q[2] = 0.5 * (1.0 + 2.0 * t(1, 1) - tr).sqrt();
let inv4qv1 = 0.25 / q[2];
q[0] = (t(0, 2) - t(2, 0)) * inv4qv1;
q[1] = (t(0, 1) + t(1, 0)) * inv4qv1;
q[3] = (t(1, 2) + t(2, 1)) * inv4qv1;
}
3 => {
q[3] = 0.5 * (1.0 + 2.0 * t(2, 2) - tr).sqrt();
let inv4qv2 = 0.25 / q[3];
q[0] = (t(1, 0) - t(0, 1)) * inv4qv2;
q[1] = (t(0, 2) + t(2, 0)) * inv4qv2;
q[2] = (t(1, 2) + t(2, 1)) * inv4qv2;
}
_ => unreachable!(),
}
if q[0] < 0.0 {
for v in q.iter_mut() {
*v = -*v;
}
}
let mut result = Self::from_array(q);
result.normalize();
result
}
pub fn left_quat_transform(&self, v: DVec3) -> DVec3 {
let qs = self.scalar();
let qv = self.vector();
let qv_cross_v = qv.cross(v);
v + 2.0 * qs * qv_cross_v + 2.0 * qv.cross(qv_cross_v)
}
}
impl NormalizedQuat<ScalarFirst, LeftTransform> {
#[inline]
pub fn left_quat_to_transformation(&self) -> DMat3 {
left_quat_to_transformation_impl(&self.inner())
}
#[inline]
pub fn left_quat_transform(&self, v: DVec3) -> DVec3 {
self.inner().left_quat_transform(v)
}
}
#[inline]
fn left_quat_to_transformation_impl(q: &JeodQuat) -> DMat3 {
let qs = q.data[0];
let qv = [q.data[1], q.data[2], q.data[3]];
let cost = 2.0 * qs * qs - 1.0;
let t00 = cost + 2.0 * qv[0] * qv[0];
let t11 = cost + 2.0 * qv[1] * qv[1];
let t22 = cost + 2.0 * qv[2] * qv[2];
let t01 = 2.0 * (qv[0] * qv[1] - qs * qv[2]);
let t10 = 2.0 * (qv[1] * qv[0] + qs * qv[2]);
let t02 = 2.0 * (qv[0] * qv[2] + qs * qv[1]);
let t20 = 2.0 * (qv[2] * qv[0] - qs * qv[1]);
let t12 = 2.0 * (qv[1] * qv[2] - qs * qv[0]);
let t21 = 2.0 * (qv[2] * qv[1] + qs * qv[0]);
DMat3::from_cols(
DVec3::new(t00, t10, t20), DVec3::new(t01, t11, t21), DVec3::new(t02, t12, t22), )
}