use crate::graph::Graph;
use crate::{Error, Result};
#[derive(Debug, Clone)]
pub struct Ellipsoid {
pub center: Vec<f64>,
pub shape: Vec<f64>,
}
impl Ellipsoid {
pub fn distance(&self, other: &Ellipsoid) -> Result<f64> {
ellipsoid_distance(self, other)
}
pub fn overlap(&self, other: &Ellipsoid) -> Result<f64> {
ellipsoid_overlap(self, other)
}
pub fn dim(&self) -> usize {
self.center.len()
}
}
#[derive(Debug, Clone)]
pub struct EllipsoidalConfig {
pub dim: usize,
pub regularization: f64,
}
impl Default for EllipsoidalConfig {
fn default() -> Self {
Self {
dim: 2,
regularization: 1e-10,
}
}
}
pub fn ellipsoidal_embedding<G: Graph>(graph: &G, config: &EllipsoidalConfig) -> Vec<Ellipsoid> {
let n = graph.node_count();
let dim = config.dim;
assert!(dim > 0, "embedding dimension must be positive");
assert!(
dim < n,
"embedding dimension must be < node_count (at most n-1 non-trivial eigenvectors)"
);
let mut laplacian = vec![0.0_f64; n * n];
for u in 0..n {
let nbrs = graph.neighbors(u);
laplacian[u * n + u] = nbrs.len() as f64;
for v in nbrs {
laplacian[u * n + v] -= 1.0;
}
}
let (eigenvalues, eigenvectors) = symmetric_eigen(n, &mut laplacian);
let mut order: Vec<usize> = (0..n).collect();
order.sort_by(|&a, &b| eigenvalues[a].partial_cmp(&eigenvalues[b]).unwrap());
let selected: Vec<usize> = order[1..=dim].to_vec();
let mut embeddings = Vec::with_capacity(n);
for v in 0..n {
let mut center = Vec::with_capacity(dim);
let mut shape = vec![0.0_f64; dim * dim];
for &ej in selected.iter() {
let lam_j = eigenvalues[ej].max(config.regularization);
let u_jv = eigenvectors[ej * n + v];
center.push(u_jv / lam_j.sqrt());
}
for j in 0..dim {
let ej = selected[j];
let lam_j = eigenvalues[ej].max(config.regularization);
let u_jv = eigenvectors[ej * n + v];
for l in 0..dim {
let el = selected[l];
let lam_l = eigenvalues[el].max(config.regularization);
let u_lv = eigenvectors[el * n + v];
shape[j * dim + l] = u_jv * u_lv / (lam_j.sqrt() * lam_l.sqrt());
}
}
embeddings.push(Ellipsoid { center, shape });
}
embeddings
}
pub fn ellipsoid_distance(a: &Ellipsoid, b: &Ellipsoid) -> Result<f64> {
let dim = a.center.len();
if dim != b.center.len() {
return Err(Error::DimensionMismatch(dim, b.center.len()));
}
if a.shape.len() != dim * dim {
return Err(Error::DimensionMismatch(a.shape.len(), dim * dim));
}
if b.shape.len() != dim * dim {
return Err(Error::DimensionMismatch(b.shape.len(), dim * dim));
}
let center_dist_sq: f64 = a
.center
.iter()
.zip(b.center.iter())
.map(|(x, y)| (x - y).powi(2))
.sum();
let tr_a = trace(dim, &a.shape);
let tr_b = trace(dim, &b.shape);
let sqrt_a = matrix_sqrt_psd(dim, &a.shape);
let m = mat_mul(dim, &mat_mul(dim, &sqrt_a, &b.shape), &sqrt_a);
let sqrt_m = matrix_sqrt_psd(dim, &m);
let tr_cross = trace(dim, &sqrt_m);
let w2_sq = center_dist_sq + tr_a + tr_b - 2.0 * tr_cross;
Ok(w2_sq.max(0.0).sqrt())
}
pub fn ellipsoid_overlap(a: &Ellipsoid, b: &Ellipsoid) -> Result<f64> {
let dim = a.center.len();
if dim != b.center.len() {
return Err(Error::DimensionMismatch(dim, b.center.len()));
}
if a.shape.len() != dim * dim {
return Err(Error::DimensionMismatch(a.shape.len(), dim * dim));
}
if b.shape.len() != dim * dim {
return Err(Error::DimensionMismatch(b.shape.len(), dim * dim));
}
let eps = 1e-8;
let mut sa = a.shape.clone();
let mut sb = b.shape.clone();
for i in 0..dim {
sa[i * dim + i] += eps;
sb[i * dim + i] += eps;
}
let mut s_avg = vec![0.0_f64; dim * dim];
for i in 0..dim * dim {
s_avg[i] = (sa[i] + sb[i]) / 2.0;
}
let det_a = matrix_det(dim, &sa).abs().max(eps);
let det_b = matrix_det(dim, &sb).abs().max(eps);
let det_avg = matrix_det(dim, &s_avg).abs().max(eps);
let det_factor = (det_a.powf(0.25) * det_b.powf(0.25)) / det_avg.sqrt();
let s_avg_inv = matrix_inverse(dim, &s_avg);
let diff: Vec<f64> = a
.center
.iter()
.zip(b.center.iter())
.map(|(x, y)| x - y)
.collect();
let mut mahal = 0.0;
for i in 0..dim {
let mut row_sum = 0.0;
for j in 0..dim {
row_sum += s_avg_inv[i * dim + j] * diff[j];
}
mahal += diff[i] * row_sum;
}
let overlap = det_factor * (-mahal / 8.0).exp();
Ok(overlap.clamp(0.0, 1.0))
}
fn trace(n: usize, m: &[f64]) -> f64 {
(0..n).map(|i| m[i * n + i]).sum()
}
fn mat_mul(n: usize, a: &[f64], b: &[f64]) -> Vec<f64> {
let mut c = vec![0.0; n * n];
for i in 0..n {
for k in 0..n {
let a_ik = a[i * n + k];
for j in 0..n {
c[i * n + j] += a_ik * b[k * n + j];
}
}
}
c
}
fn matrix_sqrt_psd(n: usize, m: &[f64]) -> Vec<f64> {
let mut work = m.to_vec();
let (vals, vecs) = symmetric_eigen(n, &mut work);
let mut result = vec![0.0; n * n];
for k in 0..n {
let s = vals[k].max(0.0).sqrt();
for i in 0..n {
let vi = vecs[k * n + i] * s;
for j in 0..n {
result[i * n + j] += vi * vecs[k * n + j];
}
}
}
result
}
fn matrix_det(n: usize, m: &[f64]) -> f64 {
if n == 0 {
return 1.0;
}
let mut a = m.to_vec();
let mut sign = 1.0_f64;
for col in 0..n {
let mut max_row = col;
let mut max_val = a[col * n + col].abs();
for row in (col + 1)..n {
let v = a[row * n + col].abs();
if v > max_val {
max_val = v;
max_row = row;
}
}
if max_val < 1e-15 {
return 0.0;
}
if max_row != col {
for j in 0..n {
a.swap(col * n + j, max_row * n + j);
}
sign = -sign;
}
let pivot = a[col * n + col];
for row in (col + 1)..n {
let factor = a[row * n + col] / pivot;
for j in col..n {
let val = a[col * n + j];
a[row * n + j] -= factor * val;
}
}
}
let mut det = sign;
for i in 0..n {
det *= a[i * n + i];
}
det
}
fn matrix_inverse(n: usize, m: &[f64]) -> Vec<f64> {
let mut aug = vec![0.0; n * 2 * n];
for i in 0..n {
for j in 0..n {
aug[i * 2 * n + j] = m[i * n + j];
}
aug[i * 2 * n + n + i] = 1.0;
}
for col in 0..n {
let mut max_row = col;
let mut max_val = aug[col * 2 * n + col].abs();
for row in (col + 1)..n {
let v = aug[row * 2 * n + col].abs();
if v > max_val {
max_val = v;
max_row = row;
}
}
if max_row != col {
for j in 0..(2 * n) {
aug.swap(col * 2 * n + j, max_row * 2 * n + j);
}
}
let pivot = aug[col * 2 * n + col];
if pivot.abs() < 1e-15 {
let mut id = vec![0.0; n * n];
for i in 0..n {
id[i * n + i] = 1.0;
}
return id;
}
for j in 0..(2 * n) {
aug[col * 2 * n + j] /= pivot;
}
for row in 0..n {
if row == col {
continue;
}
let factor = aug[row * 2 * n + col];
for j in 0..(2 * n) {
let val = aug[col * 2 * n + j];
aug[row * 2 * n + j] -= factor * val;
}
}
}
let mut inv = vec![0.0; n * n];
for i in 0..n {
for j in 0..n {
inv[i * n + j] = aug[i * 2 * n + n + j];
}
}
inv
}
fn symmetric_eigen(n: usize, a: &mut [f64]) -> (Vec<f64>, Vec<f64>) {
let mut v = vec![0.0; n * n];
for i in 0..n {
v[i * n + i] = 1.0;
}
let max_iter = 100 * n * n;
for _ in 0..max_iter {
let mut max_val = 0.0_f64;
let mut p = 0;
let mut q = 1;
for i in 0..n {
for j in (i + 1)..n {
let val = a[i * n + j].abs();
if val > max_val {
max_val = val;
p = i;
q = j;
}
}
}
if max_val < 1e-12 {
break;
}
let app = a[p * n + p];
let aqq = a[q * n + q];
let apq = a[p * n + q];
let theta = if (app - aqq).abs() < 1e-15 {
std::f64::consts::FRAC_PI_4
} else {
0.5 * (2.0 * apq / (app - aqq)).atan()
};
let c = theta.cos();
let s = theta.sin();
let mut new_ap = vec![0.0; n];
let mut new_aq = vec![0.0; n];
for i in 0..n {
new_ap[i] = c * a[p * n + i] + s * a[q * n + i];
new_aq[i] = -s * a[p * n + i] + c * a[q * n + i];
}
for i in 0..n {
a[p * n + i] = new_ap[i];
a[q * n + i] = new_aq[i];
}
for i in 0..n {
let aip = a[i * n + p];
let aiq = a[i * n + q];
a[i * n + p] = c * aip + s * aiq;
a[i * n + q] = -s * aip + c * aiq;
}
for i in 0..n {
let vip = v[i * n + p];
let viq = v[i * n + q];
v[i * n + p] = c * vip + s * viq;
v[i * n + q] = -s * vip + c * viq;
}
}
let eigenvalues: Vec<f64> = (0..n).map(|i| a[i * n + i]).collect();
let mut eigenvectors = vec![0.0; n * n];
for k in 0..n {
for i in 0..n {
eigenvectors[k * n + i] = v[i * n + k];
}
}
(eigenvalues, eigenvectors)
}
#[cfg(test)]
mod tests {
use super::*;
struct TestGraph {
adj: Vec<Vec<usize>>,
}
impl TestGraph {
fn complete(n: usize) -> Self {
let adj = (0..n)
.map(|i| (0..n).filter(|&j| j != i).collect())
.collect();
Self { adj }
}
fn star(n: usize) -> Self {
let mut adj = vec![vec![]; n];
for i in 1..n {
adj[0].push(i);
adj[i].push(0);
}
Self { adj }
}
fn path(n: usize) -> Self {
let mut adj = vec![vec![]; n];
for i in 0..(n - 1) {
adj[i].push(i + 1);
adj[i + 1].push(i);
}
Self { adj }
}
}
impl Graph for TestGraph {
fn node_count(&self) -> usize {
self.adj.len()
}
fn neighbors(&self, node: usize) -> Vec<usize> {
self.adj[node].clone()
}
}
#[test]
fn embedding_dimensions_match() {
let g = TestGraph::path(6);
let config = EllipsoidalConfig {
dim: 3,
..Default::default()
};
let embs = ellipsoidal_embedding(&g, &config);
assert_eq!(embs.len(), 6);
for e in &embs {
assert_eq!(e.center.len(), 3);
assert_eq!(e.shape.len(), 9);
}
}
#[test]
fn shape_matrices_are_psd() {
let g = TestGraph::path(8);
let config = EllipsoidalConfig {
dim: 3,
..Default::default()
};
let embs = ellipsoidal_embedding(&g, &config);
for e in &embs {
let mut m = e.shape.clone();
let (vals, _) = symmetric_eigen(3, &mut m);
for &v in &vals {
assert!(v >= -1e-10, "shape matrix eigenvalue is negative: {v}");
}
}
}
#[test]
fn hub_and_leaf_have_different_shapes() {
let n = 6;
let g = TestGraph::star(n);
let config = EllipsoidalConfig {
dim: n - 1,
..Default::default()
};
let embs = ellipsoidal_embedding(&g, &config);
let dim = n - 1;
let hub_trace = trace(dim, &embs[0].shape);
let leaf_traces: Vec<f64> = (1..n).map(|i| trace(dim, &embs[i].shape)).collect();
let leaf_avg: f64 = leaf_traces.iter().sum::<f64>() / leaf_traces.len() as f64;
for < in &leaf_traces {
assert!(
(lt - leaf_avg).abs() < 1e-6,
"leaf traces differ: {lt} vs avg {leaf_avg}"
);
}
assert!(
(hub_trace - leaf_avg).abs() > 1e-4,
"hub trace ({hub_trace}) and leaf trace ({leaf_avg}) should differ"
);
}
#[test]
fn distance_symmetry_and_self() {
let g = TestGraph::path(5);
let config = EllipsoidalConfig {
dim: 2,
..Default::default()
};
let embs = ellipsoidal_embedding(&g, &config);
for e in &embs {
let d = ellipsoid_distance(e, e).unwrap();
assert!(d < 1e-6, "self-distance should be ~0, got {d}");
}
for i in 0..embs.len() {
for j in (i + 1)..embs.len() {
let d1 = ellipsoid_distance(&embs[i], &embs[j]).unwrap();
let d2 = ellipsoid_distance(&embs[j], &embs[i]).unwrap();
assert!(
(d1 - d2).abs() < 1e-6,
"distance not symmetric: {d1} vs {d2}"
);
assert!(d1 >= 0.0, "distance should be non-negative");
}
}
}
#[test]
fn distance_is_nonnegative() {
let g = TestGraph::star(7);
let config = EllipsoidalConfig {
dim: 3,
..Default::default()
};
let embs = ellipsoidal_embedding(&g, &config);
for i in 0..embs.len() {
for j in 0..embs.len() {
let d = ellipsoid_distance(&embs[i], &embs[j]).unwrap();
assert!(d >= -1e-10, "distance should be non-negative, got {d}");
}
}
}
#[test]
fn complete_graph_identical_ellipsoids() {
let n = 5;
let g = TestGraph::complete(n);
let config = EllipsoidalConfig {
dim: n - 1,
..Default::default()
};
let embs = ellipsoidal_embedding(&g, &config);
let dim = n - 1;
let traces: Vec<f64> = embs.iter().map(|e| trace(dim, &e.shape)).collect();
let first = traces[0];
for &tr in &traces[1..] {
assert!(
(tr - first).abs() < 1e-6,
"complete graph: traces differ: {tr} vs {first}"
);
}
}
#[test]
fn overlap_self_is_one() {
let g = TestGraph::path(5);
let config = EllipsoidalConfig {
dim: 2,
..Default::default()
};
let embs = ellipsoidal_embedding(&g, &config);
for e in &embs {
let o = ellipsoid_overlap(e, e).unwrap();
assert!(
(o - 1.0).abs() < 1e-3,
"self-overlap should be ~1.0, got {o}"
);
}
}
#[test]
fn overlap_is_symmetric() {
let g = TestGraph::star(6);
let config = EllipsoidalConfig {
dim: 2,
..Default::default()
};
let embs = ellipsoidal_embedding(&g, &config);
for i in 0..embs.len() {
for j in (i + 1)..embs.len() {
let o1 = ellipsoid_overlap(&embs[i], &embs[j]).unwrap();
let o2 = ellipsoid_overlap(&embs[j], &embs[i]).unwrap();
assert!(
(o1 - o2).abs() < 1e-8,
"overlap not symmetric: {o1} vs {o2}"
);
}
}
}
#[test]
fn overlap_in_unit_range() {
let g = TestGraph::path(6);
let config = EllipsoidalConfig {
dim: 2,
..Default::default()
};
let embs = ellipsoidal_embedding(&g, &config);
for i in 0..embs.len() {
for j in 0..embs.len() {
let o = ellipsoid_overlap(&embs[i], &embs[j]).unwrap();
assert!(
(0.0..=1.0 + 1e-10).contains(&o),
"overlap out of range: {o}"
);
}
}
}
}