use crate::error::{SpatialError, SpatialResult};
use crate::transform::Rotation;
use scirs2_core::ndarray::{array, Array1};
#[allow(dead_code)]
fn euler_array(x: f64, y: f64, z: f64) -> Array1<f64> {
array![x, y, z]
}
#[allow(dead_code)]
fn rotation_from_euler(x: f64, y: f64, z: f64, convention: &str) -> SpatialResult<Rotation> {
let angles = euler_array(x, y, z);
let angles_view = angles.view();
Rotation::from_euler(&angles_view, convention)
}
#[derive(Clone, Debug)]
pub struct Slerp {
start: Rotation,
end: Rotation,
q1: Array1<f64>,
q2: Array1<f64>,
dot: f64,
}
impl Slerp {
pub fn new(start: Rotation, end: Rotation) -> SpatialResult<Self> {
let q1 = start.as_quat();
let mut q2 = end.as_quat();
let mut dot = 0.0;
for i in 0..4 {
dot += q1[i] * q2[i];
}
let dot = if dot < 0.0 {
for i in 0..4 {
q2[i] = -q2[i];
}
-dot
} else {
dot
};
if dot > 0.9999 {
return Err(SpatialError::ComputationError(
"Rotations are too close for stable Slerp calculation".into(),
));
}
Ok(Slerp {
start,
end,
q1,
q2,
dot,
})
}
pub fn interpolate(&self, t: f64) -> Rotation {
let t = t.clamp(0.0, 1.0);
if t <= 0.0 {
return self.start.clone();
}
if t >= 1.0 {
return self.end.clone();
}
let theta = self.dot.acos();
let scale1 = ((1.0 - t) * theta).sin() / theta.sin();
let scale2 = (t * theta).sin() / theta.sin();
let mut result = Array1::zeros(4);
for i in 0..4 {
result[i] = scale1 * self.q1[i] + scale2 * self.q2[i];
}
let norm = (result.iter().map(|&x| x * x).sum::<f64>()).sqrt();
result /= norm;
Rotation::from_quat(&result.view()).expect("Operation failed")
}
pub fn times(n: usize) -> Vec<f64> {
if n <= 1 {
return vec![0.0];
}
let mut times = Vec::with_capacity(n);
for i in 0..n {
times.push(i as f64 / (n - 1) as f64);
}
times
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use std::f64::consts::PI;
#[test]
fn test_slerp_identity() {
let rot1 = Rotation::identity();
let rot2 = rotation_from_euler(0.0, 0.0, PI / 2.0, "xyz").expect("Operation failed");
let slerp = Slerp::new(rot1.clone(), rot2.clone()).expect("Operation failed");
let interp_0 = slerp.interpolate(0.0);
assert_eq!(interp_0.as_quat(), rot1.as_quat());
let interp_1 = slerp.interpolate(1.0);
assert_eq!(interp_1.as_quat(), rot2.as_quat());
}
#[test]
fn test_slerp_halfway() {
let rot1 = Rotation::identity();
let angles = array![0.0, 0.0, PI];
let rot2 = Rotation::from_euler(&angles.view(), "xyz").expect("Operation failed");
let slerp = Slerp::new(rot1, rot2).expect("Operation failed");
let interp_half = slerp.interpolate(0.5);
let point = array![1.0, 0.0, 0.0];
let rotated = interp_half.apply(&point.view()).expect("Operation failed");
let magnitude =
(rotated[0] * rotated[0] + rotated[1] * rotated[1] + rotated[2] * rotated[2]).sqrt();
assert_relative_eq!(magnitude, 1.0, epsilon = 1e-10);
}
#[test]
fn test_slerp_at_values() {
let rot1 = Rotation::identity();
let angles = array![0.0, 0.0, PI];
let rot2 = Rotation::from_euler(&angles.view(), "xyz").expect("Operation failed");
let slerp = Slerp::new(rot1, rot2).expect("Operation failed");
let values = [0.25, 0.5, 0.75];
for t in values.iter() {
let interp = slerp.interpolate(*t);
let point = array![1.0, 0.0, 0.0];
let rotated = interp.apply(&point.view()).expect("Operation failed");
let magnitude =
(rotated[0] * rotated[0] + rotated[1] * rotated[1] + rotated[2] * rotated[2])
.sqrt();
assert_relative_eq!(magnitude, 1.0, epsilon = 1e-10);
assert!(rotated[1] >= 0.0);
}
}
#[test]
fn test_slerp_negative_dot() {
let rot1 =
Rotation::from_quat(&array![1.0, 0.0, 0.0, 0.0].view()).expect("Operation failed");
let rot2 = Rotation::from_quat(
&array![
-std::f64::consts::FRAC_1_SQRT_2,
0.0,
0.0,
std::f64::consts::FRAC_1_SQRT_2
]
.view(),
)
.expect("Operation failed");
let slerp = Slerp::new(rot1, rot2).expect("Operation failed");
let interp = slerp.interpolate(0.5);
let point = array![1.0, 0.0, 0.0];
let rotated = interp.apply(&point.view()).expect("Operation failed");
let magnitude =
(rotated[0] * rotated[0] + rotated[1] * rotated[1] + rotated[2] * rotated[2]).sqrt();
assert_relative_eq!(magnitude, 1.0, epsilon = 1e-10);
}
#[test]
fn test_slerp_times() {
let rot1 = Rotation::identity();
let angles = array![0.0, 0.0, PI];
let rot2 = Rotation::from_euler(&angles.view(), "xyz").expect("Operation failed");
let slerp = Slerp::new(rot1, rot2).expect("Operation failed");
let times = Slerp::times(5);
assert_eq!(times.len(), 5);
assert_relative_eq!(times[0], 0.0, epsilon = 1e-10);
assert_relative_eq!(times[1], 0.25, epsilon = 1e-10);
assert_relative_eq!(times[2], 0.5, epsilon = 1e-10);
assert_relative_eq!(times[3], 0.75, epsilon = 1e-10);
assert_relative_eq!(times[4], 1.0, epsilon = 1e-10);
}
#[test]
fn test_slerp_boundary_values() {
let rot1 = Rotation::identity();
let angles = array![0.0, 0.0, PI];
let rot2 = Rotation::from_euler(&angles.view(), "xyz").expect("Operation failed");
let slerp = Slerp::new(rot1, rot2).expect("Operation failed");
let tests = [
(-0.5, 0.0), (0.0, 0.0), (1.0, 1.0), (1.5, 1.0), ];
for (t, expected_t) in &tests {
let interp = slerp.interpolate(*t);
let expected = slerp.interpolate(*expected_t);
let point = array![1.0, 0.0, 0.0];
let rotated = interp.apply(&point.view()).expect("Operation failed");
let expected_rotated = expected.apply(&point.view()).expect("Operation failed");
assert_relative_eq!(rotated[0], expected_rotated[0], epsilon = 1e-10);
assert_relative_eq!(rotated[1], expected_rotated[1], epsilon = 1e-10);
assert_relative_eq!(rotated[2], expected_rotated[2], epsilon = 1e-10);
}
}
}