use crate::Region;
use vicinity::hnsw::HNSWIndex;
pub struct RegionIndex<R: Region> {
hnsw: HNSWIndex,
regions: Vec<R>,
ids: Vec<u32>,
built: bool,
}
pub struct IndexParams {
pub m: usize,
pub m_max: usize,
pub ef_construction: usize,
}
impl Default for IndexParams {
fn default() -> Self {
Self {
m: 16,
m_max: 32,
ef_construction: 200,
}
}
}
pub struct SearchParams {
pub ef: usize,
pub overretrieve: usize,
}
impl Default for SearchParams {
fn default() -> Self {
Self {
ef: 200,
overretrieve: 10,
}
}
}
pub type SearchResult = (u32, f32);
impl<R: Region> RegionIndex<R> {
pub fn new(dim: usize, params: IndexParams) -> Result<Self, String> {
let hnsw = HNSWIndex::builder(dim)
.m(params.m)
.m_max(params.m_max)
.ef_construction(params.ef_construction)
.auto_normalize(true)
.build()
.map_err(|e: vicinity::RetrieveError| e.to_string())?;
Ok(Self {
hnsw,
regions: Vec::new(),
ids: Vec::new(),
built: false,
})
}
pub fn add(&mut self, id: u32, region: R) {
let center = region.center().to_vec();
self.hnsw
.add(id, center)
.expect("failed to add center to HNSW");
self.regions.push(region);
self.ids.push(id);
self.built = false;
}
pub fn build(&mut self) -> Result<(), String> {
self.hnsw.build().map_err(|e| e.to_string())?;
self.built = true;
Ok(())
}
pub fn search(
&self,
query: &[f32],
k: usize,
params: SearchParams,
) -> Result<Vec<SearchResult>, String> {
if !self.built {
return Err("index must be built before search".into());
}
let fetch_k = k.saturating_mul(params.overretrieve).max(k);
let candidates = self
.hnsw
.search(query, fetch_k, params.ef)
.map_err(|e| e.to_string())?;
let mut reranked: Vec<SearchResult> = candidates
.into_iter()
.map(|(doc_id, _center_dist)| {
let region = self.region_by_id(doc_id);
let true_dist = region.distance_to_point(query);
(doc_id, true_dist)
})
.collect();
reranked.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
reranked.truncate(k);
Ok(reranked)
}
pub fn containing(&self, point: &[f32], params: SearchParams) -> Result<Vec<u32>, String> {
if !self.built {
return Err("index must be built before search".into());
}
let fetch_k = self.regions.len().min(params.ef * params.overretrieve);
let candidates = self
.hnsw
.search(point, fetch_k, params.ef)
.map_err(|e| e.to_string())?;
let result: Vec<u32> = candidates
.into_iter()
.filter(|(doc_id, _)| self.region_by_id(*doc_id).contains(point))
.map(|(doc_id, _)| doc_id)
.collect();
Ok(result)
}
pub fn containing_exhaustive(&self, point: &[f32]) -> Vec<u32> {
self.ids
.iter()
.zip(self.regions.iter())
.filter(|(_, r)| r.contains(point))
.map(|(id, _)| *id)
.collect()
}
pub fn search_exhaustive(&self, query: &[f32], k: usize) -> Vec<SearchResult> {
let mut results: Vec<SearchResult> = self
.ids
.iter()
.zip(self.regions.iter())
.map(|(id, r)| (*id, r.distance_to_point(query)))
.collect();
results.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(k);
results
}
pub fn len(&self) -> usize {
self.regions.len()
}
pub fn is_empty(&self) -> bool {
self.regions.is_empty()
}
fn region_by_id(&self, doc_id: u32) -> &R {
let pos = self
.ids
.iter()
.position(|&id| id == doc_id)
.expect("doc_id not found in region index");
&self.regions[pos]
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::AxisBox;
#[test]
fn search_finds_nearest_box() {
let mut idx = RegionIndex::new(2, Default::default()).unwrap();
idx.add(0, AxisBox::new(vec![0.0, 0.0], vec![1.0, 1.0]));
idx.add(1, AxisBox::new(vec![10.0, 10.0], vec![11.0, 11.0]));
idx.add(2, AxisBox::new(vec![5.0, 5.0], vec![6.0, 6.0]));
idx.build().unwrap();
let results = idx.search(&[0.5, 0.5], 1, Default::default()).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, 0);
assert_eq!(results[0].1, 0.0); }
#[test]
fn search_reranks_correctly() {
let mut idx = RegionIndex::new(3, Default::default()).unwrap();
for i in 0..50 {
let offset = i as f32;
idx.add(
i,
AxisBox::new(
vec![offset, offset, offset],
vec![offset + 1.0, offset + 1.0, offset + 1.0],
),
);
}
idx.build().unwrap();
let query = [3.5, 3.5, 3.5];
let results = idx
.search(&query, 5, SearchParams { ef: 100, overretrieve: 10 })
.unwrap();
assert_eq!(results[0].0, 3);
assert_eq!(results[0].1, 0.0);
for w in results.windows(2) {
assert!(w[0].1 <= w[1].1, "results not sorted: {:?}", results);
}
}
#[test]
fn containing_finds_enclosing_boxes() {
let mut idx = RegionIndex::new(2, Default::default()).unwrap();
idx.add(0, AxisBox::new(vec![0.0, 0.0], vec![10.0, 10.0])); idx.add(1, AxisBox::new(vec![4.0, 4.0], vec![6.0, 6.0])); idx.add(2, AxisBox::new(vec![20.0, 20.0], vec![21.0, 21.0])); idx.build().unwrap();
let point = [5.0, 5.0];
let result = idx.containing_exhaustive(&point);
assert!(result.contains(&0));
assert!(result.contains(&1));
assert!(!result.contains(&2));
}
}