use num_traits::Float;
use std::iter::Sum;
#[derive(Debug, PartialEq, Eq)]
pub struct IncompatiblePointsError;
pub fn lp_distance<T>(x: &[T], y: &[T], p: i32) -> Result<T, IncompatiblePointsError>
where
T: Float + Sum + From<i32>,
{
if x.len() != y.len() {
Err(IncompatiblePointsError {})
} else {
Ok(x.iter()
.zip(y.iter())
.map(|(&a, &b)| T::powi(T::abs(b - a), p))
.sum::<T>()
.powf(T::one() / p.into()))
}
}
pub fn euclidean_distance<T>(x: &[T], y: &[T]) -> Result<T, IncompatiblePointsError>
where
T: Float + Sum + From<i32>,
{
lp_distance(x, y, 2)
}
pub fn maximum_distance<T: Float>(x: &[T], y: &[T]) -> Result<T, IncompatiblePointsError> {
if x.len() != y.len() {
Err(IncompatiblePointsError {})
} else {
Ok(x.iter()
.zip(y.iter())
.map(|(&a, &b)| T::abs(b - a))
.reduce(Float::max)
.unwrap_or(T::zero()))
}
}
fn euclidean_dot_product<T: Float>(angles1: &[T], angles2: &[T]) -> T {
let mut total = T::zero();
let mut sin_prod = T::one();
let d = angles1.len();
for (i, (t1, t2)) in angles1.iter().zip(angles2.iter()).enumerate() {
total = T::mul_add(sin_prod, t1.cos() * t2.cos(), total);
sin_prod = sin_prod * t1.sin() * t2.sin();
if i == d - 1 {
total = total + sin_prod;
}
}
total
}
pub fn angular_distance<T: Float>(
angles1: &[T],
angles2: &[T],
) -> Result<T, IncompatiblePointsError> {
if angles1.len() != angles2.len() {
Err(IncompatiblePointsError {})
} else {
Ok(euclidean_dot_product(angles1, angles2).acos().abs())
}
}
pub fn polar_hyperbolic_distance<T: Float>(
r1: T,
angles1: &[T],
r2: T,
angles2: &[T],
) -> Result<T, IncompatiblePointsError> {
if angles1.len() != angles2.len() {
Err(IncompatiblePointsError {})
} else {
let arg = (r1 - r2).cosh()
+ (T::one() - euclidean_dot_product(angles1, angles2)) * r1.sinh() * r2.sinh();
Ok(if arg < T::one() {
T::zero()
} else {
arg.acosh()
})
}
}
pub fn hyperboloid_hyperbolic_distance<T: Float>(
x: &[T],
y: &[T],
) -> Result<T, IncompatiblePointsError> {
if x.len() != y.len() {
Err(IncompatiblePointsError {})
} else {
let mut sum_x_squared = T::zero();
let mut sum_y_squared = T::zero();
let mut sum_xy = T::zero();
for (x_i, y_i) in x.iter().zip(y.iter()) {
sum_x_squared = T::mul_add(*x_i, *x_i, sum_x_squared);
sum_y_squared = T::mul_add(*y_i, *y_i, sum_y_squared);
sum_xy = T::mul_add(*x_i, *y_i, sum_xy);
}
let arg = (T::one() + sum_x_squared).sqrt() * (T::one() + sum_y_squared).sqrt() - sum_xy;
Ok(if arg < T::one() {
T::zero()
} else {
arg.acosh()
})
}
}
#[cfg(test)]
mod tests {
use std::f64::consts::PI;
use super::{
IncompatiblePointsError, angular_distance, euclidean_distance,
hyperboloid_hyperbolic_distance, lp_distance, maximum_distance, polar_hyperbolic_distance,
};
#[test]
fn test_l4_dist() {
assert!(
(lp_distance(&[1., 2., 3.], &[5., 3., 1.], 4).unwrap()
- (256_f64 + 1. + 16.).sqrt().sqrt())
.abs()
< 1e-15
);
}
#[test]
fn test_l1_dist() {
assert!(
(lp_distance(&[1., 2., 3.], &[5., 3., 1.], 1).unwrap() - (4. + 1. + 2_f64)).abs()
< 1e-15
);
}
#[test]
fn test_lp_dist_incompatible_error() {
assert_eq!(
lp_distance(&[0., 0.], &[0., 0., 0.], 1),
Err(IncompatiblePointsError {})
);
assert_eq!(
lp_distance(&[0., 0.], &[0., 0., 0.], 4),
Err(IncompatiblePointsError {})
);
}
#[test]
fn test_euclidean_dist() {
assert!(
(euclidean_distance(&[1., 2., 3.], &[5., 3., 1.]).unwrap() - (16_f64 + 1. + 4.).sqrt())
.abs()
< 1e-15
);
}
#[test]
fn test_euclidean_dist_incompatible_error() {
assert_eq!(
euclidean_distance(&[0., 0.], &[0., 0., 0.]),
Err(IncompatiblePointsError {})
);
}
#[test]
fn test_maximum_dist() {
assert!((maximum_distance(&[1., 2., 3.], &[5., 3., 1.]).unwrap() - 4_f64).abs() < 1e-15);
}
#[test]
fn test_maximum_dist_incompatible_error() {
assert_eq!(
maximum_distance(&[0., 0.], &[0., 0., 0.]),
Err(IncompatiblePointsError {})
);
}
#[test]
fn test_angular_dist() {
assert!((angular_distance(&[0.3], &[0.5]).unwrap() - 0.2_f64).abs() < 1e-15);
assert!((angular_distance(&[0.5 * PI, PI], &[0.5 * PI, 0.]).unwrap() - PI).abs() < 1e-15);
assert!((angular_distance(&[0.2, 0., 1.], &[0., 0., 1.]).unwrap() - 0.2_f64).abs() < 1e-15);
assert!(
(angular_distance(&[2. * PI, 0., 1.], &[0., 0., 1.]).unwrap() - 0_f64).abs() < 1e-15
);
}
#[test]
fn test_angular_dist_incompatible_error() {
assert_eq!(
angular_distance(&[0.], &[0., 0.]),
Err(IncompatiblePointsError {})
);
}
#[test]
fn test_polar_hyperbolic_dist() {
assert_eq!(
polar_hyperbolic_distance(3., &[0.], 0.5, &[PI]).unwrap(),
3.5
);
}
#[test]
fn test_polar_hyperbolic_dist_incompatible_error() {
assert_eq!(
polar_hyperbolic_distance(1., &[0.], 1., &[0., 0.]),
Err(IncompatiblePointsError {})
);
}
#[test]
fn test_hyperboloid_dist() {
assert_eq!(
hyperboloid_hyperbolic_distance(&[3_f64.sinh(), 0.], &[-0.5_f64.sinh(), 0.]).unwrap(),
3.5
);
}
#[test]
fn test_hyperboloid_dist_inf() {
assert!(
hyperboloid_hyperbolic_distance(&[f64::INFINITY, 0.], &[0., 0.])
.unwrap()
.is_nan()
);
}
#[test]
fn test_hyperboloid_dist_length_error() {
assert_eq!(
euclidean_distance(&[0., 0.], &[0., 0., 0.]),
Err(IncompatiblePointsError {})
);
}
}