use scirs2_core::ndarray::{Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::collections::BinaryHeap;
use std::fmt::Debug;
use crate::error::{ClusteringError, Result};
#[derive(Debug, Clone)]
pub struct NeighborSearchConfig {
pub algorithm: NeighborSearchAlgorithm,
pub leaf_size: usize,
pub n_hash_tables: usize,
pub n_hash_functions: usize,
pub parallel: bool,
}
impl Default for NeighborSearchConfig {
fn default() -> Self {
Self {
algorithm: NeighborSearchAlgorithm::Auto,
leaf_size: 30,
n_hash_tables: 10,
n_hash_functions: 4,
parallel: true,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NeighborSearchAlgorithm {
Auto,
BruteForce,
KDTree,
BallTree,
LSH,
}
#[derive(Debug, Clone)]
pub struct NeighborResult {
pub indices: Vec<usize>,
pub distances: Vec<f64>,
}
pub trait NeighborSearcher<F: Float> {
fn fit(&mut self, data: ArrayView2<F>) -> Result<()>;
fn kneighbors(&self, query: ArrayView1<F>, k: usize) -> Result<NeighborResult>;
fn radius_neighbors(&self, query: ArrayView1<F>, radius: F) -> Result<NeighborResult>;
fn kneighbors_batch(&self, queries: ArrayView2<F>, k: usize) -> Result<Vec<NeighborResult>> {
let mut results = Vec::new();
for query in queries.outer_iter() {
results.push(self.kneighbors(query, k)?);
}
Ok(results)
}
}
#[derive(Debug)]
pub struct KDTree<F: Float> {
data: Option<Array2<F>>,
tree: Option<KDNode>,
leaf_size: usize,
}
#[derive(Debug, Clone)]
struct KDNode {
indices: Vec<usize>,
split_dim: usize,
split_val: f64,
left: Option<Box<KDNode>>,
right: Option<Box<KDNode>>,
is_leaf: bool,
}
impl<F: Float + FromPrimitive + Debug> KDTree<F> {
pub fn new(leaf_size: usize) -> Self {
Self {
data: None,
tree: None,
leaf_size,
}
}
}
impl<F: Float + FromPrimitive + Debug> NeighborSearcher<F> for KDTree<F> {
fn fit(&mut self, data: ArrayView2<F>) -> Result<()> {
let n_samples = data.shape()[0];
let n_features = data.shape()[1];
if n_samples == 0 {
return Err(ClusteringError::InvalidInput(
"Cannot fit on empty data".into(),
));
}
self.data = Some(data.to_owned());
let indices: Vec<usize> = (0..n_samples).collect();
self.tree = Some(self.build_tree(indices, 0, n_features)?);
Ok(())
}
fn kneighbors(&self, query: ArrayView1<F>, k: usize) -> Result<NeighborResult> {
let data = self
.data
.as_ref()
.ok_or_else(|| ClusteringError::InvalidInput("Tree not fitted yet".into()))?;
let tree = self
.tree
.as_ref()
.ok_or_else(|| ClusteringError::InvalidInput("Tree not built yet".into()))?;
if k == 0 {
return Ok(NeighborResult {
indices: vec![],
distances: vec![],
});
}
let mut heap = BinaryHeap::new();
self.search_knn(tree, query, k, data.view(), &mut heap);
let mut indices = Vec::new();
let mut distances = Vec::new();
while let Some(neighbor) = heap.pop() {
indices.push(neighbor.index);
distances.push(neighbor.distance);
}
indices.reverse();
distances.reverse();
Ok(NeighborResult { indices, distances })
}
fn radius_neighbors(&self, query: ArrayView1<F>, radius: F) -> Result<NeighborResult> {
let data = self
.data
.as_ref()
.ok_or_else(|| ClusteringError::InvalidInput("Tree not fitted yet".into()))?;
let tree = self
.tree
.as_ref()
.ok_or_else(|| ClusteringError::InvalidInput("Tree not built yet".into()))?;
let mut result = NeighborResult {
indices: Vec::new(),
distances: Vec::new(),
};
let radius_f64 = radius.to_f64().unwrap_or(0.0);
self.search_radius(tree, query, radius_f64, data.view(), &mut result);
Ok(result)
}
}
impl<F: Float + FromPrimitive + Debug> KDTree<F> {
fn build_tree(
&self,
mut indices: Vec<usize>,
depth: usize,
n_features: usize,
) -> Result<KDNode> {
if indices.len() <= self.leaf_size {
return Ok(KDNode {
indices,
split_dim: 0,
split_val: 0.0,
left: None,
right: None,
is_leaf: true,
});
}
let data = self.data.as_ref().expect("Operation failed");
let split_dim = depth % n_features;
indices.sort_by(|&a, &b| {
let val_a = data[[a, split_dim]].to_f64().unwrap_or(0.0);
let val_b = data[[b, split_dim]].to_f64().unwrap_or(0.0);
val_a
.partial_cmp(&val_b)
.unwrap_or(std::cmp::Ordering::Equal)
});
let median_idx = indices.len() / 2;
let split_val = data[[indices[median_idx], split_dim]]
.to_f64()
.unwrap_or(0.0);
let left_indices = indices[..median_idx].to_vec();
let right_indices = indices[median_idx..].to_vec();
let left = if !left_indices.is_empty() {
Some(Box::new(self.build_tree(
left_indices,
depth + 1,
n_features,
)?))
} else {
None
};
let right = if !right_indices.is_empty() {
Some(Box::new(self.build_tree(
right_indices,
depth + 1,
n_features,
)?))
} else {
None
};
Ok(KDNode {
indices: vec![], split_dim,
split_val,
left,
right,
is_leaf: false,
})
}
#[allow(clippy::only_used_in_recursion)]
fn search_knn(
&self,
node: &KDNode,
query: ArrayView1<F>,
k: usize,
data: ArrayView2<F>,
heap: &mut BinaryHeap<NeighborCandidate>,
) {
if node.is_leaf {
for &idx in &node.indices {
let dist = euclidean_distance(query, data.row(idx));
if heap.len() < k {
heap.push(NeighborCandidate {
distance: dist,
index: idx,
});
} else if let Some(top) = heap.peek() {
if dist < top.distance {
heap.pop();
heap.push(NeighborCandidate {
distance: dist,
index: idx,
});
}
}
}
} else {
let query_val = query[node.split_dim].to_f64().unwrap_or(0.0);
let (first_child, second_child) = if query_val < node.split_val {
(&node.left, &node.right)
} else {
(&node.right, &node.left)
};
if let Some(child) = first_child {
self.search_knn(child, query, k, data, heap);
}
let split_dist = (query_val - node.split_val).abs();
if heap.len() < k || heap.peek().is_none_or(|top| split_dist < top.distance) {
if let Some(child) = second_child {
self.search_knn(child, query, k, data, heap);
}
}
}
}
#[allow(clippy::only_used_in_recursion)]
fn search_radius(
&self,
node: &KDNode,
query: ArrayView1<F>,
radius: f64,
data: ArrayView2<F>,
result: &mut NeighborResult,
) {
if node.is_leaf {
for &idx in &node.indices {
let dist = euclidean_distance(query, data.row(idx));
if dist <= radius {
result.indices.push(idx);
result.distances.push(dist);
}
}
} else {
let query_val = query[node.split_dim].to_f64().unwrap_or(0.0);
let split_dist = (query_val - node.split_val).abs();
if query_val < node.split_val {
if let Some(child) = &node.left {
self.search_radius(child, query, radius, data, result);
}
if split_dist <= radius {
if let Some(child) = &node.right {
self.search_radius(child, query, radius, data, result);
}
}
} else {
if let Some(child) = &node.right {
self.search_radius(child, query, radius, data, result);
}
if split_dist <= radius {
if let Some(child) = &node.left {
self.search_radius(child, query, radius, data, result);
}
}
}
}
}
}
#[derive(Debug, Clone, PartialEq)]
struct NeighborCandidate {
distance: f64,
index: usize,
}
impl Eq for NeighborCandidate {}
impl Ord for NeighborCandidate {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.distance
.partial_cmp(&other.distance)
.unwrap_or(std::cmp::Ordering::Equal)
}
}
impl PartialOrd for NeighborCandidate {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
#[derive(Debug)]
pub struct BruteForceSearch<F: Float> {
data: Option<Array2<F>>,
}
impl<F: Float + FromPrimitive + Debug> BruteForceSearch<F> {
pub fn new() -> Self {
Self { data: None }
}
}
impl<F: Float + FromPrimitive + Debug> Default for BruteForceSearch<F> {
fn default() -> Self {
Self::new()
}
}
impl<F: Float + FromPrimitive + Debug> NeighborSearcher<F> for BruteForceSearch<F> {
fn fit(&mut self, data: ArrayView2<F>) -> Result<()> {
if data.shape()[0] == 0 {
return Err(ClusteringError::InvalidInput(
"Cannot fit on empty data".into(),
));
}
self.data = Some(data.to_owned());
Ok(())
}
fn kneighbors(&self, query: ArrayView1<F>, k: usize) -> Result<NeighborResult> {
let data = self
.data
.as_ref()
.ok_or_else(|| ClusteringError::InvalidInput("Searcher not fitted yet".into()))?;
if k == 0 {
return Ok(NeighborResult {
indices: vec![],
distances: vec![],
});
}
let n_samples = data.shape()[0];
let k_actual = k.min(n_samples);
let mut candidates: Vec<NeighborCandidate> = (0..n_samples)
.map(|i| {
let dist = euclidean_distance(query, data.row(i));
NeighborCandidate {
distance: dist,
index: i,
}
})
.collect();
candidates.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.expect("Operation failed")
});
candidates.truncate(k_actual);
let indices = candidates.iter().map(|c| c.index).collect();
let distances = candidates.iter().map(|c| c.distance).collect();
Ok(NeighborResult { indices, distances })
}
fn radius_neighbors(&self, query: ArrayView1<F>, radius: F) -> Result<NeighborResult> {
let data = self
.data
.as_ref()
.ok_or_else(|| ClusteringError::InvalidInput("Searcher not fitted yet".into()))?;
let radius_f64 = radius.to_f64().unwrap_or(0.0);
let n_samples = data.shape()[0];
let mut indices = Vec::new();
let mut distances = Vec::new();
for i in 0..n_samples {
let dist = euclidean_distance(query, data.row(i));
if dist <= radius_f64 {
indices.push(i);
distances.push(dist);
}
}
Ok(NeighborResult { indices, distances })
}
}
#[derive(Debug)]
pub struct BallTree<F: Float> {
data: Option<Array2<F>>,
tree: Option<BallNode>,
leaf_size: usize,
}
#[derive(Debug, Clone)]
struct BallNode {
center: Vec<f64>,
radius: f64,
indices: Vec<usize>,
left: Option<Box<BallNode>>,
right: Option<Box<BallNode>>,
is_leaf: bool,
}
impl<F: Float + FromPrimitive + Debug> BallTree<F> {
pub fn new(leaf_size: usize) -> Self {
Self {
data: None,
tree: None,
leaf_size,
}
}
}
impl<F: Float + FromPrimitive + Debug> NeighborSearcher<F> for BallTree<F> {
fn fit(&mut self, data: ArrayView2<F>) -> Result<()> {
let n_samples = data.shape()[0];
if n_samples == 0 {
return Err(ClusteringError::InvalidInput(
"Cannot fit on empty data".into(),
));
}
self.data = Some(data.to_owned());
let indices: Vec<usize> = (0..n_samples).collect();
self.tree = Some(self.build_ball_tree(indices, data.view())?);
Ok(())
}
fn kneighbors(&self, query: ArrayView1<F>, k: usize) -> Result<NeighborResult> {
let data = self
.data
.as_ref()
.ok_or_else(|| ClusteringError::InvalidInput("Tree not fitted yet".into()))?;
let tree = self
.tree
.as_ref()
.ok_or_else(|| ClusteringError::InvalidInput("Tree not built yet".into()))?;
if k == 0 {
return Ok(NeighborResult {
indices: vec![],
distances: vec![],
});
}
let mut heap = BinaryHeap::new();
self.search_ball_knn(tree, query, k, data.view(), &mut heap);
let mut indices = Vec::new();
let mut distances = Vec::new();
while let Some(neighbor) = heap.pop() {
indices.push(neighbor.index);
distances.push(neighbor.distance);
}
indices.reverse();
distances.reverse();
Ok(NeighborResult { indices, distances })
}
fn radius_neighbors(&self, query: ArrayView1<F>, radius: F) -> Result<NeighborResult> {
let data = self
.data
.as_ref()
.ok_or_else(|| ClusteringError::InvalidInput("Tree not fitted yet".into()))?;
let tree = self
.tree
.as_ref()
.ok_or_else(|| ClusteringError::InvalidInput("Tree not built yet".into()))?;
let mut result = NeighborResult {
indices: Vec::new(),
distances: Vec::new(),
};
let radius_f64 = radius.to_f64().unwrap_or(0.0);
self.search_ball_radius(tree, query, radius_f64, data.view(), &mut result);
Ok(result)
}
}
impl<F: Float + FromPrimitive + Debug> BallTree<F> {
fn build_ball_tree(&self, indices: Vec<usize>, data: ArrayView2<F>) -> Result<BallNode> {
if indices.len() <= self.leaf_size {
let (center, radius) = self.compute_ball(&indices, data);
return Ok(BallNode {
center,
radius,
indices,
left: None,
right: None,
is_leaf: true,
});
}
let n_features = data.shape()[1];
let mut best_dim = 0;
let mut best_spread = 0.0;
for dim in 0..n_features {
let mut min_val = f64::INFINITY;
let mut max_val = f64::NEG_INFINITY;
for &idx in &indices {
let val = data[[idx, dim]].to_f64().unwrap_or(0.0);
min_val = min_val.min(val);
max_val = max_val.max(val);
}
let spread = max_val - min_val;
if spread > best_spread {
best_spread = spread;
best_dim = dim;
}
}
let mut sorted_indices = indices;
sorted_indices.sort_by(|&a, &b| {
let val_a = data[[a, best_dim]].to_f64().unwrap_or(0.0);
let val_b = data[[b, best_dim]].to_f64().unwrap_or(0.0);
val_a
.partial_cmp(&val_b)
.unwrap_or(std::cmp::Ordering::Equal)
});
let split_idx = sorted_indices.len() / 2;
let left_indices = sorted_indices[..split_idx].to_vec();
let right_indices = sorted_indices[split_idx..].to_vec();
let left = if !left_indices.is_empty() {
Some(Box::new(self.build_ball_tree(left_indices, data)?))
} else {
None
};
let right = if !right_indices.is_empty() {
Some(Box::new(self.build_ball_tree(right_indices, data)?))
} else {
None
};
let (center, radius) = self.compute_ball(&sorted_indices, data);
Ok(BallNode {
center,
radius,
indices: vec![], left,
right,
is_leaf: false,
})
}
fn compute_ball(&self, indices: &[usize], data: ArrayView2<F>) -> (Vec<f64>, f64) {
if indices.is_empty() {
return (vec![], 0.0);
}
let n_features = data.shape()[1];
let mut center = vec![0.0; n_features];
for &idx in indices {
for j in 0..n_features {
center[j] += data[[idx, j]].to_f64().unwrap_or(0.0);
}
}
let n_points = indices.len() as f64;
for val in &mut center {
*val /= n_points;
}
let mut radius = 0.0;
for &idx in indices {
let mut dist_sq = 0.0;
for j in 0..n_features {
let diff = data[[idx, j]].to_f64().unwrap_or(0.0) - center[j];
dist_sq += diff * diff;
}
radius = radius.max(dist_sq.sqrt());
}
(center, radius)
}
#[allow(clippy::only_used_in_recursion)]
fn search_ball_knn(
&self,
node: &BallNode,
query: ArrayView1<F>,
k: usize,
data: ArrayView2<F>,
heap: &mut BinaryHeap<NeighborCandidate>,
) {
if node.is_leaf {
for &idx in &node.indices {
let dist = euclidean_distance(query, data.row(idx));
if heap.len() < k {
heap.push(NeighborCandidate {
distance: dist,
index: idx,
});
} else if let Some(top) = heap.peek() {
if dist < top.distance {
heap.pop();
heap.push(NeighborCandidate {
distance: dist,
index: idx,
});
}
}
}
} else {
let query_vec: Vec<f64> = query.iter().map(|&x| x.to_f64().unwrap_or(0.0)).collect();
let dist_to_center = euclidean_distance_vec(&query_vec, &node.center);
let min_dist_to_ball = (dist_to_center - node.radius).max(0.0);
if heap.len() < k
|| heap
.peek()
.is_none_or(|top| min_dist_to_ball < top.distance)
{
if let (Some(left), Some(right)) = (&node.left, &node.right) {
let left_dist = euclidean_distance_vec(&query_vec, &left.center);
let right_dist = euclidean_distance_vec(&query_vec, &right.center);
if left_dist <= right_dist {
self.search_ball_knn(left, query, k, data, heap);
self.search_ball_knn(right, query, k, data, heap);
} else {
self.search_ball_knn(right, query, k, data, heap);
self.search_ball_knn(left, query, k, data, heap);
}
} else if let Some(child) = &node.left {
self.search_ball_knn(child, query, k, data, heap);
} else if let Some(child) = &node.right {
self.search_ball_knn(child, query, k, data, heap);
}
}
}
}
#[allow(clippy::only_used_in_recursion)]
fn search_ball_radius(
&self,
node: &BallNode,
query: ArrayView1<F>,
radius: f64,
data: ArrayView2<F>,
result: &mut NeighborResult,
) {
if node.is_leaf {
for &idx in &node.indices {
let dist = euclidean_distance(query, data.row(idx));
if dist <= radius {
result.indices.push(idx);
result.distances.push(dist);
}
}
} else {
let query_vec: Vec<f64> = query.iter().map(|&x| x.to_f64().unwrap_or(0.0)).collect();
let dist_to_center = euclidean_distance_vec(&query_vec, &node.center);
if dist_to_center <= radius + node.radius {
if let Some(child) = &node.left {
self.search_ball_radius(child, query, radius, data, result);
}
if let Some(child) = &node.right {
self.search_ball_radius(child, query, radius, data, result);
}
}
}
}
}
#[allow(dead_code)]
pub fn create_neighbor_searcher<F: Float + FromPrimitive + Debug + 'static>(
config: NeighborSearchConfig,
) -> Box<dyn NeighborSearcher<F>> {
match config.algorithm {
NeighborSearchAlgorithm::BruteForce => Box::new(BruteForceSearch::new()),
NeighborSearchAlgorithm::KDTree => Box::new(KDTree::new(config.leaf_size)),
NeighborSearchAlgorithm::BallTree => Box::new(BallTree::new(config.leaf_size)),
NeighborSearchAlgorithm::Auto => {
Box::new(KDTree::new(config.leaf_size))
}
NeighborSearchAlgorithm::LSH => {
Box::new(KDTree::new(config.leaf_size))
}
}
}
#[allow(dead_code)]
fn euclidean_distance<F: Float + FromPrimitive>(a: ArrayView1<F>, b: ArrayView1<F>) -> f64 {
let mut sum = 0.0;
for (x, y) in a.iter().zip(b.iter()) {
let diff = x.to_f64().unwrap_or(0.0) - y.to_f64().unwrap_or(0.0);
sum += diff * diff;
}
sum.sqrt()
}
#[allow(dead_code)]
fn euclidean_distance_vec(a: &[f64], b: &[f64]) -> f64 {
let mut sum = 0.0;
for (x, y) in a.iter().zip(b.iter()) {
let diff = x - y;
sum += diff * diff;
}
sum.sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
fn create_test_data() -> Array2<f64> {
Array2::from_shape_vec(
(6, 2),
vec![
0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 10.0, 10.0, 11.0, 10.0, 10.0, 11.0, ],
)
.expect("Operation failed")
}
#[test]
fn test_brute_force_search() {
let data = create_test_data();
let mut searcher = BruteForceSearch::new();
searcher.fit(data.view()).expect("Operation failed");
let query = Array1::from_vec(vec![0.0, 0.0]);
let result = searcher
.kneighbors(query.view(), 3)
.expect("Operation failed");
assert_eq!(result.indices.len(), 3);
assert_eq!(result.distances.len(), 3);
assert_eq!(result.indices[0], 0);
assert!(result.distances[0] < 1e-10);
let radius_result = searcher
.radius_neighbors(query.view(), 1.5)
.expect("Operation failed");
assert!(radius_result.indices.len() >= 3); }
#[test]
fn test_kdtree_search() {
let data = create_test_data();
let mut searcher = KDTree::new(2);
searcher.fit(data.view()).expect("Operation failed");
let query = Array1::from_vec(vec![0.0, 0.0]);
let result = searcher
.kneighbors(query.view(), 3)
.expect("Operation failed");
assert_eq!(result.indices.len(), 3);
assert_eq!(result.distances.len(), 3);
assert_eq!(result.indices[0], 0);
assert!(result.distances[0] < 1e-10);
}
#[test]
fn test_ball_tree_search() {
let data = create_test_data();
let mut searcher = BallTree::new(2);
searcher.fit(data.view()).expect("Operation failed");
let query = Array1::from_vec(vec![0.0, 0.0]);
let result = searcher
.kneighbors(query.view(), 3)
.expect("Operation failed");
assert_eq!(result.indices.len(), 3);
assert_eq!(result.distances.len(), 3);
assert_eq!(result.indices[0], 0);
assert!(result.distances[0] < 1e-10);
}
#[test]
fn test_neighbor_searcher_factory() {
let data = create_test_data();
let algorithms = vec![
NeighborSearchAlgorithm::BruteForce,
NeighborSearchAlgorithm::KDTree,
NeighborSearchAlgorithm::BallTree,
NeighborSearchAlgorithm::Auto,
];
for algorithm in algorithms {
let config = NeighborSearchConfig {
algorithm,
..Default::default()
};
let mut searcher = create_neighbor_searcher(config);
searcher.fit(data.view()).expect("Operation failed");
let query = Array1::from_vec(vec![0.0, 0.0]);
let result = searcher
.kneighbors(query.view(), 2)
.expect("Operation failed");
assert_eq!(result.indices.len(), 2);
assert_eq!(result.distances.len(), 2);
}
}
#[test]
fn test_empty_data_error() {
let empty_data: Array2<f64> = Array2::zeros((0, 2));
let mut searcher = BruteForceSearch::new();
let result = searcher.fit(empty_data.view());
assert!(result.is_err());
}
#[test]
fn test_k_zero() {
let data = create_test_data();
let mut searcher = BruteForceSearch::new();
searcher.fit(data.view()).expect("Operation failed");
let query = Array1::from_vec(vec![0.0, 0.0]);
let result = searcher
.kneighbors(query.view(), 0)
.expect("Operation failed");
assert_eq!(result.indices.len(), 0);
assert_eq!(result.distances.len(), 0);
}
#[test]
fn test_batch_queries() {
let data = create_test_data();
let mut searcher = BruteForceSearch::new();
searcher.fit(data.view()).expect("Operation failed");
let queries =
Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 10.0, 10.0]).expect("Operation failed");
let results = searcher
.kneighbors_batch(queries.view(), 2)
.expect("Operation failed");
assert_eq!(results.len(), 2);
assert_eq!(results[0].indices.len(), 2);
assert_eq!(results[1].indices.len(), 2);
assert_eq!(results[0].indices[0], 0);
assert_eq!(results[1].indices[0], 3);
}
}