use super::projections::{alternating_projections, SdrConfig, SdrConvergence};
use nalgebra::DMatrix;
#[derive(Debug, Clone)]
pub struct SdrResult {
pub coords: Vec<f64>,
pub num_atoms: usize,
pub convergence: SdrConvergence,
pub max_distance_error: f64,
pub retries_avoided: usize,
}
pub fn warm_start_gram(n: usize, distance_pairs: &[(usize, usize, f64)]) -> DMatrix<f64> {
let mut d_sq = DMatrix::zeros(n, n);
for &(i, j, d) in distance_pairs {
d_sq[(i, j)] = d * d;
d_sq[(j, i)] = d * d;
}
let one_n = 1.0 / n as f64;
let mut gram = DMatrix::zeros(n, n);
for i in 0..n {
for j in 0..n {
let val = d_sq[(i, j)];
let row_mean: f64 = (0..n).map(|k| d_sq[(i, k)]).sum::<f64>() * one_n;
let col_mean: f64 = (0..n).map(|k| d_sq[(k, j)]).sum::<f64>() * one_n;
let grand_mean: f64 = d_sq.iter().sum::<f64>() * one_n * one_n;
gram[(i, j)] = -0.5 * (val - row_mean - col_mean + grand_mean);
}
}
gram
}
pub fn extract_coordinates(gram: &DMatrix<f64>) -> Vec<f64> {
let n = gram.nrows();
let eigen = gram.clone().symmetric_eigen();
let mut indexed: Vec<(usize, f64)> = eigen.eigenvalues.iter().cloned().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut coords = vec![0.0; n * 3];
for dim in 0..3.min(indexed.len()) {
let (idx, eval) = indexed[dim];
let sqrt_eval = if eval > 0.0 { eval.sqrt() } else { 0.0 };
for i in 0..n {
coords[i * 3 + dim] = eigen.eigenvectors[(i, idx)] * sqrt_eval;
}
}
coords
}
pub fn sdr_embed(
n: usize,
distance_pairs: &[(usize, usize, f64)],
config: &SdrConfig,
) -> SdrResult {
let x0 = warm_start_gram(n, distance_pairs);
let (gram, convergence) = alternating_projections(&x0, distance_pairs, config);
let coords = extract_coordinates(&gram);
let mut max_err = 0.0;
let mut retries = 0;
for &(i, j, d_target) in distance_pairs {
let dx = coords[i * 3] - coords[j * 3];
let dy = coords[i * 3 + 1] - coords[j * 3 + 1];
let dz = coords[i * 3 + 2] - coords[j * 3 + 2];
let d_actual = (dx * dx + dy * dy + dz * dz).sqrt();
let err = (d_actual - d_target).abs();
if err > max_err {
max_err = err;
}
}
let eigen_orig = x0.symmetric_eigen();
for &e in eigen_orig.eigenvalues.iter() {
if e < -1e-8 {
retries += 1;
}
}
SdrResult {
coords,
num_atoms: n,
convergence,
max_distance_error: max_err,
retries_avoided: retries,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_warm_start_gram_simple() {
let pairs = vec![
(0, 1, 1.0), (0, 2, 1.0), (1, 2, 2.0f64.sqrt()), ];
let gram = warm_start_gram(3, &pairs);
assert_eq!(gram.nrows(), 3);
}
#[test]
fn test_extract_coordinates_from_known() {
let known = [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]];
let n = 3;
let mut gram = DMatrix::zeros(n, n);
for i in 0..n {
for j in 0..n {
gram[(i, j)] = known[i][0] * known[j][0]
+ known[i][1] * known[j][1]
+ known[i][2] * known[j][2];
}
}
let coords = extract_coordinates(&gram);
assert_eq!(coords.len(), 9);
let d01 = ((coords[0] - coords[3]).powi(2)
+ (coords[1] - coords[4]).powi(2)
+ (coords[2] - coords[5]).powi(2))
.sqrt();
assert!((d01 - 1.0).abs() < 0.1, "d01 should be ~1.0: {}", d01);
}
#[test]
fn test_sdr_embed_triangle() {
let pairs = vec![(0, 1, 1.5), (0, 2, 1.5), (1, 2, 1.5)];
let config = SdrConfig::default();
let result = sdr_embed(3, &pairs, &config);
assert_eq!(result.num_atoms, 3);
assert_eq!(result.coords.len(), 9);
}
#[test]
fn test_sdr_embed_tetrahedron() {
let d = 2.0;
let pairs = vec![
(0, 1, d),
(0, 2, d),
(0, 3, d),
(1, 2, d),
(1, 3, d),
(2, 3, d),
];
let result = sdr_embed(4, &pairs, &SdrConfig::default());
assert_eq!(result.num_atoms, 4);
for &(i, j, d_target) in &pairs {
let dx = result.coords[i * 3] - result.coords[j * 3];
let dy = result.coords[i * 3 + 1] - result.coords[j * 3 + 1];
let dz = result.coords[i * 3 + 2] - result.coords[j * 3 + 2];
let d_actual = (dx * dx + dy * dy + dz * dz).sqrt();
assert!(
(d_actual - d_target).abs() < 0.5,
"Distance ({},{}) should be ~{}: got {}",
i,
j,
d_target,
d_actual
);
}
}
}