use crate::error::Result;
use std::fmt::Debug;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum DistanceType {
#[default]
L2,
InnerProduct,
Cosine,
}
#[derive(Debug, Clone)]
pub struct IndexConfig {
pub dimension: usize,
pub distance_type: DistanceType,
pub normalize: bool,
}
impl IndexConfig {
#[must_use]
pub fn new(dimension: usize) -> Self {
Self {
dimension,
distance_type: DistanceType::L2,
normalize: false,
}
}
#[must_use]
pub const fn with_distance(mut self, distance_type: DistanceType) -> Self {
self.distance_type = distance_type;
self
}
#[must_use]
pub const fn with_normalize(mut self, normalize: bool) -> Self {
self.normalize = normalize;
self
}
}
#[derive(Debug, Clone)]
pub struct SearchResult {
pub id: String,
pub distance: f32,
pub score: f32,
}
impl SearchResult {
#[must_use]
pub fn new(id: String, distance: f32, distance_type: DistanceType) -> Self {
let score = Self::distance_to_score(distance, distance_type);
Self { id, distance, score }
}
fn distance_to_score(distance: f32, distance_type: DistanceType) -> f32 {
match distance_type {
DistanceType::L2 => {
1.0 / (1.0 + distance)
}
DistanceType::InnerProduct | DistanceType::Cosine => {
distance.clamp(0.0, 1.0)
}
}
}
}
pub trait VectorIndex: Send + Sync + Debug {
fn add(&mut self, id: String, vector: &[f32]) -> Result<()>;
fn add_batch(&mut self, ids: Vec<String>, vectors: &[Vec<f32>]) -> Result<()> {
for (id, vector) in ids.into_iter().zip(vectors.iter()) {
self.add(id, vector)?;
}
Ok(())
}
fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>>;
fn search_with_ids(&self, query: &[f32], k: usize, ids: &[String]) -> Result<Vec<SearchResult>> {
let results = self.search(query, self.len().min(k * 10))?;
let id_set: std::collections::HashSet<_> = ids.iter().collect();
Ok(results
.into_iter()
.filter(|r| id_set.contains(&r.id))
.take(k)
.collect())
}
fn remove(&mut self, id: &str) -> Result<bool>;
fn contains(&self, id: &str) -> bool;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn dimension(&self) -> usize;
fn distance_type(&self) -> DistanceType;
fn clear(&mut self);
fn memory_usage(&self) -> usize;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_l2_score() {
let r = SearchResult::new("a".to_string(), 0.0, DistanceType::L2);
assert!((r.score - 1.0).abs() < 1e-6);
let r = SearchResult::new("b".to_string(), 1.0, DistanceType::L2);
assert!((r.score - 0.5).abs() < 1e-6);
}
#[test]
fn test_config() {
let config = IndexConfig::new(256)
.with_distance(DistanceType::Cosine)
.with_normalize(true);
assert_eq!(config.dimension, 256);
assert_eq!(config.distance_type, DistanceType::Cosine);
assert!(config.normalize);
}
}