use std::collections::HashMap;
use crate::Region;
use vicinity::hnsw::HNSWIndex;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("index must be built before search")]
NotBuilt,
#[error("vicinity: {0}")]
Vicinity(#[from] vicinity::RetrieveError),
}
pub struct RegionIndex<R: Region> {
hnsw: HNSWIndex,
regions: Vec<R>,
id_to_pos: HashMap<u32, usize>,
built: bool,
}
pub struct IndexParams {
pub m: usize,
pub m_max: usize,
pub ef_construction: usize,
}
impl Default for IndexParams {
fn default() -> Self {
Self {
m: 16,
m_max: 32,
ef_construction: 200,
}
}
}
pub struct SearchParams {
pub ef: usize,
pub overretrieve: usize,
}
impl Default for SearchParams {
fn default() -> Self {
Self {
ef: 200,
overretrieve: 10,
}
}
}
pub type SearchResult = (u32, f32);
impl<R: Region> RegionIndex<R> {
pub fn new(dim: usize, params: IndexParams) -> Result<Self, Error> {
let hnsw = HNSWIndex::builder(dim)
.m(params.m)
.m_max(params.m_max)
.ef_construction(params.ef_construction)
.metric(vicinity::DistanceMetric::L2)
.build()?;
Ok(Self {
hnsw,
regions: Vec::new(),
id_to_pos: HashMap::new(),
built: false,
})
}
pub fn add(&mut self, id: u32, region: R) {
let center = region.center().to_vec();
self.hnsw
.add(id, center)
.expect("failed to add center to HNSW");
let pos = self.regions.len();
self.regions.push(region);
self.id_to_pos.insert(id, pos);
self.built = false;
}
pub fn build(&mut self) -> Result<(), Error> {
self.hnsw.build()?;
self.built = true;
Ok(())
}
#[must_use = "search results are not used"]
pub fn search(
&self,
query: &[f32],
k: usize,
params: SearchParams,
) -> Result<Vec<SearchResult>, Error> {
if !self.built {
return Err(Error::NotBuilt);
}
let fetch_k = k.saturating_mul(params.overretrieve).max(k);
let candidates = self.hnsw.search(query, fetch_k, params.ef)?;
let mut reranked: Vec<SearchResult> = candidates
.into_iter()
.map(|(doc_id, _center_dist)| {
let region = &self.regions[self.id_to_pos[&doc_id]];
let true_dist = region.distance_to_point(query);
(doc_id, true_dist)
})
.collect();
reranked.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
reranked.truncate(k);
Ok(reranked)
}
pub fn containing(&self, point: &[f32], params: SearchParams) -> Result<Vec<u32>, Error> {
if !self.built {
return Err(Error::NotBuilt);
}
let fetch_k = self.regions.len().min(params.ef * params.overretrieve);
let candidates = self.hnsw.search(point, fetch_k, params.ef)?;
let result: Vec<u32> = candidates
.into_iter()
.filter(|(doc_id, _)| self.regions[self.id_to_pos[doc_id]].contains(point))
.map(|(doc_id, _)| doc_id)
.collect();
Ok(result)
}
pub fn containing_exhaustive(&self, point: &[f32]) -> Vec<u32> {
self.id_to_pos
.iter()
.filter(|(_, &pos)| self.regions[pos].contains(point))
.map(|(&id, _)| id)
.collect()
}
pub fn search_exhaustive(&self, query: &[f32], k: usize) -> Vec<SearchResult> {
let mut results: Vec<SearchResult> = self
.id_to_pos
.iter()
.map(|(&id, &pos)| (id, self.regions[pos].distance_to_point(query)))
.collect();
results.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(k);
results
}
pub fn search_with_distance(
&self,
query: &[f32],
k: usize,
ef: usize,
dist_fn: &dyn Fn(&[f32], u32) -> f32,
) -> Result<Vec<SearchResult>, Error> {
if !self.built {
return Err(Error::NotBuilt);
}
Ok(self.hnsw.search_with_distance(query, k, ef, dist_fn)?)
}
pub fn get(&self, id: u32) -> Option<&R> {
self.id_to_pos.get(&id).map(|&pos| &self.regions[pos])
}
#[must_use]
pub fn len(&self) -> usize {
self.regions.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.regions.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::AxisBox;
fn build_test_index() -> RegionIndex<AxisBox> {
let mut idx = RegionIndex::new(3, Default::default()).unwrap();
for i in 0..20 {
let o = i as f32 * 2.0; idx.add(
i,
AxisBox::new(vec![o, o, o], vec![o + 1.0, o + 1.0, o + 1.0]),
);
}
idx.build().unwrap();
idx
}
#[test]
fn search_finds_nearest_box() {
let idx = build_test_index();
let results = idx.search(&[0.5, 0.5, 0.5], 1, Default::default()).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, 0);
assert_eq!(results[0].1, 0.0);
}
#[test]
fn search_reranks_correctly() {
let idx = build_test_index();
let query = [10.5, 10.5, 10.5];
let results = idx
.search(&query, 5, SearchParams { ef: 100, overretrieve: 10 })
.unwrap();
assert_eq!(results[0].0, 5);
assert_eq!(results[0].1, 0.0);
for w in results.windows(2) {
assert!(w[0].1 <= w[1].1, "results not sorted: {:?}", results);
}
}
#[test]
fn search_with_custom_distance() {
let idx = build_test_index();
let dist_fn = |q: &[f32], internal_id: u32| -> f32 {
let o = internal_id as f32 * 2.0;
let center = [o + 0.5, o + 0.5, o + 0.5];
center.iter().zip(q).map(|(c, p)| (c - p).powi(2)).sum::<f32>().sqrt()
};
let results = idx.search_with_distance(&[6.5, 6.5, 6.5], 3, 200, &dist_fn).unwrap();
assert_eq!(results.len(), 3);
for w in results.windows(2) {
assert!(w[0].1 <= w[1].1, "results not sorted: {:?}", results);
}
assert!(results[0].1 < 2.0, "closest result too far: {}", results[0].1);
}
#[test]
fn containing_exhaustive_finds_enclosing_boxes() {
let mut idx = RegionIndex::new(2, Default::default()).unwrap();
idx.add(0, AxisBox::new(vec![0.0, 0.0], vec![10.0, 10.0])); idx.add(1, AxisBox::new(vec![4.0, 4.0], vec![6.0, 6.0])); idx.add(2, AxisBox::new(vec![20.0, 20.0], vec![21.0, 21.0])); for i in 3..15 {
let o = (i as f32) * 3.0;
idx.add(i, AxisBox::new(vec![o, o], vec![o + 0.5, o + 0.5]));
}
idx.build().unwrap();
let result = idx.containing_exhaustive(&[5.0, 5.0]);
assert!(result.contains(&0));
assert!(result.contains(&1));
assert!(!result.contains(&2));
}
#[test]
fn get_returns_region() {
let mut idx = RegionIndex::new(2, Default::default()).unwrap();
idx.add(42, AxisBox::new(vec![0.0, 0.0], vec![1.0, 1.0]));
idx.build().unwrap();
assert!(idx.get(42).is_some());
assert!(idx.get(99).is_none());
}
#[test]
fn error_on_search_before_build() {
let idx: RegionIndex<AxisBox> = RegionIndex::new(2, Default::default()).unwrap();
assert!(idx.search(&[0.0, 0.0], 1, Default::default()).is_err());
}
}