use crate::error::{SpatialError, SpatialResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::simd_ops::SimdUnifiedOps;
pub fn simd_euclidean_distance(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> SpatialResult<f64> {
if a.len() != b.len() {
return Err(SpatialError::ValueError(
"Points must have the same dimension".to_string(),
));
}
let diff = f64::simd_sub(a, b);
let squared = f64::simd_mul(&diff.view(), &diff.view());
let sum = f64::simd_sum(&squared.view());
Ok(sum.sqrt())
}
pub fn simd_squared_euclidean_distance(
a: &ArrayView1<f64>,
b: &ArrayView1<f64>,
) -> SpatialResult<f64> {
if a.len() != b.len() {
return Err(SpatialError::ValueError(
"Points must have the same dimension".to_string(),
));
}
let diff = f64::simd_sub(a, b);
let squared = f64::simd_mul(&diff.view(), &diff.view());
Ok(f64::simd_sum(&squared.view()))
}
pub fn simd_manhattan_distance(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> SpatialResult<f64> {
if a.len() != b.len() {
return Err(SpatialError::ValueError(
"Points must have the same dimension".to_string(),
));
}
let diff = f64::simd_sub(a, b);
let abs_diff = f64::simd_abs(&diff.view());
Ok(f64::simd_sum(&abs_diff.view()))
}
pub fn simd_chebyshev_distance(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> SpatialResult<f64> {
if a.len() != b.len() {
return Err(SpatialError::ValueError(
"Points must have the same dimension".to_string(),
));
}
let diff = f64::simd_sub(a, b);
let abs_diff = f64::simd_abs(&diff.view());
Ok(f64::simd_max_element(&abs_diff.view()))
}
pub fn simd_minkowski_distance(
a: &ArrayView1<f64>,
b: &ArrayView1<f64>,
p: f64,
) -> SpatialResult<f64> {
if a.len() != b.len() {
return Err(SpatialError::ValueError(
"Points must have the same dimension".to_string(),
));
}
if p < 1.0 {
return Err(SpatialError::ValueError(
"Minkowski p must be >= 1.0".to_string(),
));
}
if (p - 1.0).abs() < 1e-10 {
return simd_manhattan_distance(a, b);
}
if (p - 2.0).abs() < 1e-10 {
return simd_euclidean_distance(a, b);
}
let diff = f64::simd_sub(a, b);
let abs_diff = f64::simd_abs(&diff.view());
let powered = f64::simd_powf(&abs_diff.view(), p);
let sum = f64::simd_sum(&powered.view());
Ok(sum.powf(1.0 / p))
}
pub fn simd_cosine_distance(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> SpatialResult<f64> {
if a.len() != b.len() {
return Err(SpatialError::ValueError(
"Points must have the same dimension".to_string(),
));
}
let dot_product = f64::simd_dot(a, b);
let norm_a = f64::simd_norm(a);
let norm_b = f64::simd_norm(b);
if norm_a == 0.0 || norm_b == 0.0 {
return Err(SpatialError::ValueError(
"Cannot compute cosine distance for zero vectors".to_string(),
));
}
let cosine_similarity = dot_product / (norm_a * norm_b);
Ok(1.0 - cosine_similarity)
}
pub fn simd_point_to_box_min_distance_squared(
point: &ArrayView1<f64>,
box_min: &ArrayView1<f64>,
box_max: &ArrayView1<f64>,
) -> SpatialResult<f64> {
if point.len() != box_min.len() || point.len() != box_max.len() {
return Err(SpatialError::ValueError(
"Point and box dimensions must match".to_string(),
));
}
let clamped = f64::simd_clamp(
point,
*box_min
.first()
.ok_or_else(|| SpatialError::ValueError("Empty array".to_string()))?,
*box_max
.first()
.ok_or_else(|| SpatialError::ValueError("Empty array".to_string()))?,
);
let mut closest_point = Array1::zeros(point.len());
for i in 0..point.len() {
closest_point[i] = point[i].clamp(box_min[i], box_max[i]);
}
let diff = f64::simd_sub(point, &closest_point.view());
let squared = f64::simd_mul(&diff.view(), &diff.view());
Ok(f64::simd_sum(&squared.view()))
}
pub fn simd_box_box_intersection(
box1_min: &ArrayView1<f64>,
box1_max: &ArrayView1<f64>,
box2_min: &ArrayView1<f64>,
box2_max: &ArrayView1<f64>,
) -> SpatialResult<bool> {
if box1_min.len() != box1_max.len()
|| box1_min.len() != box2_min.len()
|| box1_min.len() != box2_max.len()
{
return Err(SpatialError::ValueError(
"All box dimensions must match".to_string(),
));
}
for i in 0..box1_min.len() {
if box1_max[i] < box2_min[i] || box1_min[i] > box2_max[i] {
return Ok(false);
}
}
Ok(true)
}
pub fn simd_batch_squared_distances(
query_point: &ArrayView1<f64>,
data_points: &ArrayView2<f64>,
) -> SpatialResult<Array1<f64>> {
if query_point.len() != data_points.ncols() {
return Err(SpatialError::ValueError(
"Query point dimension must match data points".to_string(),
));
}
let n_points = data_points.nrows();
let mut distances = Array1::zeros(n_points);
for i in 0..n_points {
let data_point = data_points.row(i);
let diff = f64::simd_sub(query_point, &data_point);
let squared = f64::simd_mul(&diff.view(), &diff.view());
distances[i] = f64::simd_sum(&squared.view());
}
Ok(distances)
}
pub fn simd_batch_distances(
points1: &ArrayView2<f64>,
points2: &ArrayView2<f64>,
) -> SpatialResult<Array1<f64>> {
if points1.shape() != points2.shape() {
return Err(SpatialError::ValueError(
"Point arrays must have the same shape".to_string(),
));
}
let n_points = points1.nrows();
let mut distances = Array1::zeros(n_points);
for i in 0..n_points {
let p1 = points1.row(i);
let p2 = points2.row(i);
let diff = f64::simd_sub(&p1, &p2);
let squared = f64::simd_mul(&diff.view(), &diff.view());
let sum = f64::simd_sum(&squared.view());
distances[i] = sum.sqrt();
}
Ok(distances)
}
pub fn simd_knn_search(
query_point: &ArrayView1<f64>,
data_points: &ArrayView2<f64>,
k: usize,
) -> SpatialResult<(Array1<usize>, Array1<f64>)> {
if query_point.len() != data_points.ncols() {
return Err(SpatialError::ValueError(
"Query point dimension must match data points".to_string(),
));
}
let n_points = data_points.nrows();
if k == 0 {
return Err(SpatialError::ValueError(
"k must be greater than 0".to_string(),
));
}
if k > n_points {
return Err(SpatialError::ValueError(format!(
"k ({}) cannot be larger than number of data points ({})",
k, n_points
)));
}
let squared_distances = simd_batch_squared_distances(query_point, data_points)?;
let mut indexed_distances: Vec<(f64, usize)> = squared_distances
.iter()
.enumerate()
.map(|(idx, &dist)| (dist, idx))
.collect();
indexed_distances.select_nth_unstable_by(k - 1, |a, b| {
a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
});
indexed_distances[..k]
.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
let mut indices = Array1::zeros(k);
let mut distances = Array1::zeros(k);
for (i, (dist_sq, idx)) in indexed_distances[..k].iter().enumerate() {
indices[i] = *idx;
distances[i] = dist_sq.sqrt();
}
Ok((indices, distances))
}
pub fn simd_radius_search(
query_point: &ArrayView1<f64>,
data_points: &ArrayView2<f64>,
radius: f64,
) -> SpatialResult<(Array1<usize>, Array1<f64>)> {
if query_point.len() != data_points.ncols() {
return Err(SpatialError::ValueError(
"Query point dimension must match data points".to_string(),
));
}
if radius < 0.0 {
return Err(SpatialError::ValueError(
"Radius must be non-negative".to_string(),
));
}
let squared_distances = simd_batch_squared_distances(query_point, data_points)?;
let radius_squared = radius * radius;
let mut indices = Vec::new();
let mut distances = Vec::new();
for (idx, &dist_sq) in squared_distances.iter().enumerate() {
if dist_sq <= radius_squared {
indices.push(idx);
distances.push(dist_sq.sqrt());
}
}
Ok((Array1::from(indices), Array1::from(distances)))
}
pub fn simd_pairwise_distance_matrix(points: &ArrayView2<f64>) -> SpatialResult<Array2<f64>> {
let n_points = points.nrows();
let mut distances = Array2::zeros((n_points, n_points));
for i in 0..n_points {
let point_i = points.row(i);
for j in (i + 1)..n_points {
let point_j = points.row(j);
let diff = f64::simd_sub(&point_i, &point_j);
let squared = f64::simd_mul(&diff.view(), &diff.view());
let sum = f64::simd_sum(&squared.view());
let dist = sum.sqrt();
distances[[i, j]] = dist;
distances[[j, i]] = dist; }
}
Ok(distances)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_simd_euclidean_distance() {
let a = array![1.0, 2.0, 3.0];
let b = array![4.0, 5.0, 6.0];
let dist =
simd_euclidean_distance(&a.view(), &b.view()).expect("Distance computation failed");
assert_relative_eq!(dist, 5.196152422706632, epsilon = 1e-10);
}
#[test]
fn test_simd_manhattan_distance() {
let a = array![1.0, 2.0, 3.0];
let b = array![4.0, 5.0, 6.0];
let dist =
simd_manhattan_distance(&a.view(), &b.view()).expect("Distance computation failed");
assert_eq!(dist, 9.0);
}
#[test]
fn test_simd_chebyshev_distance() {
let a = array![1.0, 2.0, 3.0];
let b = array![4.0, 6.0, 5.0];
let dist =
simd_chebyshev_distance(&a.view(), &b.view()).expect("Distance computation failed");
assert_eq!(dist, 4.0);
}
#[test]
fn test_simd_minkowski_distance() {
let a = array![1.0, 2.0, 3.0];
let b = array![4.0, 5.0, 6.0];
let dist_p1 = simd_minkowski_distance(&a.view(), &b.view(), 1.0)
.expect("Distance computation failed");
assert_eq!(dist_p1, 9.0);
let dist_p2 = simd_minkowski_distance(&a.view(), &b.view(), 2.0)
.expect("Distance computation failed");
assert_relative_eq!(dist_p2, 5.196152422706632, epsilon = 1e-10);
let dist_p3 = simd_minkowski_distance(&a.view(), &b.view(), 3.0)
.expect("Distance computation failed");
assert_relative_eq!(dist_p3, 4.3267487109222245, epsilon = 1e-10);
}
#[test]
fn test_simd_cosine_distance() {
let a = array![1.0, 2.0, 3.0];
let b = array![4.0, 5.0, 6.0];
let dist = simd_cosine_distance(&a.view(), &b.view()).expect("Distance computation failed");
assert!(dist < 0.03);
assert!(dist >= 0.0);
}
#[test]
fn test_simd_batch_distances() {
let points1 = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
let points2 = array![[2.0, 3.0], [4.0, 5.0], [6.0, 7.0]];
let distances = simd_batch_distances(&points1.view(), &points2.view())
.expect("Batch distance computation failed");
assert_eq!(distances.len(), 3);
for &dist in distances.iter() {
assert_relative_eq!(dist, std::f64::consts::SQRT_2, epsilon = 1e-10);
}
}
#[test]
fn test_simd_knn_search() {
let data_points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [2.0, 2.0]];
let query = array![0.5, 0.5];
let (indices, distances) =
simd_knn_search(&query.view(), &data_points.view(), 3).expect("k-NN search failed");
assert_eq!(indices.len(), 3);
assert_eq!(distances.len(), 3);
for i in 1..distances.len() {
assert!(distances[i] >= distances[i - 1]);
}
}
#[test]
fn test_simd_radius_search() {
let data_points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [5.0, 5.0]];
let query = array![0.5, 0.5];
let radius = 1.0;
let (indices, distances) = simd_radius_search(&query.view(), &data_points.view(), radius)
.expect("Radius search failed");
assert_eq!(indices.len(), 4);
for &dist in distances.iter() {
assert!(dist <= radius);
}
}
#[test]
fn test_simd_point_to_box_distance() {
let point = array![2.0, 2.0];
let box_min = array![0.0, 0.0];
let box_max = array![1.0, 1.0];
let dist_sq =
simd_point_to_box_min_distance_squared(&point.view(), &box_min.view(), &box_max.view())
.expect("Point-to-box distance failed");
assert_relative_eq!(dist_sq, 2.0, epsilon = 1e-10);
}
#[test]
fn test_simd_box_intersection() {
let box1_min = array![0.0, 0.0];
let box1_max = array![2.0, 2.0];
let box2_min = array![1.0, 1.0];
let box2_max = array![3.0, 3.0];
let intersects = simd_box_box_intersection(
&box1_min.view(),
&box1_max.view(),
&box2_min.view(),
&box2_max.view(),
)
.expect("Box intersection test failed");
assert!(intersects);
let box3_min = array![10.0, 10.0];
let box3_max = array![20.0, 20.0];
let no_intersect = simd_box_box_intersection(
&box1_min.view(),
&box1_max.view(),
&box3_min.view(),
&box3_max.view(),
)
.expect("Box intersection test failed");
assert!(!no_intersect);
}
#[test]
fn test_dimension_mismatch_errors() {
let a = array![1.0, 2.0];
let b = array![1.0, 2.0, 3.0];
assert!(simd_euclidean_distance(&a.view(), &b.view()).is_err());
assert!(simd_manhattan_distance(&a.view(), &b.view()).is_err());
assert!(simd_chebyshev_distance(&a.view(), &b.view()).is_err());
assert!(simd_cosine_distance(&a.view(), &b.view()).is_err());
}
#[test]
fn test_zero_vector_cosine() {
let a = array![0.0, 0.0, 0.0];
let b = array![1.0, 2.0, 3.0];
assert!(simd_cosine_distance(&a.view(), &b.view()).is_err());
}
}