use std::collections::HashMap;
use std::sync::Arc;
use crate::cache::lru::LruCache;
use crate::types::NodeId;
use crate::vector::{
create_vector_store, vector_store_clear, vector_store_delete, vector_store_get,
vector_store_insert, vector_store_stats, DistanceMetric, IvfConfig, IvfIndex, SearchOptions,
VectorManifest, VectorSearchResult, VectorStoreConfig,
};
const DEFAULT_CACHE_MAX_SIZE: usize = 10_000;
const DEFAULT_TRAINING_THRESHOLD: usize = 1000;
const MIN_CLUSTERS: usize = 16;
const MAX_CLUSTERS: usize = 1024;
#[derive(Debug, Clone)]
pub struct VectorIndexOptions {
pub dimensions: usize,
pub metric: DistanceMetric,
pub row_group_size: usize,
pub fragment_target_size: usize,
pub normalize: bool,
pub n_clusters: Option<usize>,
pub n_probe: usize,
pub training_threshold: usize,
pub cache_max_size: usize,
}
impl Default for VectorIndexOptions {
fn default() -> Self {
Self {
dimensions: 0, metric: DistanceMetric::Cosine,
row_group_size: 1024,
fragment_target_size: 100_000,
normalize: true,
n_clusters: None,
n_probe: 10,
training_threshold: DEFAULT_TRAINING_THRESHOLD,
cache_max_size: DEFAULT_CACHE_MAX_SIZE,
}
}
}
impl VectorIndexOptions {
pub fn new(dimensions: usize) -> Self {
let metric = DistanceMetric::Cosine;
Self {
dimensions,
normalize: metric == DistanceMetric::Cosine,
..Default::default()
}
}
pub fn with_metric(mut self, metric: DistanceMetric) -> Self {
self.metric = metric;
if metric == DistanceMetric::Cosine {
self.normalize = true;
}
self
}
pub fn with_row_group_size(mut self, size: usize) -> Self {
self.row_group_size = size;
self
}
pub fn with_fragment_target_size(mut self, size: usize) -> Self {
self.fragment_target_size = size;
self
}
pub fn with_normalize(mut self, normalize: bool) -> Self {
self.normalize = normalize;
self
}
pub fn with_n_clusters(mut self, n_clusters: usize) -> Self {
self.n_clusters = Some(n_clusters);
self
}
pub fn with_n_probe(mut self, n_probe: usize) -> Self {
self.n_probe = n_probe;
self
}
pub fn with_training_threshold(mut self, threshold: usize) -> Self {
self.training_threshold = threshold;
self
}
pub fn with_cache_max_size(mut self, size: usize) -> Self {
self.cache_max_size = size;
self
}
}
pub struct SimilarOptions {
pub k: usize,
pub threshold: Option<f32>,
pub n_probe: Option<usize>,
pub filter: Option<std::sync::Arc<dyn Fn(NodeId) -> bool + Send + Sync>>,
}
impl std::fmt::Debug for SimilarOptions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SimilarOptions")
.field("k", &self.k)
.field("threshold", &self.threshold)
.field("n_probe", &self.n_probe)
.field("filter", &self.filter.is_some())
.finish()
}
}
impl SimilarOptions {
pub fn new(k: usize) -> Self {
Self {
k,
threshold: None,
n_probe: None,
filter: None,
}
}
pub fn with_threshold(mut self, threshold: f32) -> Self {
self.threshold = Some(threshold);
self
}
pub fn with_n_probe(mut self, n_probe: usize) -> Self {
self.n_probe = Some(n_probe);
self
}
pub fn with_filter<F>(mut self, filter: F) -> Self
where
F: Fn(NodeId) -> bool + Send + Sync + 'static,
{
self.filter = Some(Arc::new(filter));
self
}
}
#[derive(Debug, Clone)]
pub struct VectorSearchHit {
pub node_id: NodeId,
pub distance: f32,
pub similarity: f32,
}
#[derive(Debug, Clone)]
pub struct VectorIndexStats {
pub total_vectors: usize,
pub live_vectors: usize,
pub dimensions: usize,
pub metric: DistanceMetric,
pub index_trained: bool,
pub index_clusters: Option<usize>,
}
pub struct VectorIndex {
manifest: VectorManifest,
index: Option<IvfIndex>,
node_cache: LruCache<NodeId, ()>,
cached_node_ids: HashMap<NodeId, u32>,
options: VectorIndexOptions,
needs_training: bool,
is_building: bool,
}
impl VectorIndex {
pub fn new(options: VectorIndexOptions) -> Self {
let config = VectorStoreConfig::new(options.dimensions)
.with_metric(options.metric)
.with_row_group_size(options.row_group_size)
.with_fragment_target_size(options.fragment_target_size)
.with_normalize(options.normalize);
let manifest = create_vector_store(config);
Self {
manifest,
index: None,
node_cache: LruCache::new(options.cache_max_size),
cached_node_ids: HashMap::new(),
options,
needs_training: true,
is_building: false,
}
}
pub fn set(&mut self, node_id: NodeId, vector: &[f32]) -> Result<(), VectorIndexError> {
if self.is_building {
return Err(VectorIndexError::BuildInProgress);
}
if vector.len() != self.options.dimensions {
return Err(VectorIndexError::DimensionMismatch {
expected: self.options.dimensions,
got: vector.len(),
});
}
if let Some(&existing_vector_id) = self.manifest.node_to_vector.get(&node_id) {
if let Some(ref mut index) = self.index {
if index.trained {
if let Some(existing_vector) = vector_store_get(&self.manifest, node_id) {
index.delete(existing_vector_id, existing_vector);
}
}
}
}
let vector_id = vector_store_insert(&mut self.manifest, node_id, vector)
.map_err(|e| VectorIndexError::StoreError(e.to_string()))?;
self.node_cache.set(node_id, ());
self.cached_node_ids.insert(node_id, vector_id as u32);
if let Some(ref mut index) = self.index {
if index.trained {
if let Some(stored_vector) = vector_store_get(&self.manifest, node_id) {
let _ = index.insert(vector_id as u64, stored_vector);
}
} else {
self.needs_training = true;
}
} else {
self.needs_training = true;
}
Ok(())
}
pub fn get(&self, node_id: NodeId) -> Option<Vec<f32>> {
vector_store_get(&self.manifest, node_id).map(|s| s.to_vec())
}
pub fn delete(&mut self, node_id: NodeId) -> Result<bool, VectorIndexError> {
if self.is_building {
return Err(VectorIndexError::BuildInProgress);
}
if let Some(ref mut index) = self.index {
if index.trained {
if let Some(&vector_id) = self.manifest.node_to_vector.get(&node_id) {
if let Some(vector) = vector_store_get(&self.manifest, node_id) {
index.delete(vector_id, vector);
}
}
}
}
self.node_cache.remove(&node_id);
self.cached_node_ids.remove(&node_id);
let deleted = vector_store_delete(&mut self.manifest, node_id);
Ok(deleted)
}
pub fn has(&self, node_id: NodeId) -> bool {
self.manifest.node_to_vector.contains_key(&node_id)
}
pub fn build_index(&mut self) -> Result<(), VectorIndexError> {
if self.is_building {
return Err(VectorIndexError::BuildAlreadyInProgress);
}
self.is_building = true;
let result = self.build_index_internal();
self.is_building = false;
result
}
fn build_index_internal(&mut self) -> Result<(), VectorIndexError> {
let dimensions = self.options.dimensions;
let stats = vector_store_stats(&self.manifest);
let live_vectors = stats.live_vectors;
if live_vectors < self.options.training_threshold {
self.index = None;
self.needs_training = false;
return Ok(());
}
let n_clusters = self.options.n_clusters.unwrap_or_else(|| {
let sqrt_n = (live_vectors as f64).sqrt() as usize;
sqrt_n.clamp(MIN_CLUSTERS, MAX_CLUSTERS)
});
let mut training_data = Vec::with_capacity(live_vectors * dimensions);
let mut vector_ids = Vec::with_capacity(live_vectors);
for (&node_id, &vector_id) in &self.manifest.node_to_vector {
if let Some(vector) = vector_store_get(&self.manifest, node_id) {
training_data.extend_from_slice(vector);
vector_ids.push(vector_id);
}
}
let ivf_config = IvfConfig::new(n_clusters)
.with_n_probe(self.options.n_probe)
.with_metric(self.options.metric);
let mut index = IvfIndex::new(dimensions, ivf_config);
index
.add_training_vectors(&training_data, vector_ids.len())
.map_err(|e| VectorIndexError::TrainingError(e.to_string()))?;
index
.train()
.map_err(|e| VectorIndexError::TrainingError(e.to_string()))?;
for (i, &vector_id) in vector_ids.iter().enumerate() {
let offset = i * dimensions;
let vector = &training_data[offset..offset + dimensions];
let _ = index.insert(vector_id, vector);
}
self.index = Some(index);
self.needs_training = false;
Ok(())
}
pub fn search(
&mut self,
query: &[f32],
options: SimilarOptions,
) -> Result<Vec<VectorSearchHit>, VectorIndexError> {
let dimensions = self.options.dimensions;
if query.len() != dimensions {
return Err(VectorIndexError::DimensionMismatch {
expected: dimensions,
got: query.len(),
});
}
if !is_valid_vector(query) {
return Err(VectorIndexError::InvalidVector);
}
if self.needs_training {
self.build_index()?;
}
let k = options.k;
let n_probe = options.n_probe.unwrap_or(self.options.n_probe);
let results: Vec<VectorSearchResult> = if let Some(ref index) = self.index {
if index.trained {
let search_opts = SearchOptions {
n_probe: Some(n_probe),
filter: None,
threshold: None,
};
index.search(&self.manifest, query, k * 2, Some(search_opts))
} else {
self.brute_force_search(query, k * 2)
}
} else {
self.brute_force_search(query, k * 2)
};
let mut hits = Vec::with_capacity(k);
for result in results {
if let Some(ref filter) = options.filter {
if !filter(result.node_id) {
continue;
}
}
if let Some(threshold) = options.threshold {
if result.similarity < threshold {
continue;
}
}
hits.push(VectorSearchHit {
node_id: result.node_id,
distance: result.distance,
similarity: result.similarity,
});
if hits.len() >= k {
break;
}
}
Ok(hits)
}
fn brute_force_search(&self, query: &[f32], k: usize) -> Vec<VectorSearchResult> {
use crate::vector::{cosine_distance, dot_product, euclidean_distance, normalize};
let metric = self.options.metric;
let query_normalized: Vec<f32>;
let query_for_search = if metric == DistanceMetric::Cosine {
query_normalized = normalize(query);
&query_normalized
} else {
query
};
let mut candidates: Vec<VectorSearchResult> = Vec::new();
for (&node_id, &vector_id) in &self.manifest.node_to_vector {
if let Some(vector) = vector_store_get(&self.manifest, node_id) {
let distance = match metric {
DistanceMetric::Cosine => cosine_distance(query_for_search, vector),
DistanceMetric::Euclidean => euclidean_distance(query_for_search, vector),
DistanceMetric::DotProduct => -dot_product(query_for_search, vector), };
let similarity = metric.distance_to_similarity(distance);
candidates.push(VectorSearchResult {
vector_id,
node_id,
distance,
similarity,
});
}
}
candidates.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(std::cmp::Ordering::Equal)
});
candidates.truncate(k);
candidates
}
pub fn stats(&self) -> VectorIndexStats {
let store_stats = vector_store_stats(&self.manifest);
VectorIndexStats {
total_vectors: store_stats.total_vectors,
live_vectors: store_stats.live_vectors,
dimensions: self.options.dimensions,
metric: self.options.metric,
index_trained: self.index.as_ref().map(|i| i.trained).unwrap_or(false),
index_clusters: self.index.as_ref().map(|i| i.config.n_clusters),
}
}
pub fn clear(&mut self) {
vector_store_clear(&mut self.manifest);
self.node_cache = LruCache::new(self.options.cache_max_size);
self.cached_node_ids.clear();
self.index = None;
self.needs_training = true;
}
pub fn len(&self) -> usize {
self.manifest.node_to_vector.len()
}
pub fn is_empty(&self) -> bool {
self.manifest.node_to_vector.is_empty()
}
}
#[derive(Debug, Clone)]
pub enum VectorIndexError {
BuildInProgress,
BuildAlreadyInProgress,
DimensionMismatch { expected: usize, got: usize },
InvalidVector,
StoreError(String),
TrainingError(String),
}
impl std::fmt::Display for VectorIndexError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
VectorIndexError::BuildInProgress => {
write!(f, "Cannot modify vectors while index is being built")
}
VectorIndexError::BuildAlreadyInProgress => {
write!(f, "Index build already in progress")
}
VectorIndexError::DimensionMismatch { expected, got } => {
write!(
f,
"Vector dimension mismatch: expected {expected}, got {got}"
)
}
VectorIndexError::InvalidVector => {
write!(f, "Invalid vector: contains NaN or Inf values")
}
VectorIndexError::StoreError(msg) => {
write!(f, "Store error: {msg}")
}
VectorIndexError::TrainingError(msg) => {
write!(f, "Training error: {msg}")
}
}
}
}
impl std::error::Error for VectorIndexError {}
fn is_valid_vector(vector: &[f32]) -> bool {
vector.iter().all(|&v| v.is_finite())
}
pub fn create_vector_index(options: VectorIndexOptions) -> VectorIndex {
VectorIndex::new(options)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vector_index_options_default() {
let opts = VectorIndexOptions::new(768);
assert_eq!(opts.dimensions, 768);
assert_eq!(opts.metric, DistanceMetric::Cosine);
assert!(opts.normalize);
assert_eq!(opts.training_threshold, DEFAULT_TRAINING_THRESHOLD);
}
#[test]
fn test_vector_index_options_builder() {
let opts = VectorIndexOptions::new(512)
.with_metric(DistanceMetric::Euclidean)
.with_normalize(false)
.with_n_probe(20)
.with_training_threshold(500);
assert_eq!(opts.dimensions, 512);
assert_eq!(opts.metric, DistanceMetric::Euclidean);
assert!(!opts.normalize);
assert_eq!(opts.n_probe, 20);
assert_eq!(opts.training_threshold, 500);
}
#[test]
fn test_similar_options() {
let opts = SimilarOptions::new(10).with_threshold(0.8).with_n_probe(5);
assert_eq!(opts.k, 10);
assert_eq!(opts.threshold, Some(0.8));
assert_eq!(opts.n_probe, Some(5));
}
#[test]
fn test_vector_index_new() {
let index = VectorIndex::new(VectorIndexOptions::new(128));
assert!(index.is_empty());
assert_eq!(index.len(), 0);
}
#[test]
fn test_vector_index_set_get() {
let mut index = VectorIndex::new(VectorIndexOptions::new(4));
let vector = vec![1.0, 0.0, 0.0, 0.0];
index.set(1, &vector).unwrap();
assert!(index.has(1));
assert!(!index.has(2));
let retrieved = index.get(1).unwrap();
assert_eq!(retrieved.len(), 4);
}
#[test]
fn test_vector_index_delete() {
let mut index = VectorIndex::new(VectorIndexOptions::new(4));
let vector = vec![1.0, 0.0, 0.0, 0.0];
index.set(1, &vector).unwrap();
assert!(index.has(1));
let deleted = index.delete(1).unwrap();
assert!(deleted);
assert!(!index.has(1));
}
#[test]
fn test_vector_index_dimension_mismatch() {
let mut index = VectorIndex::new(VectorIndexOptions::new(4));
let vector = vec![1.0, 0.0, 0.0]; let result = index.set(1, &vector);
assert!(matches!(
result,
Err(VectorIndexError::DimensionMismatch {
expected: 4,
got: 3
})
));
}
#[test]
fn test_vector_index_invalid_vector() {
let mut index = VectorIndex::new(VectorIndexOptions::new(4));
let vector = vec![1.0, f32::NAN, 0.0, 0.0];
let _result = index.set(1, &vector);
}
#[test]
fn test_vector_index_clear() {
let mut index = VectorIndex::new(VectorIndexOptions::new(4));
index.set(1, &[1.0, 0.0, 0.0, 0.0]).unwrap();
index.set(2, &[0.0, 1.0, 0.0, 0.0]).unwrap();
assert_eq!(index.len(), 2);
index.clear();
assert!(index.is_empty());
assert_eq!(index.len(), 0);
}
#[test]
fn test_vector_index_stats() {
let mut index = VectorIndex::new(VectorIndexOptions::new(4));
index.set(1, &[1.0, 0.0, 0.0, 0.0]).unwrap();
index.set(2, &[0.0, 1.0, 0.0, 0.0]).unwrap();
let stats = index.stats();
assert_eq!(stats.dimensions, 4);
assert_eq!(stats.live_vectors, 2);
assert!(!stats.index_trained);
}
#[test]
fn test_brute_force_search() {
let mut index = VectorIndex::new(
VectorIndexOptions::new(4).with_training_threshold(1000), );
index.set(1, &[1.0, 0.0, 0.0, 0.0]).unwrap();
index.set(2, &[0.0, 1.0, 0.0, 0.0]).unwrap();
index.set(3, &[0.707, 0.707, 0.0, 0.0]).unwrap();
let query = vec![1.0, 0.0, 0.0, 0.0];
let results = index.search(&query, SimilarOptions::new(3)).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].node_id, 1);
assert!(results[0].similarity > 0.99);
}
#[test]
fn test_search_with_threshold() {
let mut index = VectorIndex::new(VectorIndexOptions::new(4).with_training_threshold(1000));
index.set(1, &[1.0, 0.0, 0.0, 0.0]).unwrap();
index.set(2, &[0.0, 1.0, 0.0, 0.0]).unwrap();
let query = vec![1.0, 0.0, 0.0, 0.0];
let results = index
.search(&query, SimilarOptions::new(10).with_threshold(0.5))
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].node_id, 1);
}
#[test]
fn test_is_valid_vector() {
assert!(is_valid_vector(&[1.0, 2.0, 3.0]));
assert!(is_valid_vector(&[0.0, 0.0, 0.0]));
assert!(!is_valid_vector(&[1.0, f32::NAN, 3.0]));
assert!(!is_valid_vector(&[1.0, f32::INFINITY, 3.0]));
assert!(!is_valid_vector(&[f32::NEG_INFINITY, 2.0, 3.0]));
}
#[test]
fn test_create_vector_index() {
let index = create_vector_index(VectorIndexOptions::new(256));
assert!(index.is_empty());
}
}