#![no_std]
#![forbid(unsafe_code)]
#![warn(missing_docs)]
use core::ops::Mul;
use embedded_f32_sqrt::sqrt;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum So3Error {
InvalidNorm,
NonFiniteValue,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Rotation {
w: f32,
x: f32,
y: f32,
z: f32,
}
impl Rotation {
pub const IDENTITY: Self = Self { w: 1.0, x: 0.0, y: 0.0, z: 0.0 };
pub fn new(w: f32, x: f32, y: f32, z: f32) -> Result<Self, So3Error> {
Self::normalize_raw(w, x, y, z)
}
fn normalize_raw(w: f32, x: f32, y: f32, z: f32) -> Result<Self, So3Error> {
let norm_sq = w * w + x * x + y * y + z * z;
if !norm_sq.is_finite() {
return Err(So3Error::NonFiniteValue);
}
let n = sqrt(norm_sq).map_err(|_| So3Error::InvalidNorm)?;
if n < 1e-10 {
return Err(So3Error::InvalidNorm);
}
Ok(Self {
w: w / n,
x: x / n,
y: y / n,
z: z / n,
})
}
pub fn compose(&self, other: &Rotation) -> Result<Self, So3Error> {
let w = self.w * other.w - self.x * other.x - self.y * other.y - self.z * other.z;
let x = self.w * other.x + self.x * other.w + self.y * other.z - self.z * other.y;
let y = self.w * other.y - self.x * other.z + self.y * other.w + self.z * other.x;
let z = self.w * other.z + self.x * other.y - self.y * other.x + self.z * other.w;
Self::normalize_raw(w, x, y, z)
}
pub fn rotate_vector(&self, v: [f32; 3]) -> [f32; 3] {
let [vx, vy, vz] = v;
let tx = 2.0 * (self.y * vz - self.z * vy);
let ty = 2.0 * (self.z * vx - self.x * vz);
let tz = 2.0 * (self.x * vy - self.y * vx);
[
vx + self.w * tx + (self.y * tz - self.z * ty),
vy + self.w * ty + (self.z * tx - self.x * tz),
vz + self.w * tz + (self.x * ty - self.y * tx),
]
}
pub fn inverse(&self) -> Self {
Self {
w: self.w,
x: -self.x,
y: -self.y,
z: -self.z,
}
}
pub fn as_array(&self) -> [f32; 4] {
[self.w, self.x, self.y, self.z]
}
}
impl Mul for Rotation {
type Output = Result<Rotation, So3Error>;
fn mul(self, rhs: Self) -> Self::Output {
self.compose(&rhs)
}
}
impl Mul for &Rotation {
type Output = Result<Rotation, So3Error>;
fn mul(self, rhs: Self) -> Self::Output {
self.compose(rhs)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_identity_stable() {
let v = [1.0, 2.0, 3.0];
let res = Rotation::IDENTITY.rotate_vector(v);
assert!((res[0] - 1.0).abs() < 1e-7);
assert!((res[1] - 2.0).abs() < 1e-7);
assert!((res[2] - 3.0).abs() < 1e-7);
}
#[test]
fn test_rotation_90_x() {
let s = core::f32::consts::FRAC_1_SQRT_2;
let rot = Rotation::new(s, s, 0.0, 0.0).unwrap();
let res = rot.rotate_vector([0.0, 1.0, 0.0]);
assert!(res[0].abs() < 1e-6, "x devrait être 0, obtenu {}", res[0]);
assert!(res[1].abs() < 1e-6, "y devrait être 0, obtenu {}", res[1]);
assert!((res[2] - 1.0).abs() < 1e-6, "z devrait être 1, obtenu {}", res[2]);
}
#[test]
fn test_mul_operator() {
let s = core::f32::consts::FRAC_1_SQRT_2;
let r90 = Rotation::new(s, s, 0.0, 0.0).unwrap();
let r180 = (r90 * r90).unwrap();
let res = r180.rotate_vector([0.0, 1.0, 0.0]);
assert!(res[0].abs() < 1e-6);
assert!((res[1] + 1.0).abs() < 1e-6, "y devrait être -1, obtenu {}", res[1]);
assert!(res[2].abs() < 1e-6);
}
#[test]
fn test_mul_ref_operator() {
let s = core::f32::consts::FRAC_1_SQRT_2;
let r = Rotation::new(s, 0.0, s, 0.0).unwrap();
let via_compose = r.compose(&r).unwrap();
let via_mul = (&r * &r).unwrap();
assert_eq!(via_compose.as_array(), via_mul.as_array());
}
#[test]
fn test_composition_stability() {
let s = core::f32::consts::FRAC_1_SQRT_2;
let r1 = Rotation::new(s, 0.0, s, 0.0).unwrap(); let mut current = Rotation::IDENTITY;
for _ in 0..100 {
current = current.compose(&r1).unwrap();
}
let [w, x, y, z] = current.as_array();
let norm_sq = w * w + x * x + y * y + z * z;
assert!(
(norm_sq - 1.0).abs() < 1e-6,
"La norme a dérivé : {}",
norm_sq
);
}
#[test]
fn test_inverse_property() {
let rot = Rotation::new(1.0, 2.0, 3.0, 4.0).unwrap();
let res = rot.compose(&rot.inverse()).unwrap();
assert!((res.w - 1.0).abs() < 1e-6);
assert!(res.x.abs() < 1e-6);
assert!(res.y.abs() < 1e-6);
assert!(res.z.abs() < 1e-6);
}
#[test]
fn test_inverse_via_mul() {
let rot = Rotation::new(1.0, 2.0, 3.0, 4.0).unwrap();
let res = (rot * rot.inverse()).unwrap();
assert!((res.w - 1.0).abs() < 1e-6);
assert!(res.x.abs() < 1e-6);
assert!(res.y.abs() < 1e-6);
assert!(res.z.abs() < 1e-6);
}
#[test]
fn test_invalid_inputs() {
assert_eq!(
Rotation::new(0.0, 0.0, 0.0, 0.0),
Err(So3Error::InvalidNorm)
);
assert_eq!(
Rotation::new(f32::NAN, 0.0, 0.0, 0.0),
Err(So3Error::NonFiniteValue)
);
assert_eq!(
Rotation::new(f32::INFINITY, 0.0, 0.0, 0.0),
Err(So3Error::NonFiniteValue)
);
}
}