use super::{InitialGuess, Manifold, MeanError};
use nalgebra::{
allocator::Allocator, DefaultAllocator, OMatrix, RealField, UnitQuaternion, Vector3, U3,
};
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct UnitQuaternionManifold<T>
where
T: RealField + Copy,
{
quaternion: UnitQuaternion<T>,
}
impl<T> UnitQuaternionManifold<T>
where
T: RealField + Copy,
{
pub fn new(quaternion: UnitQuaternion<T>) -> Self {
Self { quaternion }
}
pub fn as_quaternion(&self) -> &UnitQuaternion<T> {
&self.quaternion
}
pub fn as_quaternion_mut(&mut self) -> &mut UnitQuaternion<T> {
&mut self.quaternion
}
pub fn into_quaternion(self) -> UnitQuaternion<T> {
self.quaternion
}
pub fn identity() -> Self {
Self::new(UnitQuaternion::identity())
}
pub fn from_euler_angles(roll: T, pitch: T, yaw: T) -> Self {
Self::new(UnitQuaternion::from_euler_angles(roll, pitch, yaw))
}
}
impl<T> Manifold<U3, T> for UnitQuaternionManifold<T>
where
T: RealField + Copy,
DefaultAllocator: Allocator<U3> + Allocator<U3, U3>,
{
fn retract(&self, delta: &Vector3<T>) -> Self {
let delta_quat = UnitQuaternion::from_scaled_axis(*delta);
let mut result = self.quaternion * delta_quat;
result.renormalize(); Self::new(result)
}
fn local(&self, other: &Self) -> Vector3<T> {
let relative = self.quaternion.inverse() * other.quaternion;
relative.scaled_axis()
}
fn weighted_mean(
points: &[Self],
weights: &[T],
tolerance: T,
initial_guess: InitialGuess<Self>,
max_iterations: usize,
) -> Result<Self, MeanError>
where
DefaultAllocator: Allocator<U3>,
{
if points.is_empty() || weights.is_empty() {
return Err(MeanError::EmptyInput);
}
if points.len() != weights.len() {
return Err(MeanError::LengthMismatch);
}
if tolerance < T::zero() {
return Err(MeanError::InvalidTolerance);
}
if max_iterations == 0 {
return Err(MeanError::NotConverged);
}
let mut mean = match initial_guess {
InitialGuess::First => points[0],
InitialGuess::Index(idx) => {
if idx >= points.len() {
return Err(MeanError::IndexOutOfBounds);
}
points[idx]
}
InitialGuess::MaxWeight => {
let mut best_idx = None;
let mut best_weight = T::zero();
for (i, &w) in weights.iter().enumerate() {
if w > T::zero() && (best_idx.is_none() || w > best_weight) {
best_idx = Some(i);
best_weight = w;
}
}
match best_idx {
Some(i) => points[i],
None => return Err(MeanError::NoPositiveWeights),
}
}
InitialGuess::Provided(m) => m,
};
let tolerance_sq = tolerance * tolerance;
let mut delta = Vector3::<T>::zeros();
for _ in 0..max_iterations {
delta.fill(T::zero());
let mut total_weight = T::zero();
for (point, &weight) in points.iter().zip(weights.iter()) {
if weight > T::zero() {
delta += mean.local(point) * weight;
total_weight += weight;
}
}
if total_weight <= T::zero() {
return Err(MeanError::NoPositiveWeights);
}
delta /= total_weight;
if delta.dot(&delta) <= tolerance_sq {
return Ok(mean);
}
mean = mean.retract(&delta);
}
Err(MeanError::NotConverged)
}
fn batch_retract(points: &[Self], deltas: &[Vector3<T>], output: &mut [Self])
where
DefaultAllocator: Allocator<U3>,
{
assert_eq!(points.len(), deltas.len(), "points/deltas length mismatch");
assert_eq!(points.len(), output.len(), "points/output length mismatch");
for ((point, delta), out) in points.iter().zip(deltas.iter()).zip(output.iter_mut()) {
*out = point.retract(delta);
}
}
fn batch_local(points_a: &[Self], points_b: &[Self], output: &mut [Vector3<T>])
where
DefaultAllocator: Allocator<U3>,
{
assert_eq!(
points_a.len(),
points_b.len(),
"base/target length mismatch"
);
assert_eq!(points_a.len(), output.len(), "base/output length mismatch");
for ((base, target), out) in points_a.iter().zip(points_b.iter()).zip(output.iter_mut()) {
*out = base.local(target);
}
}
fn batch_local_from_base(base_point: &Self, target_points: &[Self], output: &mut [Vector3<T>])
where
DefaultAllocator: Allocator<U3>,
{
assert_eq!(
target_points.len(),
output.len(),
"target/output length mismatch"
);
for (target, out) in target_points.iter().zip(output.iter_mut()) {
*out = base_point.local(target);
}
}
fn batch_local_into_matrix<C>(
base_point: &Self,
target_points: &[Self],
output_matrix: &mut OMatrix<T, U3, C>,
) where
C: nalgebra::DimName,
DefaultAllocator: Allocator<U3> + Allocator<U3, C>,
{
assert_eq!(
target_points.len(),
output_matrix.ncols(),
"target length mismatch with matrix columns"
);
for (i, target) in target_points.iter().enumerate() {
let tangent = base_point.local(target);
output_matrix.column_mut(i).copy_from(&tangent);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
use core::f64::consts::PI;
use nalgebra::Vector3;
#[test]
fn identity_is_neutral() {
let q = UnitQuaternionManifold::<f64>::identity();
let delta = Vector3::zeros();
let result = q.retract(&delta);
assert!((q.local(&result)).norm() < 1e-12);
}
#[test]
fn retract_then_local_round_trip() {
let q = UnitQuaternionManifold::<f64>::identity();
let delta = Vector3::new(0.1, -0.2, 0.3);
let q2 = q.retract(&delta);
let recovered = q.local(&q2);
assert!((delta - recovered).norm() < 1e-10);
}
#[test]
fn weighted_mean_between_rotations() {
let q1 = UnitQuaternionManifold::<f64>::identity();
let q2 = UnitQuaternionManifold::from_euler_angles(0.2, 0.0, 0.0);
let q3 = UnitQuaternionManifold::from_euler_angles(0.0, 0.2, 0.0);
let points = vec![q1, q2, q3];
let weights = vec![0.5, 0.25, 0.25];
let mean =
UnitQuaternionManifold::weighted_mean(&points, &weights, 1e-9, InitialGuess::First, 50)
.unwrap();
let axis_angle = mean.as_quaternion().scaled_axis();
assert!(axis_angle.norm() < 0.3);
}
#[test]
fn batch_operations_behave() {
let base = UnitQuaternionManifold::<f64>::identity();
let other = UnitQuaternionManifold::from_euler_angles(0.1, 0.0, 0.0);
let points = vec![base, other];
let deltas = vec![Vector3::new(0.05, 0.0, 0.0), Vector3::new(0.0, 0.02, 0.0)];
let mut out = vec![UnitQuaternionManifold::identity(); points.len()];
UnitQuaternionManifold::batch_retract(&points, &deltas, &mut out);
assert_eq!(out.len(), points.len());
let mut locals = vec![Vector3::zeros(); points.len()];
UnitQuaternionManifold::batch_local_from_base(&base, &points, &mut locals);
assert!(locals[0].norm() < 1e-12);
let mut matrix = OMatrix::<f64, U3, nalgebra::Const<2>>::zeros();
UnitQuaternionManifold::batch_local_into_matrix(&base, &points, &mut matrix);
assert!(matrix.column(0).norm() < 1e-12);
}
#[test]
fn large_rotation_handled() {
let q = UnitQuaternionManifold::<f64>::identity();
let delta = Vector3::new(PI - 0.1, 0.0, 0.0);
let q2 = q.retract(&delta);
let recovered = q.local(&q2);
assert!((recovered - delta).norm() < 1e-10);
}
#[test]
fn weighted_mean_no_positive_weights() {
let points = vec![UnitQuaternionManifold::identity()];
let weights = vec![0.0];
let result =
UnitQuaternionManifold::weighted_mean(&points, &weights, 1e-9, InitialGuess::First, 50);
assert_eq!(result, Err(MeanError::NoPositiveWeights));
}
#[test]
fn hemisphere_sign_invariance() {
use alloc::vec::Vec;
let angles = [0.20, 0.21, 0.19, 0.205, 0.195];
let points: Vec<UnitQuaternionManifold<f64>> = angles
.iter()
.map(|&a| UnitQuaternionManifold::from_euler_angles(a, 0.0, 0.0))
.collect();
let mut points_flipped = points.clone();
for &idx in &[0, 2, 4] {
let raw = points_flipped[idx].as_quaternion().quaternion().clone();
let neg = -raw;
points_flipped[idx] = UnitQuaternionManifold::new(UnitQuaternion::new_unchecked(neg));
}
let weights = vec![0.5, 0.2, 0.15, 0.1, 0.05];
let init = InitialGuess::Provided(UnitQuaternionManifold::identity());
let mean_orig = UnitQuaternionManifold::weighted_mean(&points, &weights, 1e-12, init, 200)
.expect("mean on original should converge");
let init = InitialGuess::Provided(UnitQuaternionManifold::identity());
let mean_flipped =
UnitQuaternionManifold::weighted_mean(&points_flipped, &weights, 1e-12, init, 200)
.expect("mean on flipped should converge");
let dot = mean_orig
.as_quaternion()
.quaternion()
.coords
.dot(&mean_flipped.as_quaternion().quaternion().coords);
let abs_dot = dot.abs().clamp(-1.0, 1.0);
let angle = 2.0 * abs_dot.acos();
assert!(
angle < 1e-12,
"Means differ after sign flips; angular diff = {:.3e} rad (dot = {})",
angle,
dot
);
}
}