use crate::DType;
use numr::error::Result;
use numr::runtime::Runtime;
use numr::tensor::Tensor;
use super::distance::DistanceMetric;
use super::kdtree::{KNNResult, RadiusResult};
#[derive(Debug, Clone)]
pub struct BallTreeOptions {
pub leaf_size: usize,
pub metric: DistanceMetric,
}
impl Default for BallTreeOptions {
fn default() -> Self {
Self {
leaf_size: 40,
metric: DistanceMetric::Euclidean,
}
}
}
#[derive(Debug, Clone)]
pub struct BallTree<R: Runtime<DType = DType>> {
pub data: Tensor<R>,
pub centers: Tensor<R>,
pub radii: Tensor<R>,
pub left_children: Tensor<R>,
pub right_children: Tensor<R>,
pub point_indices: Tensor<R>,
pub leaf_starts: Tensor<R>,
pub leaf_sizes: Tensor<R>,
pub options: BallTreeOptions,
}
pub trait BallTreeAlgorithms<R: Runtime<DType = DType>> {
fn balltree_build(&self, points: &Tensor<R>, options: BallTreeOptions) -> Result<BallTree<R>>;
fn balltree_query(
&self,
tree: &BallTree<R>,
query: &Tensor<R>,
k: usize,
) -> Result<KNNResult<R>>;
fn balltree_query_radius(
&self,
tree: &BallTree<R>,
query: &Tensor<R>,
radius: f64,
) -> Result<RadiusResult<R>>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_balltree_options_default() {
let opts = BallTreeOptions::default();
assert_eq!(opts.leaf_size, 40);
assert_eq!(opts.metric, DistanceMetric::Euclidean);
}
}