use crate::RetrieveError;
use smallvec::SmallVec;
#[derive(Debug)]
pub struct NSWIndex {
pub(crate) vectors: Vec<f32>,
pub(crate) dimension: usize,
pub(crate) num_vectors: usize,
pub(crate) neighbors: Vec<SmallVec<[u32; 16]>>,
pub(crate) params: NSWParams,
built: bool,
pub(crate) entry_point: Option<u32>,
}
#[derive(Clone, Debug)]
pub struct NSWParams {
pub m: usize,
pub m_max: usize,
pub ef_construction: usize,
pub ef_search: usize,
}
impl Default for NSWParams {
fn default() -> Self {
Self {
m: 16,
m_max: 16,
ef_construction: 200,
ef_search: 50,
}
}
}
impl NSWIndex {
pub fn new(dimension: usize, m: usize, m_max: usize) -> Result<Self, RetrieveError> {
if dimension == 0 {
return Err(RetrieveError::EmptyQuery);
}
if m == 0 || m_max == 0 {
return Err(RetrieveError::Other(
"m and m_max must be greater than 0".to_string(),
));
}
Ok(Self {
vectors: Vec::new(),
dimension,
num_vectors: 0,
neighbors: Vec::new(),
params: NSWParams {
m,
m_max,
..Default::default()
},
built: false,
entry_point: None,
})
}
pub fn with_params(dimension: usize, params: NSWParams) -> Result<Self, RetrieveError> {
if dimension == 0 {
return Err(RetrieveError::EmptyQuery);
}
if params.m == 0 || params.m_max == 0 {
return Err(RetrieveError::Other(
"m and m_max must be greater than 0".to_string(),
));
}
Ok(Self {
vectors: Vec::new(),
dimension,
num_vectors: 0,
neighbors: Vec::new(),
params,
built: false,
entry_point: None,
})
}
pub fn add(&mut self, _doc_id: u32, vector: Vec<f32>) -> Result<(), RetrieveError> {
self.add_slice(_doc_id, &vector)
}
pub fn add_slice(&mut self, _doc_id: u32, vector: &[f32]) -> Result<(), RetrieveError> {
if self.built {
return Err(RetrieveError::Other(
"Cannot add vectors after index is built".to_string(),
));
}
if vector.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: self.dimension,
doc_dim: vector.len(),
});
}
self.vectors.extend_from_slice(vector);
self.num_vectors += 1;
Ok(())
}
pub fn build(&mut self) -> Result<(), RetrieveError> {
if self.built {
return Ok(());
}
if self.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
super::construction::construct_graph(self)?;
self.built = true;
Ok(())
}
pub fn search(
&self,
query: &[f32],
k: usize,
ef: usize,
) -> Result<Vec<(u32, f32)>, RetrieveError> {
if !self.built {
return Err(RetrieveError::Other(
"Index must be built before search".to_string(),
));
}
if query.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: self.dimension,
doc_dim: query.len(),
});
}
if self.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
let entry_point = self.entry_point.ok_or(RetrieveError::EmptyIndex)?;
let results = super::search::greedy_search(
query,
entry_point,
&self.neighbors,
&self.vectors,
self.dimension,
ef.max(k),
)?;
let mut sorted_results: Vec<(u32, f32)> = results.into_iter().take(k).collect();
sorted_results.sort_by(|a, b| a.1.total_cmp(&b.1));
Ok(sorted_results)
}
pub(crate) fn get_vector(&self, idx: usize) -> &[f32] {
let start = idx * self.dimension;
let end = start + self.dimension;
&self.vectors[start..end]
}
}