use anyhow::Result;
use std::collections::HashMap;
use crate::similarity;
use crate::Vector;
use crate::VectorId;
pub trait VectorIndex: Send + Sync {
fn insert(&mut self, uri: String, vector: Vector) -> Result<()>;
fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>>;
fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>>;
fn get_vector(&self, uri: &str) -> Option<&Vector>;
fn add_vector(
&mut self,
id: VectorId,
vector: Vector,
_metadata: Option<HashMap<String, String>>,
) -> Result<()> {
self.insert(id, vector)
}
fn update_vector(&mut self, id: VectorId, vector: Vector) -> Result<()> {
self.insert(id, vector)
}
fn update_metadata(&mut self, _id: VectorId, _metadata: HashMap<String, String>) -> Result<()> {
Ok(())
}
fn remove_vector(&mut self, _id: VectorId) -> Result<()> {
Ok(())
}
fn iter_vectors(&self) -> Vec<(String, Vector)> {
Vec::new()
}
}
pub struct MemoryVectorIndex {
vectors: Vec<(String, Vector)>,
similarity_config: similarity::SimilarityConfig,
}
impl MemoryVectorIndex {
pub fn new() -> Self {
Self {
vectors: Vec::new(),
similarity_config: similarity::SimilarityConfig::default(),
}
}
pub fn with_similarity_config(config: similarity::SimilarityConfig) -> Self {
Self {
vectors: Vec::new(),
similarity_config: config,
}
}
}
impl Default for MemoryVectorIndex {
fn default() -> Self {
Self::new()
}
}
impl VectorIndex for MemoryVectorIndex {
fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
if let Some(pos) = self.vectors.iter().position(|(id, _)| id == &uri) {
self.vectors[pos] = (uri, vector);
} else {
self.vectors.push((uri, vector));
}
Ok(())
}
fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
let metric = self.similarity_config.primary_metric;
let query_f32 = query.as_f32();
let mut similarities: Vec<(String, f32)> = self
.vectors
.iter()
.map(|(uri, vec)| {
let vec_f32 = vec.as_f32();
let sim = metric.similarity(&query_f32, &vec_f32).unwrap_or(0.0);
(uri.clone(), sim)
})
.collect();
similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
similarities.truncate(k);
Ok(similarities)
}
fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
let metric = self.similarity_config.primary_metric;
let query_f32 = query.as_f32();
let similarities: Vec<(String, f32)> = self
.vectors
.iter()
.filter_map(|(uri, vec)| {
let vec_f32 = vec.as_f32();
let sim = metric.similarity(&query_f32, &vec_f32).unwrap_or(0.0);
if sim >= threshold {
Some((uri.clone(), sim))
} else {
None
}
})
.collect();
Ok(similarities)
}
fn get_vector(&self, uri: &str) -> Option<&Vector> {
self.vectors.iter().find(|(u, _)| u == uri).map(|(_, v)| v)
}
fn update_vector(&mut self, id: VectorId, vector: Vector) -> Result<()> {
if let Some(pos) = self.vectors.iter().position(|(uri, _)| uri == &id) {
self.vectors[pos] = (id, vector);
Ok(())
} else {
Err(anyhow::anyhow!("Vector with id '{}' not found", id))
}
}
fn remove_vector(&mut self, id: VectorId) -> Result<()> {
if let Some(pos) = self.vectors.iter().position(|(uri, _)| uri == &id) {
self.vectors.remove(pos);
Ok(())
} else {
Err(anyhow::anyhow!("Vector with id '{}' not found", id))
}
}
fn iter_vectors(&self) -> Vec<(String, Vector)> {
self.vectors.clone()
}
}