use crate::error::{InterpolateError, InterpolateResult};
use std::f64::consts::{PI, SQRT_2};
fn assoc_legendre(l: usize, m: usize, x: f64) -> f64 {
debug_assert!(m <= l);
debug_assert!(x.abs() <= 1.0 + 1e-10);
let sin_theta = (1.0 - x.min(1.0).max(-1.0).powi(2)).sqrt();
let mut pmm = 1.0f64;
let mut factor = 1.0f64;
for _ in 0..m {
pmm *= -factor * sin_theta;
factor += 2.0;
}
if l == m {
return pmm;
}
let mut pm1 = (2 * m + 1) as f64 * x * pmm;
if l == m + 1 {
return pm1;
}
let mut p_prev2 = pmm;
let mut p_prev1 = pm1;
let mut p_cur = 0.0;
for ll in (m + 2)..=l {
let lf = ll as f64;
let mf = m as f64;
p_cur = ((2.0 * lf - 1.0) * x * p_prev1 - (lf + mf - 1.0) * p_prev2) / (lf - mf);
p_prev2 = p_prev1;
p_prev1 = p_cur;
}
p_cur
}
fn normalized_legendre_cs(l: usize, m: usize, x: f64) -> f64 {
if m > l {
return 0.0;
}
let x = x.clamp(-1.0, 1.0);
let sin_theta = ((1.0 - x) * (1.0 + x)).sqrt();
let inv_4pi = 1.0 / (4.0 * PI);
let mut bar_pmm = inv_4pi.sqrt(); for k in 1..=m {
let k_f = k as f64;
bar_pmm *= -((2.0 * k_f - 1.0) / (2.0 * k_f)).sqrt() * sin_theta;
}
bar_pmm *= ((2 * m + 1) as f64).sqrt();
if l == m {
return bar_pmm;
}
let mut bar_pm1 = ((2 * m + 3) as f64).sqrt() * x * bar_pmm;
if l == m + 1 {
return bar_pm1;
}
let mut bar_prev2 = bar_pmm;
let mut bar_prev1 = bar_pm1;
let mut bar_cur = 0.0;
let m2 = (m * m) as f64;
for ll in (m + 2)..=l {
let ll_f = ll as f64;
let ll2 = ll_f * ll_f;
let denom = ll2 - m2;
let alpha = ((4.0 * ll2 - 1.0) / denom).sqrt();
let beta = ((2.0 * ll_f + 1.0) * (ll_f + m as f64 - 1.0) * (ll_f - m as f64 - 1.0)
/ ((2.0 * ll_f - 3.0) * denom))
.sqrt();
bar_cur = alpha * x * bar_prev1 - beta * bar_prev2;
bar_prev2 = bar_prev1;
bar_prev1 = bar_cur;
}
bar_cur
}
pub fn real_sph_harm(l: usize, m: i64, theta: f64, phi: f64) -> f64 {
let abs_m = m.unsigned_abs() as usize;
debug_assert!(abs_m <= l);
let cos_theta = theta.cos().clamp(-1.0, 1.0);
let bar_plm = normalized_legendre_cs(l, abs_m, cos_theta);
if m == 0 {
bar_plm
} else if m > 0 {
SQRT_2 * bar_plm * (abs_m as f64 * phi).cos()
} else {
SQRT_2 * bar_plm * (abs_m as f64 * phi).sin()
}
}
pub fn n_sph_harm_basis(l_max: usize) -> usize {
(l_max + 1) * (l_max + 1)
}
pub fn eval_sph_basis(l_max: usize, theta: f64, phi: f64) -> Vec<f64> {
let nb = n_sph_harm_basis(l_max);
let mut basis = Vec::with_capacity(nb);
for l in 0..=l_max {
let li = l as i64;
for m in -li..=li {
basis.push(real_sph_harm(l, m, theta, phi));
}
}
basis
}
fn lstsq(a: &[Vec<f64>], b: &[f64], reg: f64) -> Vec<f64> {
let m = b.len();
let n = if m > 0 { a[0].len() } else { 0 };
let mut ata = vec![vec![0.0f64; n]; n];
let mut atb = vec![0.0f64; n];
for (i, row) in a.iter().enumerate() {
for j in 0..n {
atb[j] += row[j] * b[i];
for k in 0..n {
ata[j][k] += row[j] * row[k];
}
}
}
for j in 0..n {
ata[j][j] += reg;
}
let mut l = vec![vec![0.0f64; n]; n];
for i in 0..n {
for j in 0..=i {
let mut s = ata[i][j];
for k in 0..j {
s -= l[i][k] * l[j][k];
}
if i == j {
l[i][j] = if s > 0.0 { s.sqrt() } else { 0.0 };
} else if l[j][j].abs() > 0.0 {
l[i][j] = s / l[j][j];
}
}
}
let mut y = vec![0.0f64; n];
for i in 0..n {
let mut s = atb[i];
for k in 0..i {
s -= l[i][k] * y[k];
}
if l[i][i].abs() > 1e-30 {
y[i] = s / l[i][i];
}
}
let mut x = vec![0.0f64; n];
for i in (0..n).rev() {
let mut s = y[i];
for k in i + 1..n {
s -= l[k][i] * x[k];
}
if l[i][i].abs() > 1e-30 {
x[i] = s / l[i][i];
}
}
x
}
#[derive(Clone, Debug)]
pub struct SphericalHarmonicsInterpolator {
l_max: usize,
coefficients: Vec<f64>,
}
impl SphericalHarmonicsInterpolator {
pub fn fit(
coords: &[[f64; 2]],
values: &[f64],
l_max: usize,
) -> InterpolateResult<Self> {
let n = coords.len();
if n != values.len() {
return Err(InterpolateError::invalid_input(
"SphericalHarmonics: coords and values must have the same length".to_string(),
));
}
let nb = n_sph_harm_basis(l_max);
if n < nb {
return Err(InterpolateError::insufficient_points(
nb,
n,
"SphericalHarmonicsInterpolator",
));
}
let mut a = Vec::with_capacity(n);
for [theta, phi] in coords.iter() {
a.push(eval_sph_basis(l_max, *theta, *phi));
}
let data_scale = values.iter().map(|v| v.abs()).fold(0.0f64, f64::max).max(1.0);
let reg = 1e-12 * data_scale * data_scale;
let coefficients = lstsq(&a, values, reg);
Ok(Self { l_max, coefficients })
}
pub fn predict(&self, theta: f64, phi: f64) -> f64 {
let basis = eval_sph_basis(self.l_max, theta, phi);
basis.iter().zip(self.coefficients.iter()).map(|(b, c)| b * c).sum()
}
pub fn predict_batch(&self, coords: &[[f64; 2]]) -> Vec<f64> {
coords.iter().map(|[t, p]| self.predict(*t, *p)).collect()
}
pub fn power_spectrum(&self) -> Vec<f64> {
let mut ps = vec![0.0f64; self.l_max + 1];
let mut idx = 0usize;
for l in 0..=self.l_max {
for _m in -(l as i64)..=(l as i64) {
ps[l] += self.coefficients[idx] * self.coefficients[idx];
idx += 1;
}
}
ps
}
pub fn l_max(&self) -> usize {
self.l_max
}
pub fn n_basis(&self) -> usize {
n_sph_harm_basis(self.l_max)
}
pub fn coefficients(&self) -> &[f64] {
&self.coefficients
}
}
#[inline]
fn quat_dot(q: [f64; 4], r: [f64; 4]) -> f64 {
q[0] * r[0] + q[1] * r[1] + q[2] * r[2] + q[3] * r[3]
}
#[inline]
fn quat_mul(q: [f64; 4], r: [f64; 4]) -> [f64; 4] {
let [qw, qx, qy, qz] = q;
let [rw, rx, ry, rz] = r;
[
qw * rw - qx * rx - qy * ry - qz * rz,
qw * rx + qx * rw + qy * rz - qz * ry,
qw * ry - qx * rz + qy * rw + qz * rx,
qw * rz + qx * ry - qy * rx + qz * rw,
]
}
#[inline]
fn quat_conj(q: [f64; 4]) -> [f64; 4] {
[q[0], -q[1], -q[2], -q[3]]
}
#[inline]
fn quat_norm(q: [f64; 4]) -> [f64; 4] {
let n = (q[0] * q[0] + q[1] * q[1] + q[2] * q[2] + q[3] * q[3]).sqrt();
if n < 1e-30 {
[1.0, 0.0, 0.0, 0.0]
} else {
[q[0] / n, q[1] / n, q[2] / n, q[3] / n]
}
}
fn quat_log(q: [f64; 4]) -> [f64; 4] {
let w = q[0].clamp(-1.0, 1.0);
let sin_theta = (q[1] * q[1] + q[2] * q[2] + q[3] * q[3]).sqrt();
if sin_theta < 1e-14 {
return [0.0, 0.0, 0.0, 0.0];
}
let theta = w.acos();
let scale = theta / sin_theta;
[0.0, scale * q[1], scale * q[2], scale * q[3]]
}
fn quat_exp(v: [f64; 4]) -> [f64; 4] {
let theta = (v[1] * v[1] + v[2] * v[2] + v[3] * v[3]).sqrt();
if theta < 1e-14 {
return [1.0, 0.0, 0.0, 0.0];
}
let sinc = theta.sin() / theta;
[theta.cos(), sinc * v[1], sinc * v[2], sinc * v[3]]
}
#[inline]
fn quat_scale(q: [f64; 4], s: f64) -> [f64; 4] {
[q[0] * s, q[1] * s, q[2] * s, q[3] * s]
}
#[inline]
fn quat_add(a: [f64; 4], b: [f64; 4]) -> [f64; 4] {
[a[0] + b[0], a[1] + b[1], a[2] + b[2], a[3] + b[3]]
}
pub fn slerp(q1: [f64; 4], q2: [f64; 4], t: f64) -> [f64; 4] {
let q1 = quat_norm(q1);
let mut q2 = quat_norm(q2);
let mut dot = quat_dot(q1, q2);
if dot < 0.0 {
q2 = [-q2[0], -q2[1], -q2[2], -q2[3]];
dot = -dot;
}
if dot > 1.0 - 1e-10 {
let q = [
q1[0] + t * (q2[0] - q1[0]),
q1[1] + t * (q2[1] - q1[1]),
q1[2] + t * (q2[2] - q1[2]),
q1[3] + t * (q2[3] - q1[3]),
];
return quat_norm(q);
}
let theta = dot.clamp(-1.0, 1.0).acos();
let sin_theta = theta.sin();
let s1 = ((1.0 - t) * theta).sin() / sin_theta;
let s2 = (t * theta).sin() / sin_theta;
quat_norm([
s1 * q1[0] + s2 * q2[0],
s1 * q1[1] + s2 * q2[1],
s1 * q1[2] + s2 * q2[2],
s1 * q1[3] + s2 * q2[3],
])
}
fn squad_inner(q_prev: [f64; 4], q_i: [f64; 4], q_next: [f64; 4]) -> [f64; 4] {
let qi_inv = quat_conj(q_i);
let log1 = quat_log(quat_mul(qi_inv, q_next));
let log2 = quat_log(quat_mul(qi_inv, q_prev));
let sum = quat_add(log1, log2);
let term = quat_scale(sum, -0.25);
quat_mul(q_i, quat_exp(term))
}
pub fn squad(
q0: [f64; 4],
q1: [f64; 4],
q2: [f64; 4],
q3: [f64; 4],
t: f64,
) -> [f64; 4] {
let q0n = quat_norm(q0);
let q1n = quat_norm(q1);
let q2n = quat_norm(q2);
let q3n = quat_norm(q3);
let s1 = squad_inner(q0n, q1n, q2n);
let s2 = squad_inner(q1n, q2n, q3n);
let ab = slerp(q1n, q2n, t);
let cd = slerp(s1, s2, t);
slerp(ab, cd, 2.0 * t * (1.0 - t))
}
#[cfg(test)]
mod tests {
use super::*;
use std::f64::consts::PI;
#[test]
fn test_legendre_p00() {
for x in [-1.0, -0.5, 0.0, 0.5, 1.0] {
assert!((assoc_legendre(0, 0, x) - 1.0).abs() < 1e-12, "P00({x})");
}
}
#[test]
fn test_legendre_p10() {
for x in [-0.8, 0.0, 0.5, 1.0] {
assert!((assoc_legendre(1, 0, x) - x).abs() < 1e-10, "P10({x})");
}
}
#[test]
fn test_legendre_p20() {
for x in [-1.0, -0.5, 0.0, 0.5, 1.0] {
let expected = (3.0 * x * x - 1.0) / 2.0;
let got = assoc_legendre(2, 0, x);
assert!((got - expected).abs() < 1e-10, "P20({x}) got={got} expected={expected}");
}
}
#[test]
fn test_sph_harm_y00_normalization() {
let expected = (1.0 / (4.0 * PI)).sqrt();
let v = real_sph_harm(0, 0, 1.0, 2.0);
assert!((v - expected).abs() < 1e-10, "Y00={v} expected={expected}");
}
#[test]
fn test_sph_harm_y10_north_pole() {
let expected = (3.0 / (4.0 * PI)).sqrt();
let v = real_sph_harm(1, 0, 0.0, 0.0);
assert!((v - expected).abs() < 1e-8, "Y10(pole)={v} expected={expected}");
}
#[test]
fn test_sph_harm_nonzero() {
for l in 0..=4usize {
for m in -(l as i64)..=(l as i64) {
let v = real_sph_harm(l, m, 1.2, 0.7);
assert!(v.is_finite(), "Y_{l}^{m} is not finite");
}
}
}
#[test]
fn test_sph_interp_constant() {
let n = 25;
let coords: Vec<[f64; 2]> = (0..n)
.map(|i| {
let t = PI * (i as f64 + 0.5) / n as f64;
let p = 2.0 * PI * (i as f64) / n as f64;
[t, p]
})
.collect();
let values = vec![1.0f64; n];
let interp = SphericalHarmonicsInterpolator::fit(&coords, &values, 2).expect("test: should succeed");
let pred = interp.predict(PI * 0.3, 1.0);
assert!((pred - 1.0).abs() < 0.1, "pred={pred}");
}
#[test]
fn test_sph_interp_cos_theta() {
let n = 36;
let coords: Vec<[f64; 2]> = (0..n)
.map(|i| {
let t = PI * (i as f64 + 0.5) / n as f64;
let p = 2.0 * PI * (i as f64) / n as f64;
[t, p]
})
.collect();
let values: Vec<f64> = coords.iter().map(|[t, _]| t.cos()).collect();
let interp = SphericalHarmonicsInterpolator::fit(&coords, &values, 3).expect("test: should succeed");
let pred_np = interp.predict(0.01, 0.0);
assert!((pred_np - 1.0).abs() < 0.15, "north pole pred={pred_np}");
let pred_eq = interp.predict(PI * 0.5, 1.0);
assert!(pred_eq.abs() < 0.2, "equator pred={pred_eq}");
}
#[test]
fn test_sph_interp_power_spectrum_length() {
let n = 25;
let coords: Vec<[f64; 2]> = (0..n)
.map(|i| [PI * i as f64 / n as f64, 2.0 * PI * i as f64 / n as f64])
.collect();
let values = vec![1.0f64; n];
let interp = SphericalHarmonicsInterpolator::fit(&coords, &values, 2).expect("test: should succeed");
let ps = interp.power_spectrum();
assert_eq!(ps.len(), 3); for p in &ps {
assert!(p.is_finite() && *p >= 0.0);
}
}
#[test]
fn test_sph_interp_n_basis() {
let n = 25;
let coords: Vec<[f64; 2]> = (0..n)
.map(|i| [PI * i as f64 / n as f64, 2.0 * PI * i as f64 / n as f64])
.collect();
let values = vec![1.0f64; n];
let interp = SphericalHarmonicsInterpolator::fit(&coords, &values, 2).expect("test: should succeed");
assert_eq!(interp.n_basis(), 9); assert_eq!(interp.l_max(), 2);
}
#[test]
fn test_sph_interp_batch() {
let n = 25;
let coords: Vec<[f64; 2]> = (0..n)
.map(|i| [PI * i as f64 / n as f64, 2.0 * PI * i as f64 / n as f64])
.collect();
let values = vec![1.0f64; n];
let interp = SphericalHarmonicsInterpolator::fit(&coords, &values, 2).expect("test: should succeed");
let queries = vec![[0.5, 0.5], [1.0, 1.0], [1.5, 2.0]];
let preds = interp.predict_batch(&queries);
assert_eq!(preds.len(), 3);
for p in &preds {
assert!(p.is_finite());
}
}
#[test]
fn test_sph_interp_insufficient_points() {
let coords: Vec<[f64; 2]> = (0..4).map(|i| [i as f64, 0.0]).collect();
let values = vec![1.0f64; 4];
assert!(SphericalHarmonicsInterpolator::fit(&coords, &values, 2).is_err());
}
#[test]
fn test_slerp_endpoints() {
let q1 = [1.0f64, 0.0, 0.0, 0.0];
let q2 = [0.0f64, 1.0, 0.0, 0.0];
let s0 = slerp(q1, q2, 0.0);
let s1 = slerp(q1, q2, 1.0);
for i in 0..4 {
assert!((s0[i] - q1[i]).abs() < 1e-10, "slerp(0)[{i}]");
assert!((s1[i] - q2[i]).abs() < 1e-10, "slerp(1)[{i}]");
}
}
#[test]
fn test_slerp_unit_norm() {
let q1 = [1.0f64, 0.0, 0.0, 0.0];
let q2 = [0.0f64, 0.0, 0.0, 1.0];
for i in 0..=10 {
let t = i as f64 / 10.0;
let q = slerp(q1, q2, t);
let norm = (q[0] * q[0] + q[1] * q[1] + q[2] * q[2] + q[3] * q[3]).sqrt();
assert!((norm - 1.0).abs() < 1e-10, "norm={norm} at t={t}");
}
}
#[test]
fn test_slerp_halfway() {
let q1 = [1.0f64, 0.0, 0.0, 0.0];
let q2 = [0.0f64, 0.0, 0.0, 1.0];
let mid = slerp(q1, q2, 0.5);
let expected = (PI / 4.0).cos();
assert!((mid[0] - expected).abs() < 1e-10, "mid[0]={}", mid[0]);
assert!((mid[3] - expected).abs() < 1e-10, "mid[3]={}", mid[3]);
}
#[test]
fn test_slerp_shortest_path() {
let q1 = [1.0f64, 0.0, 0.0, 0.0];
let q2 = [0.0f64, 0.0, 0.0, 1.0];
let q2_neg = [-q2[0], -q2[1], -q2[2], -q2[3]];
let s_pos = slerp(q1, q2, 0.5);
let s_neg = slerp(q1, q2_neg, 0.5);
for i in 0..4 {
assert!(
(s_pos[i] - s_neg[i]).abs() < 1e-10
|| (s_pos[i] + s_neg[i]).abs() < 1e-10,
"slerp shortest path component {i}: pos={} neg={}",
s_pos[i],
s_neg[i]
);
}
}
#[test]
fn test_squad_endpoints() {
let q0 = [1.0f64, 0.0, 0.0, 0.0];
let q1 = [1.0f64, 0.0, 0.0, 0.0];
let q2 = [0.0f64, 0.0, 0.0, 1.0];
let q3 = [0.0f64, 0.0, 0.0, 1.0];
let start = squad(q0, q1, q2, q3, 0.0);
let end = squad(q0, q1, q2, q3, 1.0);
assert!((start[0] - 1.0).abs() < 1e-8, "squad start[0]={}", start[0]);
assert!((end[3] - 1.0).abs() < 1e-8, "squad end[3]={}", end[3]);
}
#[test]
fn test_squad_unit_norm() {
let q0 = [1.0f64, 0.0, 0.0, 0.0];
let q1 = [0.7071f64, 0.7071, 0.0, 0.0];
let q2 = [0.0f64, 1.0, 0.0, 0.0];
let q3 = [-0.7071f64, 0.7071, 0.0, 0.0];
for i in 0..=20 {
let t = i as f64 / 20.0;
let q = squad(q0, q1, q2, q3, t);
let norm = (q[0] * q[0] + q[1] * q[1] + q[2] * q[2] + q[3] * q[3]).sqrt();
assert!((norm - 1.0).abs() < 1e-8, "squad norm={norm} at t={t}");
}
}
#[test]
fn test_squad_smooth_midpoint() {
let q0 = [1.0f64, 0.0, 0.0, 0.0];
let q1 = [1.0f64, 0.0, 0.0, 0.0];
let q2 = [0.0f64, 0.0, 0.0, 1.0];
let q3 = [0.0f64, 0.0, 0.0, 1.0];
let mid = squad(q0, q1, q2, q3, 0.5);
let same_q1 = mid.iter().zip(q1.iter()).all(|(a, b)| (a - b).abs() < 1e-8);
let same_q2 = mid.iter().zip(q2.iter()).all(|(a, b)| (a - b).abs() < 1e-8);
assert!(!same_q1 && !same_q2, "squad midpoint should differ from endpoints");
}
#[test]
fn test_quat_mul_identity() {
let id = [1.0f64, 0.0, 0.0, 0.0];
let q = [0.7071f64, 0.7071, 0.0, 0.0];
let q_id = quat_mul(q, id);
for i in 0..4 {
assert!((q_id[i] - q[i]).abs() < 1e-10);
}
}
#[test]
fn test_quat_inv() {
let q = quat_norm([0.5f64, 0.5, 0.5, 0.5]);
let qi = quat_conj(q);
let prod = quat_mul(q, qi);
assert!((prod[0] - 1.0).abs() < 1e-10);
assert!(prod[1].abs() < 1e-10);
assert!(prod[2].abs() < 1e-10);
assert!(prod[3].abs() < 1e-10);
}
}