use super::ProductManifold;
use crate::error::{MathError, Result};
use crate::utils::{norm, EPS};
#[cfg(feature = "parallel")]
use rayon::prelude::*;
impl ProductManifold {
pub fn pairwise_distances(&self, points: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
let n = points.len();
#[cfg(feature = "parallel")]
{
self.pairwise_distances_parallel(points, n)
}
#[cfg(not(feature = "parallel"))]
{
self.pairwise_distances_sequential(points, n)
}
}
#[inline]
fn pairwise_distances_sequential(
&self,
points: &[Vec<f64>],
n: usize,
) -> Result<Vec<Vec<f64>>> {
let mut distances = vec![vec![0.0; n]; n];
for i in 0..n {
for j in (i + 1)..n {
let d = self.distance(&points[i], &points[j])?;
distances[i][j] = d;
distances[j][i] = d;
}
}
Ok(distances)
}
#[cfg(feature = "parallel")]
fn pairwise_distances_parallel(&self, points: &[Vec<f64>], n: usize) -> Result<Vec<Vec<f64>>> {
let pairs: Vec<_> = (0..n)
.flat_map(|i| ((i + 1)..n).map(move |j| (i, j)))
.collect();
let results: Vec<(usize, usize, f64)> = pairs
.par_iter()
.filter_map(|&(i, j)| {
self.distance(&points[i], &points[j])
.ok()
.map(|d| (i, j, d))
})
.collect();
let mut distances = vec![vec![0.0; n]; n];
for (i, j, d) in results {
distances[i][j] = d;
distances[j][i] = d;
}
Ok(distances)
}
pub fn knn(&self, query: &[f64], points: &[Vec<f64>], k: usize) -> Result<Vec<(usize, f64)>> {
#[cfg(feature = "parallel")]
{
self.knn_parallel(query, points, k)
}
#[cfg(not(feature = "parallel"))]
{
self.knn_sequential(query, points, k)
}
}
#[inline]
fn knn_sequential(
&self,
query: &[f64],
points: &[Vec<f64>],
k: usize,
) -> Result<Vec<(usize, f64)>> {
let mut distances: Vec<(usize, f64)> = points
.iter()
.enumerate()
.filter_map(|(i, p)| self.distance(query, p).ok().map(|d| (i, d)))
.collect();
distances
.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
distances.truncate(k);
Ok(distances)
}
#[cfg(feature = "parallel")]
fn knn_parallel(
&self,
query: &[f64],
points: &[Vec<f64>],
k: usize,
) -> Result<Vec<(usize, f64)>> {
let mut distances: Vec<(usize, f64)> = points
.par_iter()
.enumerate()
.filter_map(|(i, p)| self.distance(query, p).ok().map(|d| (i, d)))
.collect();
distances
.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
distances.truncate(k);
Ok(distances)
}
pub fn geodesic(&self, x: &[f64], y: &[f64], t: f64) -> Result<Vec<f64>> {
let t = t.clamp(0.0, 1.0);
let v = self.log_map(x, y)?;
let tv: Vec<f64> = v.iter().map(|&vi| t * vi).collect();
self.exp_map(x, &tv)
}
pub fn geodesic_path(&self, x: &[f64], y: &[f64], num_points: usize) -> Result<Vec<Vec<f64>>> {
let mut path = Vec::with_capacity(num_points);
for i in 0..num_points {
let t = i as f64 / (num_points - 1).max(1) as f64;
path.push(self.geodesic(x, y, t)?);
}
Ok(path)
}
pub fn parallel_transport(&self, x: &[f64], y: &[f64], v: &[f64]) -> Result<Vec<f64>> {
if x.len() != self.dim() || y.len() != self.dim() || v.len() != self.dim() {
return Err(MathError::dimension_mismatch(self.dim(), x.len()));
}
let mut result = vec![0.0; self.dim()];
let (e_range, h_range, s_range) = self.config().component_ranges();
for i in e_range.clone() {
result[i] = v[i];
}
if !h_range.is_empty() {
let x_h = &x[h_range.clone()];
let y_h = &y[h_range.clone()];
let v_h = &v[h_range.clone()];
let pt_h = self.poincare_parallel_transport(x_h, y_h, v_h)?;
for (i, val) in h_range.clone().zip(pt_h.iter()) {
result[i] = *val;
}
}
if !s_range.is_empty() {
let x_s = &x[s_range.clone()];
let y_s = &y[s_range.clone()];
let v_s = &v[s_range.clone()];
let pt_s = self.spherical_parallel_transport(x_s, y_s, v_s)?;
for (i, val) in s_range.clone().zip(pt_s.iter()) {
result[i] = *val;
}
}
Ok(result)
}
fn poincare_parallel_transport(&self, x: &[f64], y: &[f64], v: &[f64]) -> Result<Vec<f64>> {
let c = -self.config().hyperbolic_curvature;
let x_norm_sq: f64 = x.iter().map(|&xi| xi * xi).sum();
let y_norm_sq: f64 = y.iter().map(|&yi| yi * yi).sum();
let lambda_x = 2.0 / (1.0 - c * x_norm_sq).max(EPS);
let lambda_y = 2.0 / (1.0 - c * y_norm_sq).max(EPS);
let scale = lambda_x / lambda_y;
let xy_dot: f64 = x.iter().zip(y.iter()).map(|(&xi, &yi)| xi * yi).sum();
let _gyration_factor = 1.0 + c * xy_dot;
Ok(v.iter().map(|&vi| scale * vi).collect())
}
fn spherical_parallel_transport(&self, x: &[f64], y: &[f64], v: &[f64]) -> Result<Vec<f64>> {
use crate::utils::dot;
let cos_theta = dot(x, y).clamp(-1.0, 1.0);
if (cos_theta - 1.0).abs() < EPS {
return Ok(v.to_vec());
}
let theta = cos_theta.acos();
let u: Vec<f64> = x
.iter()
.zip(y.iter())
.map(|(&xi, &yi)| yi - cos_theta * xi)
.collect();
let u_norm = norm(&u);
if u_norm < EPS {
return Ok(v.to_vec());
}
let u: Vec<f64> = u.iter().map(|&ui| ui / u_norm).collect();
let v_u = dot(v, &u);
let v_x = dot(v, x);
let result: Vec<f64> = (0..x.len())
.map(|i| {
let v_perp = v[i] - v_u * u[i] - v_x * x[i];
v_perp + v_u * (-theta.sin() * x[i] + theta.cos() * u[i])
- v_x * (theta.cos() * x[i] + theta.sin() * u[i])
})
.collect();
Ok(result)
}
pub fn variance(&self, points: &[Vec<f64>], mean: Option<&[f64]>) -> Result<f64> {
if points.is_empty() {
return Ok(0.0);
}
let mean = match mean {
Some(m) => m.to_vec(),
None => self.frechet_mean(points, None)?,
};
let mut total_sq_dist = 0.0;
for p in points {
let d = self.distance(&mean, p)?;
total_sq_dist += d * d;
}
Ok(total_sq_dist / points.len() as f64)
}
pub fn project_gradient(&self, point: &[f64], gradient: &[f64]) -> Result<Vec<f64>> {
if point.len() != self.dim() || gradient.len() != self.dim() {
return Err(MathError::dimension_mismatch(self.dim(), point.len()));
}
let mut result = gradient.to_vec();
let (_e_range, h_range, s_range) = self.config().component_ranges();
if !h_range.is_empty() {
let x_h = &point[h_range.clone()];
let x_norm_sq: f64 = x_h.iter().map(|&xi| xi * xi).sum();
let c = -self.config().hyperbolic_curvature;
let lambda = 2.0 / (1.0 - c * x_norm_sq).max(EPS);
let scale = 1.0 / (lambda * lambda);
for i in h_range.clone() {
result[i] *= scale;
}
}
if !s_range.is_empty() {
let x_s = &point[s_range.clone()];
let g_s = &gradient[s_range.clone()];
let normal_component: f64 = g_s.iter().zip(x_s.iter()).map(|(&gi, &xi)| gi * xi).sum();
for (i, &xi) in s_range.clone().zip(x_s.iter()) {
result[i] -= normal_component * xi;
}
}
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pairwise_distances() {
let manifold = ProductManifold::new(2, 0, 0);
let points = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0]];
let dists = manifold.pairwise_distances(&points).unwrap();
assert!(dists[0][0].abs() < 1e-10);
assert!((dists[0][1] - 1.0).abs() < 1e-10);
assert!((dists[0][2] - 1.0).abs() < 1e-10);
}
#[test]
fn test_knn() {
let manifold = ProductManifold::new(2, 0, 0);
let points = vec![
vec![0.0, 0.0],
vec![1.0, 0.0],
vec![2.0, 0.0],
vec![3.0, 0.0],
];
let query = vec![0.5, 0.0];
let neighbors = manifold.knn(&query, &points, 2).unwrap();
assert_eq!(neighbors.len(), 2);
assert!(neighbors[0].0 == 0 || neighbors[0].0 == 1);
}
#[test]
fn test_geodesic_path() {
let manifold = ProductManifold::new(2, 0, 0);
let x = vec![0.0, 0.0];
let y = vec![2.0, 2.0];
let path = manifold.geodesic_path(&x, &y, 5).unwrap();
assert_eq!(path.len(), 5);
assert!((path[2][0] - 1.0).abs() < 1e-6);
assert!((path[2][1] - 1.0).abs() < 1e-6);
}
#[test]
fn test_variance() {
let manifold = ProductManifold::new(2, 0, 0);
let points = vec![
vec![1.0, 0.0],
vec![-1.0, 0.0],
vec![0.0, 1.0],
vec![0.0, -1.0],
];
let variance = manifold.variance(&points, Some(&vec![0.0, 0.0])).unwrap();
assert!((variance - 1.0).abs() < 1e-10);
}
}