use crate::RetrieveError;
#[derive(Clone, Debug)]
pub struct EsgParams {
pub num_checkpoints: usize,
pub hnsw_m: usize,
pub hnsw_ef_construction: usize,
pub ef_search: usize,
}
impl Default for EsgParams {
fn default() -> Self {
Self {
num_checkpoints: 16,
hnsw_m: 16,
hnsw_ef_construction: 200,
ef_search: 100,
}
}
}
#[derive(Clone, Debug)]
struct AttributedPoint {
doc_id: u32,
attribute: f64,
}
pub struct EsgIndex {
dimension: usize,
params: EsgParams,
built: bool,
vectors: Vec<f32>,
num_vectors: usize,
sorted_points: Vec<AttributedPoint>,
staging: Vec<(u32, Vec<f32>, f64)>,
#[cfg(feature = "hnsw")]
full_index: Option<crate::hnsw::HNSWIndex>,
}
impl EsgIndex {
pub fn new(dimension: usize, params: EsgParams) -> Result<Self, RetrieveError> {
if dimension == 0 {
return Err(RetrieveError::InvalidParameter(
"dimension must be > 0".into(),
));
}
Ok(Self {
dimension,
params,
built: false,
vectors: Vec::new(),
num_vectors: 0,
sorted_points: Vec::new(),
staging: Vec::new(),
#[cfg(feature = "hnsw")]
full_index: None,
})
}
pub fn add(
&mut self,
doc_id: u32,
vector: Vec<f32>,
attribute: f64,
) -> Result<(), RetrieveError> {
if self.built {
return Err(RetrieveError::InvalidParameter(
"cannot add after build".into(),
));
}
if vector.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: vector.len(),
doc_dim: self.dimension,
});
}
self.staging.push((doc_id, vector, attribute));
self.num_vectors += 1;
Ok(())
}
#[cfg(feature = "hnsw")]
pub fn build(&mut self) -> Result<(), RetrieveError> {
if self.built {
return Ok(());
}
if self.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
self.staging.sort_unstable_by(|a, b| a.2.total_cmp(&b.2));
self.sorted_points = Vec::with_capacity(self.num_vectors);
self.vectors = Vec::with_capacity(self.num_vectors * self.dimension);
for (doc_id, vector, attribute) in self.staging.drain(..) {
self.sorted_points
.push(AttributedPoint { doc_id, attribute });
let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
self.vectors.extend(vector.iter().map(|x| x / norm));
} else {
self.vectors.extend_from_slice(&vector);
}
}
let mut hnsw = crate::hnsw::HNSWIndex::builder(self.dimension)
.m(self.params.hnsw_m)
.ef_construction(self.params.hnsw_ef_construction)
.auto_normalize(false)
.build()?;
for (rank, point) in self.sorted_points.iter().enumerate() {
let vec = self.get_vector(rank);
hnsw.add_slice(point.doc_id, vec)?;
}
hnsw.build()?;
self.full_index = Some(hnsw);
self.built = true;
Ok(())
}
#[cfg(feature = "hnsw")]
pub fn range_search(
&self,
query: &[f32],
k: usize,
lo: f64,
hi: f64,
) -> Result<Vec<(u32, f32)>, RetrieveError> {
if !self.built {
return Err(RetrieveError::InvalidParameter(
"index must be built before search".into(),
));
}
if query.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: query.len(),
doc_dim: self.dimension,
});
}
let hnsw = match self.full_index.as_ref() {
Some(h) => h,
None => {
return Err(RetrieveError::InvalidParameter(
"index must be built before search".into(),
))
}
};
let query_norm: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
let query_normalized: Vec<f32> = if query_norm > 1e-10 {
query.iter().map(|x| x / query_norm).collect()
} else {
query.to_vec()
};
let ef = self.params.ef_search.max(k);
let in_range = |doc_id: u32| -> bool {
self.sorted_points
.iter()
.any(|p| p.doc_id == doc_id && p.attribute >= lo && p.attribute <= hi)
};
let candidates = hnsw.search(&query_normalized, k * 4, ef * 2)?;
let mut results: Vec<(u32, f32)> = candidates
.into_iter()
.filter(|(doc_id, _)| in_range(*doc_id))
.take(k)
.collect();
results.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
results.truncate(k);
Ok(results)
}
#[cfg(feature = "hnsw")]
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(u32, f32)>, RetrieveError> {
let lo = f64::NEG_INFINITY;
let hi = f64::INFINITY;
self.range_search(query, k, lo, hi)
}
pub fn len(&self) -> usize {
self.num_vectors
}
pub fn is_empty(&self) -> bool {
self.num_vectors == 0
}
#[inline]
fn get_vector(&self, rank: usize) -> &[f32] {
let start = rank * self.dimension;
&self.vectors[start..start + self.dimension]
}
}
#[cfg(test)]
#[cfg(feature = "hnsw")]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
fn make_vector(dim: usize, seed: u32) -> Vec<f32> {
(0..dim)
.map(|i| (seed as f32 * 0.1 + i as f32 * 0.01).sin())
.collect()
}
#[test]
fn build_and_range_search() {
let dim = 16;
let mut index = EsgIndex::new(
dim,
EsgParams {
num_checkpoints: 4,
hnsw_m: 8,
hnsw_ef_construction: 50,
ef_search: 50,
},
)
.unwrap();
for i in 0..50u32 {
index.add(i, make_vector(dim, i), i as f64 * 2.0).unwrap();
}
index.build().unwrap();
let query = make_vector(dim, 15);
let results = index.range_search(&query, 5, 20.0, 60.0).unwrap();
for (doc_id, _) in &results {
let attr = *doc_id as f64 * 2.0;
assert!(
(20.0..=60.0).contains(&attr),
"doc_id {} has attribute {}, expected in [20, 60]",
doc_id,
attr
);
}
}
#[test]
fn full_range_search() {
let dim = 16;
let mut index = EsgIndex::new(
dim,
EsgParams {
num_checkpoints: 4,
hnsw_m: 8,
hnsw_ef_construction: 50,
ef_search: 50,
},
)
.unwrap();
for i in 0..30u32 {
index.add(i, make_vector(dim, i), i as f64).unwrap();
}
index.build().unwrap();
let query = make_vector(dim, 0);
let results = index.search(&query, 5).unwrap();
assert!(!results.is_empty());
}
#[test]
fn narrow_range_returns_subset() {
let dim = 16;
let mut index = EsgIndex::new(
dim,
EsgParams {
num_checkpoints: 4,
hnsw_m: 8,
hnsw_ef_construction: 50,
ef_search: 50,
},
)
.unwrap();
for i in 0..40u32 {
index.add(i, make_vector(dim, i), i as f64).unwrap();
}
index.build().unwrap();
let query = make_vector(dim, 10);
let results = index.range_search(&query, 10, 9.0, 11.0).unwrap();
assert!(results.len() <= 3);
for (doc_id, _) in &results {
assert!(
*doc_id >= 9 && *doc_id <= 11,
"unexpected doc_id {} in narrow range",
doc_id
);
}
}
#[test]
fn empty_range_returns_empty() {
let dim = 16;
let mut index = EsgIndex::new(
dim,
EsgParams {
num_checkpoints: 4,
hnsw_m: 8,
hnsw_ef_construction: 50,
ef_search: 50,
},
)
.unwrap();
for i in 0..20u32 {
index.add(i, make_vector(dim, i), i as f64).unwrap();
}
index.build().unwrap();
let query = make_vector(dim, 0);
let results = index.range_search(&query, 5, 100.0, 200.0).unwrap();
assert!(results.is_empty());
}
}