use crate::DType;
use numr::error::Result;
use numr::runtime::Runtime;
use numr::tensor::Tensor;
use super::distance::DistanceMetric;
#[derive(Debug, Clone)]
pub struct KDTreeOptions {
pub leaf_size: usize,
pub metric: DistanceMetric,
}
impl Default for KDTreeOptions {
fn default() -> Self {
Self {
leaf_size: 10,
metric: DistanceMetric::Euclidean,
}
}
}
#[derive(Debug, Clone)]
pub struct KDTree<R: Runtime<DType = DType>> {
pub data: Tensor<R>,
pub split_dims: Tensor<R>,
pub split_values: Tensor<R>,
pub left_children: Tensor<R>,
pub right_children: Tensor<R>,
pub point_indices: Tensor<R>,
pub leaf_starts: Tensor<R>,
pub leaf_sizes: Tensor<R>,
pub options: KDTreeOptions,
}
#[derive(Debug, Clone)]
pub struct KNNResult<R: Runtime<DType = DType>> {
pub distances: Tensor<R>,
pub indices: Tensor<R>,
}
#[derive(Debug, Clone)]
pub struct RadiusResult<R: Runtime<DType = DType>> {
pub distances: Tensor<R>,
pub indices: Tensor<R>,
pub counts: Tensor<R>,
pub offsets: Tensor<R>,
}
pub trait KDTreeAlgorithms<R: Runtime<DType = DType>> {
fn kdtree_build(&self, points: &Tensor<R>, options: KDTreeOptions) -> Result<KDTree<R>>;
fn kdtree_query(&self, tree: &KDTree<R>, query: &Tensor<R>, k: usize) -> Result<KNNResult<R>>;
fn kdtree_query_radius(
&self,
tree: &KDTree<R>,
query: &Tensor<R>,
radius: f64,
) -> Result<RadiusResult<R>>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kdtree_options_default() {
let opts = KDTreeOptions::default();
assert_eq!(opts.leaf_size, 10);
assert_eq!(opts.metric, DistanceMetric::Euclidean);
}
}