use super::vector::cosine_similarity;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddedDocument {
pub id: String,
pub content: String,
pub embedding: Vec<f32>,
#[serde(default)]
pub metadata: HashMap<String, serde_json::Value>,
}
impl EmbeddedDocument {
pub fn new(id: impl Into<String>, content: impl Into<String>, embedding: Vec<f32>) -> Self {
Self {
id: id.into(),
content: content.into(),
embedding,
metadata: HashMap::new(),
}
}
pub fn metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
}
#[derive(Debug, Clone)]
pub struct VectorSearchResult {
pub document: EmbeddedDocument,
pub score: f32,
}
pub struct VectorIndex {
documents: Vec<EmbeddedDocument>,
id_to_index: HashMap<String, usize>,
}
impl VectorIndex {
pub fn new() -> Self {
Self {
documents: Vec::new(),
id_to_index: HashMap::new(),
}
}
pub fn with_documents(documents: Vec<EmbeddedDocument>) -> Self {
let mut index = Self::new();
for doc in documents {
index.add(doc);
}
index
}
pub fn add(&mut self, document: EmbeddedDocument) {
if let Some(&idx) = self.id_to_index.get(&document.id) {
self.documents[idx] = document;
} else {
let idx = self.documents.len();
self.id_to_index.insert(document.id.clone(), idx);
self.documents.push(document);
}
}
pub fn add_batch(&mut self, documents: Vec<EmbeddedDocument>) {
for doc in documents {
self.add(doc);
}
}
pub fn remove(&mut self, id: &str) -> bool {
if let Some(idx) = self.id_to_index.remove(id) {
self.documents.swap_remove(idx);
if idx < self.documents.len() {
let swapped_id = &self.documents[idx].id;
self.id_to_index.insert(swapped_id.clone(), idx);
}
true
} else {
false
}
}
pub fn get(&self, id: &str) -> Option<&EmbeddedDocument> {
self.id_to_index.get(id).map(|&idx| &self.documents[idx])
}
pub fn search(&self, query: &[f32], limit: usize) -> Vec<VectorSearchResult> {
let mut results: Vec<VectorSearchResult> = self
.documents
.iter()
.map(|doc| {
let score = cosine_similarity(query, &doc.embedding);
VectorSearchResult {
document: doc.clone(),
score,
}
})
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(limit);
results
}
pub fn search_with_threshold(
&self,
query: &[f32],
limit: usize,
min_score: f32,
) -> Vec<VectorSearchResult> {
let mut results: Vec<VectorSearchResult> = self
.documents
.iter()
.filter_map(|doc| {
let score = cosine_similarity(query, &doc.embedding);
if score >= min_score {
Some(VectorSearchResult {
document: doc.clone(),
score,
})
} else {
None
}
})
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(limit);
results
}
pub fn len(&self) -> usize {
self.documents.len()
}
pub fn is_empty(&self) -> bool {
self.documents.is_empty()
}
pub fn clear(&mut self) {
self.documents.clear();
self.id_to_index.clear();
}
pub fn ids(&self) -> Vec<&str> {
self.documents.iter().map(|d| d.id.as_str()).collect()
}
}
impl Default for VectorIndex {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_doc(id: &str, embedding: Vec<f32>) -> EmbeddedDocument {
EmbeddedDocument::new(id, format!("Content for {}", id), embedding)
}
#[test]
fn test_add_and_get() {
let mut index = VectorIndex::new();
index.add(make_doc("1", vec![1.0, 0.0, 0.0]));
let doc = index.get("1").unwrap();
assert_eq!(doc.id, "1");
assert_eq!(doc.embedding, vec![1.0, 0.0, 0.0]);
}
#[test]
fn test_add_replaces_existing() {
let mut index = VectorIndex::new();
index.add(make_doc("1", vec![1.0, 0.0, 0.0]));
index.add(make_doc("1", vec![0.0, 1.0, 0.0]));
assert_eq!(index.len(), 1);
assert_eq!(index.get("1").unwrap().embedding, vec![0.0, 1.0, 0.0]);
}
#[test]
fn test_remove() {
let mut index = VectorIndex::new();
index.add(make_doc("1", vec![1.0, 0.0, 0.0]));
index.add(make_doc("2", vec![0.0, 1.0, 0.0]));
assert!(index.remove("1"));
assert!(!index.remove("1")); assert_eq!(index.len(), 1);
assert!(index.get("2").is_some());
}
#[test]
fn test_search() {
let mut index = VectorIndex::new();
index.add(make_doc("1", vec![1.0, 0.0, 0.0])); index.add(make_doc("2", vec![0.0, 1.0, 0.0])); index.add(make_doc("3", vec![0.7, 0.7, 0.0]));
let query = vec![1.0, 0.0, 0.0];
let results = index.search(&query, 2);
assert_eq!(results.len(), 2);
assert_eq!(results[0].document.id, "1");
assert!((results[0].score - 1.0).abs() < 0.001);
}
#[test]
fn test_search_with_threshold() {
let mut index = VectorIndex::new();
index.add(make_doc("1", vec![1.0, 0.0, 0.0]));
index.add(make_doc("2", vec![0.0, 1.0, 0.0]));
index.add(make_doc("3", vec![0.9, 0.1, 0.0]));
let query = vec![1.0, 0.0, 0.0];
let results = index.search_with_threshold(&query, 10, 0.5);
assert_eq!(results.len(), 2);
assert!(results.iter().all(|r| r.score >= 0.5));
}
#[test]
fn test_with_documents() {
let docs = vec![make_doc("1", vec![1.0, 0.0]), make_doc("2", vec![0.0, 1.0])];
let index = VectorIndex::with_documents(docs);
assert_eq!(index.len(), 2);
assert!(index.get("1").is_some());
assert!(index.get("2").is_some());
}
#[test]
fn test_clear() {
let mut index = VectorIndex::new();
index.add(make_doc("1", vec![1.0, 0.0]));
index.add(make_doc("2", vec![0.0, 1.0]));
index.clear();
assert!(index.is_empty());
assert_eq!(index.len(), 0);
}
}