use super::kernel::{eval_kernel, so3_geodesic_dist, GeometricKernel};
use crate::InterpolateError;
pub struct So3RbfInterpolator {
centers: Vec<[f64; 4]>,
weights: Vec<f64>,
kernel: GeometricKernel,
}
impl So3RbfInterpolator {
pub fn new(
quaternions: &[[f64; 4]],
values: &[f64],
kernel: GeometricKernel,
lambda: f64,
) -> Result<Self, InterpolateError> {
let n = quaternions.len();
if n == 0 {
return Err(InterpolateError::invalid_input(
"at least 1 quaternion required for So3RbfInterpolator",
));
}
if n != values.len() {
return Err(InterpolateError::shape_mismatch(
n.to_string(),
values.len().to_string(),
"So3RbfInterpolator: quaternions vs values",
));
}
let centers: Vec<[f64; 4]> = quaternions.iter().map(|q| normalize4(q)).collect();
let mut k_mat = vec![0.0_f64; n * n];
for i in 0..n {
for j in 0..n {
let d = so3_geodesic_dist(¢ers[i], ¢ers[j]);
let mut kij = eval_kernel(d, &kernel);
if i == j {
kij += lambda;
}
k_mat[i * n + j] = kij;
}
}
let weights = solve_system(&k_mat, values, n)?;
Ok(Self {
centers,
weights,
kernel,
})
}
pub fn eval(&self, q: &[f64; 4]) -> f64 {
let qn = normalize4(q);
self.centers
.iter()
.zip(self.weights.iter())
.map(|(c, &w)| {
let d = so3_geodesic_dist(&qn, c);
w * eval_kernel(d, &self.kernel)
})
.sum()
}
pub fn eval_batch(&self, quaternions: &[[f64; 4]]) -> Vec<f64> {
quaternions.iter().map(|q| self.eval(q)).collect()
}
}
fn normalize4(q: &[f64; 4]) -> [f64; 4] {
let norm = (q[0] * q[0] + q[1] * q[1] + q[2] * q[2] + q[3] * q[3]).sqrt();
if norm < f64::EPSILON {
[1.0, 0.0, 0.0, 0.0]
} else {
[q[0] / norm, q[1] / norm, q[2] / norm, q[3] / norm]
}
}
fn solve_system(a: &[f64], b: &[f64], n: usize) -> Result<Vec<f64>, InterpolateError> {
if let Ok(x) = cholesky_solve(a, b, n) {
return Ok(x);
}
gauss_solve(a, b, n)
}
fn cholesky_solve(a: &[f64], b: &[f64], n: usize) -> Result<Vec<f64>, InterpolateError> {
let mut l = vec![0.0_f64; n * n];
for i in 0..n {
for j in 0..=i {
let mut s: f64 = a[i * n + j];
for k in 0..j {
s -= l[i * n + k] * l[j * n + k];
}
if i == j {
if s <= 0.0 {
return Err(InterpolateError::ComputationError(
"Cholesky: not positive definite".into(),
));
}
l[i * n + j] = s.sqrt();
} else {
let lii = l[j * n + j];
if lii.abs() < f64::EPSILON {
return Err(InterpolateError::ComputationError(
"Cholesky: zero diagonal".into(),
));
}
l[i * n + j] = s / lii;
}
}
}
let mut y = vec![0.0_f64; n];
for i in 0..n {
let mut s = b[i];
for j in 0..i {
s -= l[i * n + j] * y[j];
}
let lii = l[i * n + i];
if lii.abs() < f64::EPSILON {
return Err(InterpolateError::ComputationError(
"Cholesky forward sub: zero diagonal".into(),
));
}
y[i] = s / lii;
}
let mut x = vec![0.0_f64; n];
for i in (0..n).rev() {
let mut s = y[i];
for j in (i + 1)..n {
s -= l[j * n + i] * x[j];
}
let lii = l[i * n + i];
if lii.abs() < f64::EPSILON {
return Err(InterpolateError::ComputationError(
"Cholesky back sub: zero diagonal".into(),
));
}
x[i] = s / lii;
}
Ok(x)
}
fn gauss_solve(a: &[f64], b: &[f64], n: usize) -> Result<Vec<f64>, InterpolateError> {
let mut mat = a.to_vec();
let mut rhs = b.to_vec();
for col in 0..n {
let pivot_row = (col..n)
.max_by(|&i, &j| {
mat[i * n + col]
.abs()
.partial_cmp(&mat[j * n + col].abs())
.unwrap_or(std::cmp::Ordering::Equal)
})
.ok_or_else(|| InterpolateError::ComputationError("empty matrix".into()))?;
if pivot_row != col {
for k in 0..n {
mat.swap(col * n + k, pivot_row * n + k);
}
rhs.swap(col, pivot_row);
}
let piv = mat[col * n + col];
if piv.abs() < 1e-14 {
return Err(InterpolateError::ComputationError(
"singular kernel matrix — increase lambda".into(),
));
}
for row in (col + 1)..n {
let factor = mat[row * n + col] / piv;
for k in col..n {
let val = mat[col * n + k];
mat[row * n + k] -= factor * val;
}
let rv = rhs[col];
rhs[row] -= factor * rv;
}
}
let mut x = vec![0.0_f64; n];
for i in (0..n).rev() {
let mut s = rhs[i];
for j in (i + 1)..n {
s -= mat[i * n + j] * x[j];
}
let d = mat[i * n + i];
if d.abs() < f64::EPSILON {
return Err(InterpolateError::ComputationError(
"back-substitution: zero diagonal".into(),
));
}
x[i] = s / d;
}
Ok(x)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::lie_group::kernel::{so3_geodesic_dist, GeometricKernel};
use std::f64::consts::PI;
fn identity() -> [f64; 4] {
[1.0, 0.0, 0.0, 0.0]
}
fn rot_x_180() -> [f64; 4] {
[0.0, 1.0, 0.0, 0.0]
}
fn rot_z_90() -> [f64; 4] {
let s = (PI / 4.0).sin();
let c = (PI / 4.0).cos();
[c, 0.0, 0.0, s]
}
#[test]
fn test_so3_geodesic_dist_identity() {
let id = identity();
let d = so3_geodesic_dist(&id, &id);
assert!(
d.abs() < 1e-12,
"identity self-distance should be 0, got {d}"
);
}
#[test]
fn test_so3_geodesic_dist_180deg() {
let id = identity();
let r180 = rot_x_180();
let d = so3_geodesic_dist(&id, &r180);
assert!(
(d - PI).abs() < 1e-10,
"180° rotation distance should be π, got {d}"
);
}
#[test]
fn test_so3_geodesic_dist_double_cover() {
let q = [0.5_f64, 0.5, 0.5, 0.5];
let neg_q = [-0.5_f64, -0.5, -0.5, -0.5];
let d = so3_geodesic_dist(&q, &neg_q);
assert!(d.abs() < 1e-10, "q and -q should have distance 0, got {d}");
}
#[test]
fn test_so3_rbf_constant_function() {
let s = 1.0_f64 / 2.0_f64.sqrt();
let pts = vec![
identity(),
rot_x_180(),
rot_z_90(),
[s, 0.0_f64, s, 0.0], [s, 0.0_f64, 0.0, -s], ];
let vals = vec![1.0_f64; pts.len()];
let interp =
So3RbfInterpolator::new(&pts, &vals, GeometricKernel::Heat { sigma: 2.0 }, 1e-8)
.expect("construction should succeed");
let v = interp.eval(&identity());
assert!(
(v - 1.0).abs() < 0.3,
"constant function should return ≈1 at identity, got {v}"
);
let v2 = interp.eval(&[0.5_f64, 0.5, 0.5, 0.5]);
assert!(
v2.is_finite() && v2 > 0.0,
"constant function eval should be finite and positive, got {v2}"
);
}
#[test]
fn test_so3_rbf_reproduces_training_points() {
let pts = vec![identity(), rot_x_180(), rot_z_90()];
let vals = vec![0.0_f64, 3.14, 1.57];
let interp =
So3RbfInterpolator::new(&pts, &vals, GeometricKernel::Heat { sigma: 0.5 }, 1e-10)
.expect("construction should succeed");
for (q, &expected) in pts.iter().zip(vals.iter()) {
let got = interp.eval(q);
assert!(
(got - expected).abs() < 0.5,
"at training quaternion {:?}, expected {expected}, got {got}",
q
);
}
}
#[test]
fn test_so3_rbf_empty_input_error() {
let result = So3RbfInterpolator::new(&[], &[], GeometricKernel::Heat { sigma: 1.0 }, 1e-6);
assert!(result.is_err(), "empty input should return error");
}
#[test]
fn test_so3_rbf_mismatched_lengths_error() {
let pts = vec![identity(), rot_x_180()];
let vals = vec![1.0_f64];
let result =
So3RbfInterpolator::new(&pts, &vals, GeometricKernel::Heat { sigma: 1.0 }, 1e-6);
assert!(result.is_err(), "mismatched lengths should return error");
}
#[test]
fn test_so3_rbf_batch_eval() {
let pts = vec![identity(), rot_x_180()];
let vals = vec![1.0_f64, 2.0];
let interp =
So3RbfInterpolator::new(&pts, &vals, GeometricKernel::Heat { sigma: 1.0 }, 1e-6)
.expect("construction should succeed");
let batch = interp.eval_batch(&[identity(), rot_x_180()]);
assert_eq!(batch.len(), 2);
assert!(batch.iter().all(|v| v.is_finite()));
}
#[test]
fn test_so3_rbf_matern_kernel() {
let pts = vec![identity(), rot_z_90()];
let vals = vec![0.0_f64, 1.0];
let interp = So3RbfInterpolator::new(
&pts,
&vals,
GeometricKernel::Matern {
nu: 2.5,
length_scale: 1.0,
},
1e-6,
)
.expect("construction with Matern kernel should succeed");
let v = interp.eval(&identity());
assert!(
v.is_finite(),
"Matern kernel eval should be finite, got {v}"
);
}
}