use super::tree::{GpuKdTree, KdQueryResult};
use crate::error::InterpolateResult;
#[derive(Debug, Clone)]
pub struct KdTreeConfig {
pub gpu_threshold_points: usize,
pub gpu_threshold_queries: usize,
}
impl Default for KdTreeConfig {
fn default() -> Self {
Self {
gpu_threshold_points: 100_000,
gpu_threshold_queries: 1_000,
}
}
}
pub fn knn_auto_dispatch(
tree: &GpuKdTree,
queries: &[Vec<f64>],
k: usize,
config: &KdTreeConfig,
) -> InterpolateResult<Vec<KdQueryResult>> {
#[cfg(feature = "gpu_kdtree")]
if tree.n_points() >= config.gpu_threshold_points
&& queries.len() >= config.gpu_threshold_queries
{
match super::wgpu_linear_scan::knn_wgpu(tree, queries, k) {
Ok(results) => return Ok(results),
Err(_) => {}
}
}
let _ = config;
tree.knn_batch(queries, k)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::gpu_kdtree::GpuKdTree;
#[test]
fn test_knn_auto_dispatch_cpu_below_threshold() {
let pts: Vec<Vec<f64>> = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.5, 0.5]];
let tree = GpuKdTree::new(pts).expect("build");
let queries = vec![vec![0.4_f64, 0.4]];
let cfg = KdTreeConfig::default();
let results = knn_auto_dispatch(&tree, &queries, 1, &cfg).expect("dispatch");
assert_eq!(results.len(), 1);
assert_eq!(results[0].indices[0], 2);
}
#[test]
fn test_knn_auto_dispatch_returns_correct_count() {
let pts: Vec<Vec<f64>> = (0..10).map(|i| vec![i as f64, 0.0]).collect();
let tree = GpuKdTree::new(pts).expect("build");
let queries: Vec<Vec<f64>> = (0..5).map(|i| vec![i as f64 + 0.1, 0.0]).collect();
let cfg = KdTreeConfig {
gpu_threshold_points: usize::MAX,
gpu_threshold_queries: usize::MAX,
};
let results = knn_auto_dispatch(&tree, &queries, 3, &cfg).expect("dispatch");
assert_eq!(results.len(), 5);
for r in &results {
assert_eq!(r.indices.len(), 3);
assert_eq!(r.distances_sq.len(), 3);
}
}
#[test]
fn test_knn_auto_dispatch_empty_queries() {
let pts = vec![vec![1.0_f64, 2.0]];
let tree = GpuKdTree::new(pts).expect("build");
let results = knn_auto_dispatch(&tree, &[], 1, &KdTreeConfig::default())
.expect("dispatch empty queries");
assert!(results.is_empty());
}
#[test]
fn test_kdtree_config_default_thresholds() {
let cfg = KdTreeConfig::default();
assert_eq!(cfg.gpu_threshold_points, 100_000);
assert_eq!(cfg.gpu_threshold_queries, 1_000);
}
}