use glam::Vec3;
#[derive(Debug, Clone, Copy)]
pub enum DistanceMetric {
Euclidean,
Manhattan,
Chebyshev,
}
#[derive(Debug, Clone, Copy)]
pub enum BitsetOp {
And,
Or,
Xor,
}
pub trait ComputeBackend: Send + Sync {
fn distance(&self, a: &Vec3, b: &Vec3, metric: DistanceMetric) -> f32;
fn distance_batch(&self, points: &[Vec3], query: &Vec3, metric: DistanceMetric) -> Vec<f32>;
fn bitset_op(&self, a: &[u64], b: &[u64], op: BitsetOp) -> Vec<u64>;
}
pub struct CpuBackend;
impl ComputeBackend for CpuBackend {
fn distance(&self, a: &Vec3, b: &Vec3, metric: DistanceMetric) -> f32 {
let dx = (a.x - b.x).abs();
let dy = (a.y - b.y).abs();
let dz = (a.z - b.z).abs();
match metric {
DistanceMetric::Euclidean => (dx * dx + dy * dy + dz * dz).sqrt(),
DistanceMetric::Manhattan => dx + dy + dz,
DistanceMetric::Chebyshev => dx.max(dy).max(dz),
}
}
fn distance_batch(&self, points: &[Vec3], query: &Vec3, metric: DistanceMetric) -> Vec<f32> {
points
.iter()
.map(|p| self.distance(p, query, metric))
.collect()
}
fn bitset_op(&self, a: &[u64], b: &[u64], op: BitsetOp) -> Vec<u64> {
match op {
BitsetOp::And => a.iter().zip(b.iter()).map(|(x, y)| x & y).collect(),
BitsetOp::Or => a.iter().zip(b.iter()).map(|(x, y)| x | y).collect(),
BitsetOp::Xor => a.iter().zip(b.iter()).map(|(x, y)| x ^ y).collect(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_distance_metric_euclidean() {
let backend = CpuBackend;
let a = Vec3::new(0.0, 0.0, 0.0);
let b = Vec3::new(3.0, 4.0, 0.0);
let d = backend.distance(&a, &b, DistanceMetric::Euclidean);
assert!(
(d - 5.0).abs() < 1e-3,
"Euclidean distance should be 5.0, got {}",
d
);
}
#[test]
fn test_distance_metric_manhattan() {
let backend = CpuBackend;
let a = Vec3::new(0.0, 0.0, 0.0);
let b = Vec3::new(3.0, 4.0, 0.0);
let d = backend.distance(&a, &b, DistanceMetric::Manhattan);
assert!(
(d - 7.0).abs() < 1e-3,
"Manhattan distance should be 7.0, got {}",
d
);
}
#[test]
fn test_distance_metric_chebyshev() {
let backend = CpuBackend;
let a = Vec3::new(0.0, 0.0, 0.0);
let b = Vec3::new(3.0, 4.0, 5.0);
let d = backend.distance(&a, &b, DistanceMetric::Chebyshev);
assert!(
(d - 5.0).abs() < 1e-3,
"Chebyshev distance should be 5.0, got {}",
d
);
}
#[test]
fn test_distance_batch_manhattan() {
let backend = CpuBackend;
let query = Vec3::new(0.0, 0.0, 0.0);
let points = vec![Vec3::new(1.0, 0.0, 0.0), Vec3::new(3.0, 4.0, 0.0)];
let distances = backend.distance_batch(&points, &query, DistanceMetric::Manhattan);
assert_eq!(distances.len(), 2);
assert!((distances[0] - 1.0).abs() < 1e-3, "first should be 1.0");
assert!((distances[1] - 7.0).abs() < 1e-3, "second should be 7.0");
}
#[test]
fn test_distance_batch_chebyshev() {
let backend = CpuBackend;
let query = Vec3::new(0.0, 0.0, 0.0);
let points = vec![Vec3::new(1.0, 2.0, 3.0), Vec3::new(5.0, 4.0, 2.0)];
let distances = backend.distance_batch(&points, &query, DistanceMetric::Chebyshev);
assert_eq!(distances.len(), 2);
assert!((distances[0] - 3.0).abs() < 1e-3, "first should be 3.0");
assert!((distances[1] - 5.0).abs() < 1e-3, "second should be 5.0");
}
}