scirs2_interpolate/spatial/
balltree.rs

1//! Ball Tree implementation for efficient nearest neighbor search
2//!
3//! A Ball Tree is a space-partitioning data structure for organizing points in a
4//! k-dimensional space. It divides points into nested hyperspheres, which makes it
5//! particularly effective for high-dimensional data.
6//!
7//! Advantages of Ball Trees over KD-Trees:
8//! - Better performance in high dimensions
9//! - More efficient for datasets with varying density
10//! - Handles elongated clusters well
11//!
12//! This implementation provides:
13//! - Building balanced Ball Trees from point data
14//! - Efficient exact nearest neighbor queries
15//! - k-nearest neighbor searches
16//! - Range queries for all points within a specified radius
17
18use ordered_float::OrderedFloat;
19use scirs2_core::ndarray::Array2;
20use scirs2_core::numeric::{Float, FromPrimitive};
21use std::cmp::Ordering;
22use std::fmt::Debug;
23use std::marker::PhantomData;
24
25use crate::error::{InterpolateError, InterpolateResult};
26use scirs2_core::ndarray::ArrayView1;
27
28/// A node in the Ball Tree
29#[derive(Debug, Clone)]
30struct BallNode<F: Float + ordered_float::FloatCore> {
31    /// Indices of the points in this node
32    indices: Vec<usize>,
33
34    /// Center of the ball
35    center: Vec<F>,
36
37    /// Radius of the ball
38    radius: F,
39
40    /// Left child node index
41    left: Option<usize>,
42
43    /// Right child node index
44    right: Option<usize>,
45}
46
47/// Ball Tree for efficient nearest neighbor searches in high dimensions
48///
49/// # Examples
50///
51/// ```rust
52/// use scirs2_core::ndarray::Array2;
53/// use scirs2_interpolate::spatial::balltree::BallTree;
54///
55/// // Create sample 3D points
56/// let points = Array2::from_shape_vec((5, 3), vec![
57///     0.0, 0.0, 0.0,
58///     1.0, 0.0, 0.0,
59///     0.0, 1.0, 0.0,
60///     0.0, 0.0, 1.0,
61///     0.5, 0.5, 0.5,
62/// ]).unwrap();
63///
64/// // Build Ball Tree
65/// let ball_tree = BallTree::new(points).unwrap();
66///
67/// // Find the nearest neighbor to point (0.6, 0.6, 0.6)
68/// let query = vec![0.6, 0.6, 0.6];
69/// let (idx, distance) = ball_tree.nearest_neighbor(&query).unwrap();
70///
71/// // idx should be 4 (the point at (0.5, 0.5, 0.5))
72/// assert_eq!(idx, 4);
73/// ```
74#[derive(Debug, Clone)]
75pub struct BallTree<F>
76where
77    F: Float + FromPrimitive + Debug + std::cmp::PartialOrd + ordered_float::FloatCore,
78{
79    /// The original points used to build the tree
80    points: Array2<F>,
81
82    /// Nodes of the Ball Tree
83    nodes: Vec<BallNode<F>>,
84
85    /// Root node index
86    root: Option<usize>,
87
88    /// The dimension of the space
89    dim: usize,
90
91    /// Leaf size (max points in a leaf node)
92    leafsize: usize,
93
94    /// Marker for generic type parameter
95    _phantom: PhantomData<F>,
96}
97
98impl<F> BallTree<F>
99where
100    F: Float + FromPrimitive + Debug + std::cmp::PartialOrd + ordered_float::FloatCore,
101{
102    /// Create a new Ball Tree from points
103    ///
104    /// # Arguments
105    ///
106    /// * `points` - Point coordinates with shape (n_points, n_dims)
107    ///
108    /// # Returns
109    ///
110    /// A new Ball Tree for efficient nearest neighbor searches
111    pub fn new(points: Array2<F>) -> InterpolateResult<Self> {
112        Self::with_leafsize(points, 10)
113    }
114
115    /// Create a new Ball Tree with a specified leaf size
116    ///
117    /// # Arguments
118    ///
119    /// * `points` - Point coordinates with shape (n_points, n_dims)
120    /// * `leafsize` - Maximum number of points in a leaf node
121    ///
122    /// # Returns
123    ///
124    /// A new Ball Tree for efficient nearest neighbor searches
125    pub fn with_leafsize(points: Array2<F>, leafsize: usize) -> InterpolateResult<Self> {
126        if points.is_empty() {
127            return Err(InterpolateError::InvalidValue(
128                "Points array cannot be empty".to_string(),
129            ));
130        }
131
132        let n_points = points.shape()[0];
133        let dim = points.shape()[1];
134
135        // For very small datasets, just use a simple linear search
136        if n_points <= leafsize {
137            let indices: Vec<usize> = (0..n_points).collect();
138            let center = compute_centroid(&points, &indices);
139            let radius = compute_radius(&points, &indices, &center);
140
141            let mut tree = Self {
142                points,
143                nodes: Vec::new(),
144                root: None,
145                dim,
146                leafsize,
147                _phantom: PhantomData,
148            };
149
150            if n_points > 0 {
151                // Create a single root node
152                tree.nodes.push(BallNode {
153                    indices,
154                    center,
155                    radius,
156                    left: None,
157                    right: None,
158                });
159                tree.root = Some(0);
160            }
161
162            return Ok(tree);
163        }
164
165        // Pre-allocate nodes (approximately 2*n_points/leafsize)
166        let est_nodes = (2 * n_points / leafsize).max(16);
167
168        let mut tree = Self {
169            points,
170            nodes: Vec::with_capacity(est_nodes),
171            root: None,
172            dim,
173            leafsize,
174            _phantom: PhantomData,
175        };
176
177        // Build the tree
178        let indices: Vec<usize> = (0..n_points).collect();
179        tree.root = Some(tree.build_subtree(&indices));
180
181        Ok(tree)
182    }
183
184    /// Build a subtree recursively
185    fn build_subtree(&mut self, indices: &[usize]) -> usize {
186        let n_points = indices.len();
187
188        // Compute center and radius of this ball
189        let center = compute_centroid(&self.points, indices);
190        let radius = compute_radius(&self.points, indices, &center);
191
192        // If few enough points, create a leaf node
193        if n_points <= self.leafsize {
194            let node_idx = self.nodes.len();
195            self.nodes.push(BallNode {
196                indices: indices.to_vec(),
197                center,
198                radius,
199                left: None,
200                right: None,
201            });
202            return node_idx;
203        }
204
205        // Find the dimension with the largest spread
206        let (split_dim, _) = find_max_spread_dimension(&self.points, indices);
207
208        // Find the two points farthest apart along this dimension to use as seeds
209        let (seed1, seed2) = find_distant_points(&self.points, indices, split_dim);
210
211        // Partition points based on which seed they're closer to
212        let (left_indices, right_indices) = partition_by_seeds(&self.points, indices, seed1, seed2);
213
214        // Create node for this ball
215        let node_idx = self.nodes.len();
216        self.nodes.push(BallNode {
217            indices: indices.to_vec(),
218            center,
219            radius,
220            left: None,
221            right: None,
222        });
223
224        // Recursively build left and right subtrees
225        let left_idx = self.build_subtree(&left_indices);
226        let right_idx = self.build_subtree(&right_indices);
227
228        // Update node with child information
229        self.nodes[node_idx].left = Some(left_idx);
230        self.nodes[node_idx].right = Some(right_idx);
231
232        node_idx
233    }
234
235    /// Find the nearest neighbor to a query point
236    ///
237    /// # Arguments
238    ///
239    /// * `query` - Query point coordinates
240    ///
241    /// # Returns
242    ///
243    /// Tuple containing (point_index, distance) of the nearest neighbor
244    pub fn nearest_neighbor(&self, query: &[F]) -> InterpolateResult<(usize, F)> {
245        // Check query dimension
246        if query.len() != self.dim {
247            return Err(InterpolateError::DimensionMismatch(format!(
248                "Query dimension {} doesn't match Ball Tree dimension {}",
249                query.len(),
250                self.dim
251            )));
252        }
253
254        // Handle empty tree
255        if self.root.is_none() {
256            return Err(InterpolateError::InvalidState(
257                "Ball Tree is empty".to_string(),
258            ));
259        }
260
261        // Very small trees (just use linear search)
262        if self.points.shape()[0] <= self.leafsize {
263            return self.linear_nearest_neighbor(query);
264        }
265
266        // Initialize nearest neighbor search
267        let mut best_dist = <F as scirs2_core::numeric::Float>::infinity();
268        let mut best_idx = 0;
269
270        // Start recursive search
271        self.search_nearest(self.root.unwrap(), query, &mut best_dist, &mut best_idx);
272
273        Ok((best_idx, best_dist))
274    }
275
276    /// Find k nearest neighbors to a query point
277    ///
278    /// # Arguments
279    ///
280    /// * `query` - Query point coordinates
281    /// * `k` - Number of nearest neighbors to find
282    ///
283    /// # Returns
284    ///
285    /// Vector of (point_index, distance) tuples, sorted by distance
286    pub fn k_nearest_neighbors(&self, query: &[F], k: usize) -> InterpolateResult<Vec<(usize, F)>> {
287        // Check query dimension
288        if query.len() != self.dim {
289            return Err(InterpolateError::DimensionMismatch(format!(
290                "Query dimension {} doesn't match Ball Tree dimension {}",
291                query.len(),
292                self.dim
293            )));
294        }
295
296        // Handle empty tree
297        if self.root.is_none() {
298            return Err(InterpolateError::InvalidState(
299                "Ball Tree is empty".to_string(),
300            ));
301        }
302
303        // Limit k to the number of points
304        let k = k.min(self.points.shape()[0]);
305
306        if k == 0 {
307            return Ok(Vec::new());
308        }
309
310        // Very small trees (just use linear search)
311        if self.points.shape()[0] <= self.leafsize {
312            return self.linear_k_nearest_neighbors(query, k);
313        }
314
315        // Use a BinaryHeap as a priority queue to keep track of k nearest points
316        use std::collections::BinaryHeap;
317
318        let mut heap = BinaryHeap::with_capacity(k + 1);
319
320        // Start recursive search
321        self.search_k_nearest(self.root.unwrap(), query, k, &mut heap);
322
323        // Convert heap to sorted vector
324        let mut results: Vec<(usize, F)> = heap
325            .into_iter()
326            .map(|(dist, idx)| (idx, dist.into_inner()))
327            .collect();
328
329        // Sort by distance (since heap gives reverse order)
330        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
331
332        Ok(results)
333    }
334
335    /// Find all points within a specified radius of a query point
336    ///
337    /// # Arguments
338    ///
339    /// * `query` - Query point coordinates
340    /// * `radius` - Search radius
341    ///
342    /// # Returns
343    ///
344    /// Vector of (point_index, distance) tuples for all points within radius
345    pub fn points_within_radius(
346        &self,
347        query: &[F],
348        radius: F,
349    ) -> InterpolateResult<Vec<(usize, F)>> {
350        // Check query dimension
351        if query.len() != self.dim {
352            return Err(InterpolateError::DimensionMismatch(format!(
353                "Query dimension {} doesn't match Ball Tree dimension {}",
354                query.len(),
355                self.dim
356            )));
357        }
358
359        // Handle empty tree
360        if self.root.is_none() {
361            return Err(InterpolateError::InvalidState(
362                "Ball Tree is empty".to_string(),
363            ));
364        }
365
366        if radius <= F::zero() {
367            return Err(InterpolateError::InvalidValue(
368                "Radius must be positive".to_string(),
369            ));
370        }
371
372        // Very small trees (just use linear search)
373        if self.points.shape()[0] <= self.leafsize {
374            return self.linear_points_within_radius(query, radius);
375        }
376
377        // Store results
378        let mut results = Vec::new();
379
380        // Start recursive search
381        self.search_radius(self.root.unwrap(), query, radius, &mut results);
382
383        // Sort by distance
384        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
385
386        Ok(results)
387    }
388
389    /// Recursively search for the nearest neighbor
390    fn search_nearest(
391        &self,
392        node_idx: usize,
393        query: &[F],
394        best_dist: &mut F,
395        best_idx: &mut usize,
396    ) {
397        let node = &self.nodes[node_idx];
398
399        // Calculate distance from query to ball center
400        let center_dist = euclidean_distance(query, &node.center);
401
402        // If this ball is too far away to contain a better point, skip it
403        if center_dist > node.radius + *best_dist {
404            return;
405        }
406
407        // If this is a leaf node, check all points
408        if node.left.is_none() && node.right.is_none() {
409            for &idx in &node.indices {
410                let point = self.points.row(idx);
411                let dist = euclidean_distance(query, &point.to_vec());
412
413                if dist < *best_dist {
414                    *best_dist = dist;
415                    *best_idx = idx;
416                }
417            }
418            return;
419        }
420
421        // Process children
422        // Choose the closer ball first to potentially reduce the best_dist sooner
423        let left_idx = node.left.unwrap();
424        let right_idx = node.right.unwrap();
425
426        let left_node = &self.nodes[left_idx];
427        let right_node = &self.nodes[right_idx];
428
429        let left_dist = euclidean_distance(query, &left_node.center);
430        let right_dist = euclidean_distance(query, &right_node.center);
431
432        if left_dist < right_dist {
433            // Search left child first
434            self.search_nearest(left_idx, query, best_dist, best_idx);
435            self.search_nearest(right_idx, query, best_dist, best_idx);
436        } else {
437            // Search right child first
438            self.search_nearest(right_idx, query, best_dist, best_idx);
439            self.search_nearest(left_idx, query, best_dist, best_idx);
440        }
441    }
442
443    /// Recursively search for the k nearest neighbors
444    #[allow(clippy::type_complexity)]
445    fn search_k_nearest(
446        &self,
447        node_idx: usize,
448        query: &[F],
449        k: usize,
450        heap: &mut std::collections::BinaryHeap<(OrderedFloat<F>, usize)>,
451    ) {
452        let node = &self.nodes[node_idx];
453
454        // Calculate distance from query to ball center
455        let center_dist = euclidean_distance(query, &node.center);
456
457        // Get current kth distance (the farthest point in our current result set)
458        let kth_dist = if heap.len() < k {
459            <F as scirs2_core::numeric::Float>::infinity()
460        } else {
461            // Peek at the top of the max-heap to get the farthest point
462            match heap.peek() {
463                Some(&(dist_, _)) => dist_.into_inner(),
464                None => <F as scirs2_core::numeric::Float>::infinity(),
465            }
466        };
467
468        // If this ball is too far away to contain a better point, skip it
469        if center_dist > node.radius + kth_dist {
470            return;
471        }
472
473        // If this is a leaf node, check all points
474        if node.left.is_none() && node.right.is_none() {
475            for &idx in &node.indices {
476                let point = self.points.row(idx);
477                let dist = euclidean_distance(query, &point.to_vec());
478
479                // Add to heap
480                heap.push((OrderedFloat(dist), idx));
481
482                // If heap is too large, remove the farthest point
483                if heap.len() > k {
484                    heap.pop();
485                }
486            }
487            return;
488        }
489
490        // Process children
491        // Choose the closer ball first to potentially reduce the kth_dist sooner
492        let left_idx = node.left.unwrap();
493        let right_idx = node.right.unwrap();
494
495        let left_node = &self.nodes[left_idx];
496        let right_node = &self.nodes[right_idx];
497
498        let left_dist = euclidean_distance(query, &left_node.center);
499        let right_dist = euclidean_distance(query, &right_node.center);
500
501        if left_dist < right_dist {
502            // Search left child first
503            self.search_k_nearest(left_idx, query, k, heap);
504            self.search_k_nearest(right_idx, query, k, heap);
505        } else {
506            // Search right child first
507            self.search_k_nearest(right_idx, query, k, heap);
508            self.search_k_nearest(left_idx, query, k, heap);
509        }
510    }
511
512    /// Recursively search for all points within a radius
513    fn search_radius(
514        &self,
515        node_idx: usize,
516        query: &[F],
517        radius: F,
518        results: &mut Vec<(usize, F)>,
519    ) {
520        let node = &self.nodes[node_idx];
521
522        // Calculate distance from query to ball center
523        let center_dist = euclidean_distance(query, &node.center);
524
525        // If this ball is too far away to contain any points within radius, skip it
526        if center_dist > node.radius + radius {
527            return;
528        }
529
530        // If this is a leaf node, check all points
531        if node.left.is_none() && node.right.is_none() {
532            for &_idx in &node.indices {
533                let point = self.points.row(_idx);
534                let dist = euclidean_distance(query, &point.to_vec());
535
536                if dist <= radius {
537                    results.push((_idx, dist));
538                }
539            }
540            return;
541        }
542
543        // Process children
544        if let Some(left_idx) = node.left {
545            self.search_radius(left_idx, query, radius, results);
546        }
547
548        if let Some(right_idx) = node.right {
549            self.search_radius(right_idx, query, radius, results);
550        }
551    }
552
553    /// Linear search for the nearest neighbor (for small datasets or leaf nodes)
554    fn linear_nearest_neighbor(&self, query: &[F]) -> InterpolateResult<(usize, F)> {
555        let n_points = self.points.shape()[0];
556
557        let mut min_dist = <F as scirs2_core::numeric::Float>::infinity();
558        let mut min_idx = 0;
559
560        for i in 0..n_points {
561            let point = self.points.row(i);
562            let dist = euclidean_distance(query, &point.to_vec());
563
564            if dist < min_dist {
565                min_dist = dist;
566                min_idx = i;
567            }
568        }
569
570        Ok((min_idx, min_dist))
571    }
572
573    /// Linear search for k nearest neighbors (for small datasets or leaf nodes)
574    fn linear_k_nearest_neighbors(
575        &self,
576        query: &[F],
577        k: usize,
578    ) -> InterpolateResult<Vec<(usize, F)>> {
579        let n_points = self.points.shape()[0];
580        let k = k.min(n_points); // Limit k to the number of points
581
582        // Calculate all distances
583        let mut distances: Vec<(usize, F)> = (0..n_points)
584            .map(|i| {
585                let point = self.points.row(i);
586                let dist = euclidean_distance(query, &point.to_vec());
587                (i, dist)
588            })
589            .collect();
590
591        // Sort by distance
592        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
593
594        // Return k nearest
595        distances.truncate(k);
596        Ok(distances)
597    }
598
599    /// Linear search for points within radius (for small datasets or leaf nodes)
600    fn linear_points_within_radius(
601        &self,
602        query: &[F],
603        radius: F,
604    ) -> InterpolateResult<Vec<(usize, F)>> {
605        let n_points = self.points.shape()[0];
606
607        // Calculate all distances and filter by radius
608        let mut results: Vec<(usize, F)> = (0..n_points)
609            .filter_map(|i| {
610                let point = self.points.row(i);
611                let dist = euclidean_distance(query, &point.to_vec());
612                if dist <= radius {
613                    Some((i, dist))
614                } else {
615                    None
616                }
617            })
618            .collect();
619
620        // Sort by distance
621        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
622
623        Ok(results)
624    }
625
626    /// Get the number of points in the Ball Tree
627    pub fn len(&self) -> usize {
628        self.points.shape()[0]
629    }
630
631    /// Check if the Ball Tree is empty
632    pub fn is_empty(&self) -> bool {
633        self.len() == 0
634    }
635
636    /// Get the dimension of points in the Ball Tree
637    pub fn dim(&self) -> usize {
638        self.dim
639    }
640
641    /// Get a reference to the points in the Ball Tree
642    pub fn points(&self) -> &Array2<F> {
643        &self.points
644    }
645
646    /// Find all points within a specified radius (alias for points_within_radius)
647    ///
648    /// # Arguments
649    ///
650    /// * `query` - Query point coordinates
651    /// * `radius` - Search radius
652    ///
653    /// # Returns
654    ///
655    /// Vector of (point_index, distance) tuples for all points within radius
656    pub fn radius_neighbors(&self, query: &[F], radius: F) -> InterpolateResult<Vec<(usize, F)>> {
657        self.points_within_radius(query, radius)
658    }
659
660    /// Find all points within a specified radius using an array view
661    ///
662    /// # Arguments
663    ///
664    /// * `query` - Query point coordinates as an array view
665    /// * `radius` - Search radius
666    ///
667    /// # Returns
668    ///
669    /// Vector of (point_index, distance) tuples for all points within radius
670    pub fn radius_neighbors_view(
671        &self,
672        query: &scirs2_core::ndarray::ArrayView1<F>,
673        radius: F,
674    ) -> InterpolateResult<Vec<(usize, F)>> {
675        let query_slice = query.as_slice().ok_or_else(|| {
676            InterpolateError::InvalidValue("Query must be contiguous".to_string())
677        })?;
678        self.points_within_radius(query_slice, radius)
679    }
680
681    /// Enhanced k-nearest neighbor search with distance bounds optimization
682    ///
683    /// This method provides improved performance for k-NN queries by using
684    /// tighter distance bounds and more efficient ball pruning strategies.
685    ///
686    /// # Arguments
687    ///
688    /// * `query` - Query point coordinates
689    /// * `k` - Number of nearest neighbors to find
690    /// * `max_distance` - Optional maximum search distance for early termination
691    ///
692    /// # Returns
693    ///
694    /// Vector of (point_index, distance) tuples, sorted by distance
695    pub fn k_nearest_neighbors_optimized(
696        &self,
697        query: &[F],
698        k: usize,
699        max_distance: Option<F>,
700    ) -> InterpolateResult<Vec<(usize, F)>> {
701        // Check query dimension
702        if query.len() != self.dim {
703            return Err(InterpolateError::DimensionMismatch(format!(
704                "Query dimension {} doesn't match Ball Tree dimension {}",
705                query.len(),
706                self.dim
707            )));
708        }
709
710        // Handle empty tree
711        if self.root.is_none() {
712            return Err(InterpolateError::InvalidState(
713                "Ball Tree is empty".to_string(),
714            ));
715        }
716
717        // Limit k to the number of points
718        let k = k.min(self.points.shape()[0]);
719
720        if k == 0 {
721            return Ok(Vec::new());
722        }
723
724        // Very small trees (just use linear search)
725        if self.points.shape()[0] <= self.leafsize {
726            return self.linear_k_nearest_neighbors_optimized(query, k, max_distance);
727        }
728
729        use std::collections::BinaryHeap;
730
731        let mut heap = BinaryHeap::with_capacity(k + 1);
732        let mut search_radius =
733            max_distance.unwrap_or(<F as scirs2_core::numeric::Float>::infinity());
734
735        // Start recursive search with adaptive bounds
736        self.search_k_nearest_optimized(
737            self.root.unwrap(),
738            query,
739            k,
740            &mut heap,
741            &mut search_radius,
742        );
743
744        // Convert heap to sorted vector
745        let mut results: Vec<(usize, F)> = heap
746            .into_iter()
747            .map(|(dist, idx)| (idx, dist.into_inner()))
748            .collect();
749
750        // Sort by _distance (since heap gives reverse order)
751        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
752
753        Ok(results)
754    }
755
756    /// Optimized linear k-nearest neighbors search with early termination
757    fn linear_k_nearest_neighbors_optimized(
758        &self,
759        query: &[F],
760        k: usize,
761        max_distance: Option<F>,
762    ) -> InterpolateResult<Vec<(usize, F)>> {
763        let n_points = self.points.shape()[0];
764        let k = k.min(n_points);
765        let max_dist = max_distance.unwrap_or(<F as scirs2_core::numeric::Float>::infinity());
766
767        let mut distances: Vec<(usize, F)> = Vec::with_capacity(n_points);
768
769        for i in 0..n_points {
770            let point = self.points.row(i);
771            let dist = euclidean_distance(query, &point.to_vec());
772
773            // Early termination if _distance exceeds maximum
774            if dist <= max_dist {
775                distances.push((i, dist));
776            }
777        }
778
779        // Sort by _distance
780        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
781
782        // Return k nearest within max _distance
783        distances.truncate(k);
784        Ok(distances)
785    }
786
787    /// Optimized recursive k-nearest search with enhanced ball pruning
788    #[allow(clippy::type_complexity)]
789    fn search_k_nearest_optimized(
790        &self,
791        node_idx: usize,
792        query: &[F],
793        k: usize,
794        heap: &mut std::collections::BinaryHeap<(OrderedFloat<F>, usize)>,
795        search_radius: &mut F,
796    ) {
797        let node = &self.nodes[node_idx];
798
799        // Calculate distance from query to ball center
800        let center_dist = euclidean_distance(query, &node.center);
801
802        // Enhanced distance bounds checking
803        let min_possible_dist = if center_dist > node.radius {
804            center_dist - node.radius
805        } else {
806            F::zero()
807        };
808
809        // Get current kth distance for pruning
810        let kth_dist = if heap.len() < k {
811            *search_radius
812        } else {
813            match heap.peek() {
814                Some(&(dist_, _)) => dist_.into_inner(),
815                None => *search_radius,
816            }
817        };
818
819        // Early pruning: if the closest possible point in this ball is farther than kth distance
820        if min_possible_dist > kth_dist {
821            return;
822        }
823
824        // If this is a leaf node, check all points
825        if node.left.is_none() && node.right.is_none() {
826            for &idx in &node.indices {
827                let point = self.points.row(idx);
828                let dist = euclidean_distance(query, &point.to_vec());
829
830                // Only consider points within search _radius
831                if dist <= *search_radius {
832                    heap.push((OrderedFloat(dist), idx));
833
834                    // If heap is too large, remove the farthest point and update search _radius
835                    if heap.len() > k {
836                        heap.pop();
837                    }
838
839                    // Update search _radius to the farthest point in current k-nearest set
840                    if heap.len() == k {
841                        if let Some(&(max_dist_, _)) = heap.peek() {
842                            *search_radius = max_dist_.into_inner();
843                        }
844                    }
845                }
846            }
847            return;
848        }
849
850        // Process children with improved ordering strategy
851        let left_idx = node.left.unwrap();
852        let right_idx = node.right.unwrap();
853
854        let left_node = &self.nodes[left_idx];
855        let right_node = &self.nodes[right_idx];
856
857        // Calculate minimum possible distances to each child ball
858        let left_center_dist = euclidean_distance(query, &left_node.center);
859        let right_center_dist = euclidean_distance(query, &right_node.center);
860
861        let left_min_dist = if left_center_dist > left_node.radius {
862            left_center_dist - left_node.radius
863        } else {
864            F::zero()
865        };
866
867        let right_min_dist = if right_center_dist > right_node.radius {
868            right_center_dist - right_node.radius
869        } else {
870            F::zero()
871        };
872
873        // Order children by minimum possible distance
874        let (first_idx, second_idx, second_min_dist) = if left_min_dist < right_min_dist {
875            (left_idx, right_idx, right_min_dist)
876        } else {
877            (right_idx, left_idx, left_min_dist)
878        };
879
880        // Search the closer child first
881        self.search_k_nearest_optimized(first_idx, query, k, heap, search_radius);
882
883        // Re-check pruning condition for second child after first search
884        let updated_kth_dist = if heap.len() < k {
885            *search_radius
886        } else {
887            match heap.peek() {
888                Some(&(dist_, _)) => dist_.into_inner(),
889                None => *search_radius,
890            }
891        };
892
893        // Only search second child if it could contain better points
894        if second_min_dist <= updated_kth_dist {
895            self.search_k_nearest_optimized(second_idx, query, k, heap, search_radius);
896        }
897    }
898
899    /// Approximate nearest neighbor search with controlled accuracy
900    ///
901    /// This method provides faster approximate nearest neighbor search by
902    /// exploring only a limited number of branches in the tree.
903    ///
904    /// # Arguments
905    ///
906    /// * `query` - Query point coordinates
907    /// * `k` - Number of nearest neighbors to find
908    /// * `max_checks` - Maximum number of distance computations
909    ///
910    /// # Returns
911    ///
912    /// Vector of (point_index, distance) tuples, sorted by distance
913    pub fn approximate_k_nearest_neighbors(
914        &self,
915        query: &[F],
916        k: usize,
917        max_checks: usize,
918    ) -> InterpolateResult<Vec<(usize, F)>> {
919        // Check query dimension
920        if query.len() != self.dim {
921            return Err(InterpolateError::DimensionMismatch(format!(
922                "Query dimension {} doesn't match Ball Tree dimension {}",
923                query.len(),
924                self.dim
925            )));
926        }
927
928        // Handle empty tree
929        if self.root.is_none() {
930            return Err(InterpolateError::InvalidState(
931                "Ball Tree is empty".to_string(),
932            ));
933        }
934
935        // Limit k to the number of points
936        let k = k.min(self.points.shape()[0]);
937
938        if k == 0 {
939            return Ok(Vec::new());
940        }
941
942        // For very small trees or large max_checks, use exact search
943        if self.points.shape()[0] <= self.leafsize || max_checks >= self.points.shape()[0] {
944            return self.k_nearest_neighbors(query, k);
945        }
946
947        use std::collections::{BinaryHeap, VecDeque};
948
949        let mut heap = BinaryHeap::with_capacity(k + 1);
950        let mut checks_performed = 0;
951        let mut nodes_to_visit = VecDeque::new();
952
953        // Start with root
954        nodes_to_visit.push_back((self.root.unwrap(), F::zero()));
955
956        while let Some((node_idx, _min_dist)) = nodes_to_visit.pop_front() {
957            if checks_performed >= max_checks {
958                break;
959            }
960
961            let node = &self.nodes[node_idx];
962
963            // Calculate distance from query to ball center
964            let _center_dist = euclidean_distance(query, &node.center);
965
966            // If this is a leaf node, check all points
967            if node.left.is_none() && node.right.is_none() {
968                for &idx in &node.indices {
969                    if checks_performed >= max_checks {
970                        break;
971                    }
972
973                    let point = self.points.row(idx);
974                    let dist = euclidean_distance(query, &point.to_vec());
975                    checks_performed += 1;
976
977                    heap.push((OrderedFloat(dist), idx));
978
979                    // If heap is too large, remove the farthest point
980                    if heap.len() > k {
981                        heap.pop();
982                    }
983                }
984            } else {
985                // Add children to queue, prioritizing by distance
986                if let Some(left_idx) = node.left {
987                    let left_node = &self.nodes[left_idx];
988                    let left_center_dist = euclidean_distance(query, &left_node.center);
989                    let left_min_dist = if left_center_dist > left_node.radius {
990                        left_center_dist - left_node.radius
991                    } else {
992                        F::zero()
993                    };
994
995                    nodes_to_visit.push_back((left_idx, left_min_dist));
996                }
997
998                if let Some(right_idx) = node.right {
999                    let right_node = &self.nodes[right_idx];
1000                    let right_center_dist = euclidean_distance(query, &right_node.center);
1001                    let right_min_dist = if right_center_dist > right_node.radius {
1002                        right_center_dist - right_node.radius
1003                    } else {
1004                        F::zero()
1005                    };
1006
1007                    nodes_to_visit.push_back((right_idx, right_min_dist));
1008                }
1009
1010                // Sort queue by minimum distance to prioritize closer nodes
1011                nodes_to_visit
1012                    .make_contiguous()
1013                    .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
1014            }
1015        }
1016
1017        // Convert heap to sorted vector
1018        let mut results: Vec<(usize, F)> = heap
1019            .into_iter()
1020            .map(|(dist, idx)| (idx, dist.into_inner()))
1021            .collect();
1022
1023        // Sort by distance
1024        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
1025
1026        Ok(results)
1027    }
1028}
1029
1030/// Compute the centroid (center) of a set of points
1031#[allow(dead_code)]
1032fn compute_centroid<F: Float + FromPrimitive>(points: &Array2<F>, indices: &[usize]) -> Vec<F> {
1033    let n_points = indices.len();
1034    let n_dims = points.shape()[1];
1035
1036    if n_points == 0 {
1037        return vec![F::zero(); n_dims];
1038    }
1039
1040    let mut center = vec![F::zero(); n_dims];
1041
1042    // Sum all point coordinates
1043    for &idx in indices {
1044        let point = points.row(idx);
1045        for d in 0..n_dims {
1046            center[d] = center[d] + point[d];
1047        }
1048    }
1049
1050    // Divide by number of _points
1051    let n = F::from_usize(n_points).unwrap();
1052    for val in center.iter_mut() {
1053        *val = *val / n;
1054    }
1055
1056    center
1057}
1058
1059/// Compute the radius of a ball containing all points
1060#[allow(dead_code)]
1061fn compute_radius<F: Float>(points: &Array2<F>, indices: &[usize], center: &[F]) -> F {
1062    let n_points = indices.len();
1063
1064    if n_points == 0 {
1065        return F::zero();
1066    }
1067
1068    let mut max_dist = F::zero();
1069
1070    // Find the maximum distance from center to any point
1071    for &idx in indices {
1072        let point = points.row(idx);
1073        let dist = euclidean_distance(&point.to_vec(), center);
1074
1075        if dist > max_dist {
1076            max_dist = dist;
1077        }
1078    }
1079
1080    max_dist
1081}
1082
1083/// Find the dimension with the largest spread of values
1084#[allow(dead_code)]
1085fn find_max_spread_dimension<F: Float>(points: &Array2<F>, indices: &[usize]) -> (usize, F) {
1086    let n_points = indices.len();
1087    let n_dims = points.shape()[1];
1088
1089    if n_points <= 1 {
1090        return (0, F::zero());
1091    }
1092
1093    let mut max_dim = 0;
1094    let mut max_spread = F::neg_infinity();
1095
1096    // For each dimension, find the range of values
1097    for d in 0..n_dims {
1098        let mut min_val = F::infinity();
1099        let mut max_val = F::neg_infinity();
1100
1101        for &idx in indices {
1102            let val = points[[idx, d]];
1103
1104            if val < min_val {
1105                min_val = val;
1106            }
1107
1108            if val > max_val {
1109                max_val = val;
1110            }
1111        }
1112
1113        let spread = max_val - min_val;
1114
1115        if spread > max_spread {
1116            max_spread = spread;
1117            max_dim = d;
1118        }
1119    }
1120
1121    (max_dim, max_spread)
1122}
1123
1124/// Find two points that are far apart along a given dimension
1125#[allow(dead_code)]
1126fn find_distant_points<F: Float>(
1127    points: &Array2<F>,
1128    indices: &[usize],
1129    dim: usize,
1130) -> (usize, usize) {
1131    let n_points = indices.len();
1132
1133    if n_points <= 1 {
1134        return (indices[0], indices[0]);
1135    }
1136
1137    // Find min and max points along the dimension
1138    let mut min_idx = indices[0];
1139    let mut max_idx = indices[0];
1140    let mut min_val = points[[min_idx, dim]];
1141    let mut max_val = min_val;
1142
1143    for &idx in indices.iter().skip(1) {
1144        let val = points[[idx, dim]];
1145
1146        if val < min_val {
1147            min_val = val;
1148            min_idx = idx;
1149        }
1150
1151        if val > max_val {
1152            max_val = val;
1153            max_idx = idx;
1154        }
1155    }
1156
1157    (min_idx, max_idx)
1158}
1159
1160/// Partition points based on which of two seed points they're closer to
1161#[allow(dead_code)]
1162fn partition_by_seeds<F: Float>(
1163    points: &Array2<F>,
1164    indices: &[usize],
1165    seed1: usize,
1166    seed2: usize,
1167) -> (Vec<usize>, Vec<usize>) {
1168    let seed1_point = points.row(seed1).to_vec();
1169    let seed2_point = points.row(seed2).to_vec();
1170
1171    let mut left_indices = Vec::new();
1172    let mut right_indices = Vec::new();
1173
1174    // Always include the seeds in their respective partitions
1175    left_indices.push(seed1);
1176    right_indices.push(seed2);
1177
1178    // Partition the remaining points
1179    for &idx in indices {
1180        if idx == seed1 || idx == seed2 {
1181            continue; // Skip the seeds
1182        }
1183
1184        let point = points.row(idx).to_vec();
1185        let dist1 = euclidean_distance(&point, &seed1_point);
1186        let dist2 = euclidean_distance(&point, &seed2_point);
1187
1188        if dist1 <= dist2 {
1189            left_indices.push(idx);
1190        } else {
1191            right_indices.push(idx);
1192        }
1193    }
1194
1195    // If one partition is empty, move some points from the other
1196    if left_indices.is_empty() && right_indices.len() >= 2 {
1197        left_indices.push(right_indices.pop().unwrap());
1198    } else if right_indices.is_empty() && left_indices.len() >= 2 {
1199        right_indices.push(left_indices.pop().unwrap());
1200    }
1201
1202    (left_indices, right_indices)
1203}
1204
1205/// Calculate Euclidean distance between two points
1206#[allow(dead_code)]
1207fn euclidean_distance<F: Float>(a: &[F], b: &[F]) -> F {
1208    debug_assert_eq!(a.len(), b.len());
1209
1210    let mut sum_sq = F::zero();
1211
1212    for i in 0..a.len() {
1213        let diff = a[i] - b[i];
1214        sum_sq = sum_sq + diff * diff;
1215    }
1216
1217    sum_sq.sqrt()
1218}
1219
1220#[cfg(test)]
1221mod tests {
1222    use super::*;
1223    use scirs2_core::ndarray::arr2;
1224
1225    #[test]
1226    fn test_balltree_creation() {
1227        // Create a simple 3D dataset
1228        let points = arr2(&[
1229            [0.0, 0.0, 0.0],
1230            [1.0, 0.0, 0.0],
1231            [0.0, 1.0, 0.0],
1232            [0.0, 0.0, 1.0],
1233            [0.5, 0.5, 0.5],
1234        ]);
1235
1236        let balltree = BallTree::new(points).unwrap();
1237
1238        // Check tree properties
1239        assert_eq!(balltree.len(), 5);
1240        assert_eq!(balltree.dim(), 3);
1241        assert!(!balltree.is_empty());
1242    }
1243
1244    #[test]
1245    fn test_nearest_neighbor() {
1246        // Create a simple 3D dataset
1247        let points = arr2(&[
1248            [0.0, 0.0, 0.0],
1249            [1.0, 0.0, 0.0],
1250            [0.0, 1.0, 0.0],
1251            [0.0, 0.0, 1.0],
1252            [0.5, 0.5, 0.5],
1253        ]);
1254
1255        let balltree = BallTree::new(points).unwrap();
1256
1257        // Test exact matches
1258        for i in 0..5 {
1259            let point = balltree.points().row(i).to_vec();
1260            let (idx, dist) = balltree.nearest_neighbor(&point).unwrap();
1261            assert_eq!(idx, i);
1262            assert!(dist < 1e-10);
1263        }
1264
1265        // Test near matches
1266        let query = vec![0.6, 0.6, 0.6];
1267        let (idx, _) = balltree.nearest_neighbor(&query).unwrap();
1268        assert_eq!(idx, 4); // Should be closest to (0.5, 0.5, 0.5)
1269
1270        let query = vec![0.9, 0.1, 0.1];
1271        let (idx, _) = balltree.nearest_neighbor(&query).unwrap();
1272        assert_eq!(idx, 1); // Should be closest to (1.0, 0.0, 0.0)
1273    }
1274
1275    #[test]
1276    fn test_k_nearest_neighbors() {
1277        // Create a simple 3D dataset
1278        let points = arr2(&[
1279            [0.0, 0.0, 0.0],
1280            [1.0, 0.0, 0.0],
1281            [0.0, 1.0, 0.0],
1282            [0.0, 0.0, 1.0],
1283            [0.5, 0.5, 0.5],
1284        ]);
1285
1286        let balltree = BallTree::new(points).unwrap();
1287
1288        // Test at point (0.6, 0.6, 0.6)
1289        let query = vec![0.6, 0.6, 0.6];
1290
1291        // Get 3 nearest neighbors
1292        let neighbors = balltree.k_nearest_neighbors(&query, 3).unwrap();
1293
1294        // Should include (0.5, 0.5, 0.5) as the closest
1295        assert_eq!(neighbors.len(), 3);
1296        assert_eq!(neighbors[0].0, 4); // (0.5, 0.5, 0.5) should be first
1297    }
1298
1299    #[test]
1300    fn test_points_within_radius() {
1301        // Create a simple 3D dataset
1302        let points = arr2(&[
1303            [0.0, 0.0, 0.0],
1304            [1.0, 0.0, 0.0],
1305            [0.0, 1.0, 0.0],
1306            [0.0, 0.0, 1.0],
1307            [0.5, 0.5, 0.5],
1308        ]);
1309
1310        let balltree = BallTree::new(points).unwrap();
1311
1312        // Test at origin with radius 0.7
1313        let query = vec![0.0, 0.0, 0.0];
1314        let radius = 0.7;
1315
1316        let results = balltree.points_within_radius(&query, radius).unwrap();
1317
1318        // Should include the origin and possibly (0.5, 0.5, 0.5) depending on threshold
1319        assert!(!results.is_empty());
1320        assert_eq!(results[0].0, 0); // Origin should be first
1321    }
1322}