use crate::RetrieveError;
use smallvec::SmallVec;
#[cfg(feature = "vamana")]
#[derive(Clone, Debug)]
pub struct VamanaParams {
pub max_degree: usize,
pub alpha: f32,
pub ef_construction: usize,
pub ef_search: usize,
}
#[cfg(feature = "vamana")]
impl Default for VamanaParams {
fn default() -> Self {
Self {
max_degree: 64,
alpha: 1.3,
ef_construction: 200,
ef_search: 50,
}
}
}
#[cfg(feature = "vamana")]
pub struct VamanaIndex {
pub(crate) dimension: usize,
pub(crate) vectors: Vec<f32>,
pub(crate) neighbors: Vec<SmallVec<[u32; 16]>>,
pub(crate) params: VamanaParams,
pub(crate) num_vectors: usize,
built: bool,
}
#[cfg(feature = "vamana")]
impl VamanaIndex {
pub fn new(dimension: usize, params: VamanaParams) -> Result<Self, RetrieveError> {
if dimension == 0 {
return Err(RetrieveError::Other("Dimension must be > 0".to_string()));
}
Ok(Self {
dimension,
vectors: Vec::new(),
neighbors: Vec::new(),
params,
num_vectors: 0,
built: false,
})
}
pub fn add(&mut self, _id: u32, vector: Vec<f32>) -> Result<(), RetrieveError> {
if self.built {
return Err(RetrieveError::Other(
"Cannot add vectors after build".to_string(),
));
}
if vector.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: self.dimension,
doc_dim: vector.len(),
});
}
self.vectors.extend_from_slice(&vector);
self.neighbors.push(SmallVec::new());
self.num_vectors += 1;
Ok(())
}
pub fn build(&mut self) -> Result<(), RetrieveError> {
if self.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
if self.built {
return Err(RetrieveError::Other("Index already built".to_string()));
}
super::construction::construct_graph(self)?;
self.built = true;
Ok(())
}
pub fn search(
&self,
query: &[f32],
k: usize,
ef: usize,
) -> Result<Vec<(u32, f32)>, RetrieveError> {
if !self.built {
return Err(RetrieveError::Other(
"Index must be built before search".to_string(),
));
}
super::search::search(self, query, k, ef)
}
pub(crate) fn get_vector(&self, idx: usize) -> &[f32] {
let start = idx * self.dimension;
let end = start + self.dimension;
&self.vectors[start..end]
}
}
#[cfg(all(test, feature = "vamana"))]
mod tests {
use super::*;
#[test]
fn test_vamana_create() {
let params = VamanaParams::default();
let index = VamanaIndex::new(128, params);
assert!(index.is_ok());
}
#[test]
fn test_vamana_add() {
let params = VamanaParams::default();
let mut index = VamanaIndex::new(128, params).unwrap();
let vector = vec![0.1; 128];
assert!(index.add(0, vector).is_ok());
assert_eq!(index.num_vectors, 1);
}
}