use async_trait::async_trait;
use anyhow::Result;
use std::collections::HashMap;
use crate::types::{Document, SearchOptions, SearchResult, SearchFilter};
#[async_trait]
pub trait VectorStore: Send + Sync {
fn as_any(&self) -> &dyn std::any::Any;
async fn initialize(&self) -> Result<()>;
async fn is_initialized(&self) -> bool;
async fn set_dimensions(&self, dimensions: usize) -> Result<()>;
async fn add_document(&self, collection_name: &str, document: Document) -> Result<String>;
async fn add_documents(&self, collection_name: &str, documents: Vec<Document>) -> Result<Vec<String>>;
async fn search(
&self,
collection_name: &str,
query_vector: Vec<f32>,
options: SearchOptions,
) -> Result<Vec<SearchResult>>;
async fn get_document(&self, collection_name: &str, id: &str) -> Result<Option<Document>>;
async fn update_document(&self, collection_name: &str, id: &str, document: Document) -> Result<()>;
async fn delete_document(&self, collection_name: &str, id: &str) -> Result<bool>;
async fn list_documents(
&self,
collection_name: &str,
limit: Option<usize>,
filter: Option<SearchFilter>,
) -> Result<Vec<Document>>;
async fn create_collection(&self, name: &str, vector_size: usize) -> Result<()>;
async fn delete_collection(&self, name: &str) -> Result<bool>;
async fn list_collections(&self) -> Result<Vec<String>>;
async fn get_collection_info(&self, name: &str) -> Result<Option<CollectionInfo>>;
async fn scroll_collection(
&self,
collection_name: &str,
filter: Option<SearchFilter>,
limit: Option<usize>,
) -> Result<Vec<SearchResult>>;
async fn get_collections_health(&self) -> Result<HashMap<String, CollectionHealth>>;
async fn shutdown(&self) -> Result<()>;
async fn clear_document_cache(&self) -> Result<()>;
async fn disable_optimizer(&self, collection_name: &str) -> Result<()>;
async fn enable_optimizer(&self, collection_name: &str) -> Result<()>;
}
#[derive(Debug, Clone)]
pub struct CollectionInfo {
pub name: String,
pub vector_size: usize,
pub distance: String,
pub points_count: usize,
pub segments_count: Option<usize>,
pub disk_data_size: Option<u64>,
pub ram_data_size: Option<u64>,
}
#[derive(Debug, Clone)]
pub struct CollectionHealth {
pub name: String,
pub status: String,
pub points_count: usize,
pub segments_count: usize,
pub disk_size: u64,
pub ram_size: u64,
pub last_updated: chrono::DateTime<chrono::Utc>,
}
pub mod utils {
use anyhow::{Result, anyhow};
pub fn cosine_similarity(vec_a: &[f32], vec_b: &[f32]) -> Result<f32> {
if vec_a.len() != vec_b.len() {
return Err(anyhow!("Vectors must have the same dimensions"));
}
let dot_product: f32 = vec_a.iter().zip(vec_b.iter()).map(|(a, b)| a * b).sum();
let norm_a: f32 = vec_a.iter().map(|a| a * a).sum::<f32>().sqrt();
let norm_b: f32 = vec_b.iter().map(|b| b * b).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return Ok(0.0);
}
Ok(dot_product / (norm_a * norm_b))
}
pub fn euclidean_distance(vec_a: &[f32], vec_b: &[f32]) -> Result<f32> {
if vec_a.len() != vec_b.len() {
return Err(anyhow!("Vectors must have the same dimensions"));
}
let sum_squared_diff: f32 = vec_a
.iter()
.zip(vec_b.iter())
.map(|(a, b)| (a - b).powi(2))
.sum();
Ok(sum_squared_diff.sqrt())
}
pub fn normalize_vector(vector: &mut [f32]) -> Result<()> {
let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm == 0.0 {
return Err(anyhow!("Cannot normalize zero vector"));
}
for x in vector.iter_mut() {
*x /= norm;
}
Ok(())
}
pub fn generate_dummy_vector() -> Vec<f32> {
vec![1.0]
}
}
#[cfg(test)]
mod tests {
use super::utils::*;
#[test]
fn test_cosine_similarity() {
let vec_a = vec![1.0, 2.0, 3.0];
let vec_b = vec![4.0, 5.0, 6.0];
let similarity = cosine_similarity(&vec_a, &vec_b).unwrap();
assert!(similarity > 0.0 && similarity <= 1.0);
}
#[test]
fn test_euclidean_distance() {
let vec_a = vec![1.0, 2.0, 3.0];
let vec_b = vec![4.0, 5.0, 6.0];
let distance = euclidean_distance(&vec_a, &vec_b).unwrap();
assert!(distance > 0.0);
}
#[test]
fn test_normalize_vector() {
let mut vector = vec![3.0, 4.0];
normalize_vector(&mut vector).unwrap();
let magnitude: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((magnitude - 1.0).abs() < 1e-6);
}
}