pub mod flat;
pub mod hnsw;
pub mod hnsw_pq;
pub mod hnsw_quantized;
pub mod streaming;
pub use flat::FlatIndex;
pub use hnsw::{BuildStrategy, HNSWConfig, HNSWIndex};
pub use hnsw_pq::{PQHNSWConfig, PQHNSWIndex};
pub use hnsw_quantized::{BinaryHNSWIndex, QuantizedHNSWConfig, SQ8HNSWIndex};
pub use streaming::{
BatchBuilder, BatchConfig, BatchIndex, BatchProgress, BatchResult, FilteredSearchBuilder,
PaginationConfig, SearchPage, SearchResultIterator,
};
use crate::{Document, Result, SearchResult};
pub trait VectorIndex {
fn add(&mut self, document: Document) -> Result<()>;
fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>>;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn clear(&mut self);
fn embedding_dim(&self) -> usize;
}
pub trait VectorIndexSnapshot: VectorIndex {
fn get_all_documents(&self) -> Vec<Document>;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Document;
fn make_doc(id: &str, embedding: Vec<f32>) -> Document {
Document {
id: id.into(),
content: format!("content-{id}"),
embedding,
metadata: None,
}
}
#[test]
fn vector_index_object_safety_hnsw() {
let mut index: Box<dyn VectorIndex> = Box::new(HNSWIndex::with_defaults(3));
assert!(index.is_empty());
assert_eq!(index.embedding_dim(), 3);
index.add(make_doc("a", vec![1.0, 0.0, 0.0])).unwrap();
index.add(make_doc("b", vec![0.0, 1.0, 0.0])).unwrap();
assert_eq!(index.len(), 2);
let results = index.search(&[1.0, 0.0, 0.0], 1).unwrap();
assert_eq!(results[0].id, "a");
index.clear();
assert!(index.is_empty());
}
#[test]
fn vector_index_snapshot_flat() {
let mut index: Box<dyn VectorIndexSnapshot> = Box::new(FlatIndex::new(3));
index.add(make_doc("x", vec![0.5, 0.5, 0.0])).unwrap();
index.add(make_doc("y", vec![0.0, 0.5, 0.5])).unwrap();
let docs = index.get_all_documents();
assert_eq!(docs.len(), 2);
let ids: std::collections::HashSet<_> = docs.iter().map(|d| d.id.as_str()).collect();
assert!(ids.contains("x"));
assert!(ids.contains("y"));
}
#[test]
fn vector_index_snapshot_hnsw() {
let mut index: Box<dyn VectorIndexSnapshot> = Box::new(HNSWIndex::with_defaults(3));
index.add(make_doc("p", vec![1.0, 0.0, 0.0])).unwrap();
let docs = index.get_all_documents();
assert_eq!(docs.len(), 1);
assert_eq!(docs[0].id, "p");
}
#[test]
fn vector_index_sq8() {
let mut index: Box<dyn VectorIndex> = Box::new(SQ8HNSWIndex::for_normalized(
4,
QuantizedHNSWConfig::default(),
));
index.add(make_doc("q", vec![0.5, -0.3, 0.8, 0.1])).unwrap();
assert_eq!(index.len(), 1);
assert_eq!(index.embedding_dim(), 4);
let results = index.search(&[0.5, -0.3, 0.8, 0.1], 1).unwrap();
assert_eq!(results.len(), 1);
}
#[test]
fn vector_index_binary() {
let mut index: Box<dyn VectorIndex> =
Box::new(BinaryHNSWIndex::new(4, QuantizedHNSWConfig::default()));
index.add(make_doc("r", vec![0.1, 0.2, 0.3, 0.4])).unwrap();
assert_eq!(index.len(), 1);
assert_eq!(index.embedding_dim(), 4);
}
#[test]
fn batch_builder_via_blanket_impl() {
let mut index = HNSWIndex::with_defaults(3);
let config = BatchConfig::default().with_batch_size(10);
let mut builder = BatchBuilder::new(&mut index, config);
builder.add(make_doc("d1", vec![1.0, 0.0, 0.0])).unwrap();
builder.add(make_doc("d2", vec![0.0, 1.0, 0.0])).unwrap();
let result = builder.finish();
assert_eq!(result.documents_indexed, 2);
assert_eq!(index.len(), 2);
}
#[test]
fn batch_builder_flat_via_blanket_impl() {
let mut index = FlatIndex::new(3);
let config = BatchConfig::default();
let mut builder = BatchBuilder::new(&mut index, config);
builder.add(make_doc("f1", vec![1.0, 0.0, 0.0])).unwrap();
let result = builder.finish();
assert_eq!(result.documents_indexed, 1);
}
}