use super::distance::DistanceMetric;
use super::flat::DataRef;
use rand::prelude::*;
pub(crate) struct ProjIndex<'a, T: DataRef + ?Sized, D: DistanceMetric> {
data: &'a T,
metric: &'a D,
axes: Vec<Vec<f32>>,
sorted_projs: Vec<Vec<(f32, usize)>>,
}
impl<'a, T: DataRef + ?Sized, D: DistanceMetric> ProjIndex<'a, T, D> {
pub(crate) fn new(data: &'a T, metric: &'a D, num_axes: usize) -> Self {
let n = data.n();
if n == 0 {
return Self {
data,
metric,
axes: Vec::new(),
sorted_projs: Vec::new(),
};
}
let d = data.d();
let mut rng = StdRng::seed_from_u64(12345);
let axes: Vec<Vec<f32>> = (0..num_axes)
.map(|_| {
let mut v: Vec<f32> = (0..d).map(|_| rng.random::<f32>() - 0.5).collect();
let norm: f32 = v.iter().map(|&x| x * x).sum::<f32>().sqrt();
if norm > f32::EPSILON {
for x in &mut v {
*x /= norm;
}
}
v
})
.collect();
let sorted_projs: Vec<Vec<(f32, usize)>> = axes
.iter()
.map(|axis| {
let mut projs: Vec<(f32, usize)> = (0..n)
.map(|i| {
let proj: f32 = data
.row(i)
.iter()
.zip(axis.iter())
.map(|(&x, &a)| x * a)
.sum();
(proj, i)
})
.collect();
projs.sort_by(|a, b| a.0.total_cmp(&b.0));
projs
})
.collect();
Self {
data,
metric,
axes,
sorted_projs,
}
}
pub(crate) fn range_query(&self, query: &[f32], radius: f32) -> Vec<usize> {
if self.sorted_projs.is_empty() {
return Vec::new();
}
let n = self.data.n();
let mut candidate_counts = vec![0u16; n];
let num_axes = self.axes.len();
for (axis_idx, axis) in self.axes.iter().enumerate() {
let proj_q: f32 = query.iter().zip(axis.iter()).map(|(&x, &a)| x * a).sum();
let lo = proj_q - radius;
let hi = proj_q + radius;
let sorted = &self.sorted_projs[axis_idx];
let start = sorted.partition_point(|&(v, _)| v < lo);
let end = sorted.partition_point(|&(v, _)| v <= hi);
for &(_, idx) in &sorted[start..end] {
candidate_counts[idx] += 1;
}
}
let mut results = Vec::new();
for (i, &count) in candidate_counts.iter().enumerate() {
if count == num_axes as u16 && self.metric.distance(query, self.data.row(i)) <= radius {
results.push(i);
}
}
results
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cluster::distance::Euclidean;
#[test]
fn range_query_correctness() {
let data = vec![
vec![0.0, 0.0],
vec![0.1, 0.1],
vec![10.0, 10.0],
vec![10.1, 10.1],
];
let index = ProjIndex::new(data.as_slice(), &Euclidean, 8);
let mut results = index.range_query(&[0.0, 0.0], 0.5);
results.sort();
assert!(results.contains(&0));
assert!(results.contains(&1));
assert!(!results.contains(&2));
}
#[test]
fn empty_data() {
let data: Vec<Vec<f32>> = vec![];
let index = ProjIndex::new(data.as_slice(), &Euclidean, 8);
assert!(index.range_query(&[0.0], 1.0).is_empty());
}
#[test]
fn high_dimensional() {
let mut rng = StdRng::seed_from_u64(42);
let data: Vec<Vec<f32>> = (0..100)
.map(|_| (0..20).map(|_| rng.random::<f32>()).collect())
.collect();
let index = ProjIndex::new(data.as_slice(), &Euclidean, 12);
let query = &data[0];
let radius = 1.0;
let mut brute: Vec<usize> = (0..100)
.filter(|&i| Euclidean.distance(query, &data[i]) <= radius)
.collect();
brute.sort();
let mut indexed = index.range_query(query, radius);
indexed.sort();
for &b in &brute {
assert!(indexed.contains(&b), "projection index missed neighbor {b}");
}
}
}