use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::sync::atomic::{AtomicU64, Ordering};
use std::path::Path;
use crate::{VectorIndexWrapper, VectorIndex, Vector, SearchResult, SimilarityMetric, EmbeddingFunction, PersistenceError, save_collection_to_file, load_collection_from_file};
use crate::errors::{VectorLiteError, VectorLiteResult};
pub struct VectorLiteClient {
collections: HashMap<String, CollectionRef>,
embedding_function: Arc<dyn EmbeddingFunction>,
}
pub struct Settings {}
impl VectorLiteClient {
pub fn new(embedding_function: Box<dyn EmbeddingFunction>) -> Self {
Self {
collections: HashMap::new(),
embedding_function: Arc::from(embedding_function),
}
}
pub fn create_collection(&mut self, name: &str, index_type: IndexType) -> VectorLiteResult<()> {
if self.collections.contains_key(name) {
return Err(VectorLiteError::CollectionAlreadyExists { name: name.to_string() });
}
let dimension = self.embedding_function.dimension();
let index = match index_type {
IndexType::Flat => VectorIndexWrapper::Flat(crate::FlatIndex::new(dimension, Vec::new())),
IndexType::HNSW => VectorIndexWrapper::HNSW(Box::new(crate::HNSWIndex::new(dimension))),
};
let collection = Collection {
name: name.to_string(),
index: Arc::new(RwLock::new(index)),
next_id: Arc::new(AtomicU64::new(0)),
};
self.collections.insert(name.to_string(), Arc::new(collection));
Ok(())
}
pub fn get_collection(&self, name: &str) -> Option<&CollectionRef> {
self.collections.get(name)
}
pub fn list_collections(&self) -> Vec<String> {
self.collections.keys().cloned().collect()
}
pub fn delete_collection(&mut self, name: &str) -> VectorLiteResult<()> {
if self.collections.remove(name).is_some() {
Ok(())
} else {
Err(VectorLiteError::CollectionNotFound { name: name.to_string() })
}
}
pub fn has_collection(&self, name: &str) -> bool {
self.collections.contains_key(name)
}
pub fn add_text_to_collection(&self, collection_name: &str, text: &str, metadata: Option<serde_json::Value>) -> VectorLiteResult<u64> {
let collection = self.collections.get(collection_name)
.ok_or_else(|| VectorLiteError::CollectionNotFound { name: collection_name.to_string() })?;
collection.add_text_with_metadata(text, metadata, self.embedding_function.as_ref())
}
pub fn search_text_in_collection(&self, collection_name: &str, query_text: &str, k: usize, similarity_metric: SimilarityMetric) -> VectorLiteResult<Vec<SearchResult>> {
let collection = self.collections.get(collection_name)
.ok_or_else(|| VectorLiteError::CollectionNotFound { name: collection_name.to_string() })?;
collection.search_text(query_text, k, similarity_metric, self.embedding_function.as_ref())
}
pub fn delete_from_collection(&self, collection_name: &str, id: u64) -> VectorLiteResult<()> {
let collection = self.collections.get(collection_name)
.ok_or_else(|| VectorLiteError::CollectionNotFound { name: collection_name.to_string() })?;
collection.delete(id)
}
pub fn get_vector_from_collection(&self, collection_name: &str, id: u64) -> VectorLiteResult<Option<Vector>> {
let collection = self.collections.get(collection_name)
.ok_or_else(|| VectorLiteError::CollectionNotFound { name: collection_name.to_string() })?;
collection.get_vector(id)
}
pub fn get_collection_info(&self, collection_name: &str) -> VectorLiteResult<CollectionInfo> {
let collection = self.collections.get(collection_name)
.ok_or_else(|| VectorLiteError::CollectionNotFound { name: collection_name.to_string() })?;
collection.get_info()
}
pub fn add_collection(&mut self, collection: Collection) -> VectorLiteResult<()> {
let name = collection.name().to_string();
if self.collections.contains_key(&name) {
return Err(VectorLiteError::CollectionAlreadyExists { name });
}
self.collections.insert(name, Arc::new(collection));
Ok(())
}
}
#[derive(Debug, Clone, Copy)]
pub enum IndexType {
Flat,
HNSW,
}
pub struct Collection {
name: String,
index: Arc<RwLock<VectorIndexWrapper>>,
next_id: Arc<AtomicU64>,
}
type CollectionRef = Arc<Collection>;
#[derive(Debug, Clone, serde::Serialize)]
pub struct CollectionInfo {
pub name: String,
pub count: usize,
pub is_empty: bool,
pub dimension: usize,
}
impl std::fmt::Debug for Collection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Collection")
.field("name", &self.name)
.field("next_id", &self.next_id.load(Ordering::Relaxed))
.finish()
}
}
impl Collection {
pub fn new(name: String, index: VectorIndexWrapper) -> Self {
let next_id = match &index {
VectorIndexWrapper::Flat(flat_index) => {
flat_index.max_id()
.map(|max_id| max_id + 1)
.unwrap_or(0)
}
VectorIndexWrapper::HNSW(hnsw_index) => {
hnsw_index.max_id()
.map(|max_id| max_id + 1)
.unwrap_or(0)
}
};
Self {
name,
index: Arc::new(RwLock::new(index)),
next_id: Arc::new(AtomicU64::new(next_id)),
}
}
pub fn add_text(&self, text: &str, embedding_function: &dyn EmbeddingFunction) -> VectorLiteResult<u64> {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let embedding = embedding_function.generate_embedding(text)?;
let vector = Vector {
id,
values: embedding,
text: text.to_string(),
metadata: None
};
let vector_dimension = vector.values.len();
let vector_id = vector.id;
let mut index = self.index.write().map_err(|_| VectorLiteError::LockError("Failed to acquire write lock for add_text".to_string()))?;
index.add(vector).map_err(|e| {
if e.contains("dimension") {
VectorLiteError::DimensionMismatch {
expected: index.dimension(),
actual: vector_dimension
}
} else if e.contains("already exists") {
VectorLiteError::DuplicateVectorId { id: vector_id }
} else {
VectorLiteError::InternalError(e)
}
})?;
Ok(id)
}
pub fn add_text_with_metadata(&self, text: &str, metadata: Option<serde_json::Value>, embedding_function: &dyn EmbeddingFunction) -> VectorLiteResult<u64> {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let embedding = embedding_function.generate_embedding(text)?;
let vector = Vector {
id,
values: embedding,
text: text.to_string(),
metadata
};
let vector_dimension = vector.values.len();
let vector_id = vector.id;
let mut index = self.index.write().map_err(|_| VectorLiteError::LockError("Failed to acquire write lock for add_text_with_metadata".to_string()))?;
index.add(vector).map_err(|e| {
if e.contains("dimension") {
VectorLiteError::DimensionMismatch {
expected: index.dimension(),
actual: vector_dimension
}
} else if e.contains("already exists") {
VectorLiteError::DuplicateVectorId { id: vector_id }
} else {
VectorLiteError::InternalError(e)
}
})?;
Ok(id)
}
pub fn delete(&self, id: u64) -> VectorLiteResult<()> {
let mut index = self.index.write().map_err(|_| VectorLiteError::LockError("Failed to acquire write lock for delete".to_string()))?;
index.delete(id).map_err(|e| {
if e.contains("does not exist") {
VectorLiteError::VectorNotFound { id }
} else {
VectorLiteError::InternalError(e)
}
})
}
pub fn search_text(&self, query_text: &str, k: usize, similarity_metric: SimilarityMetric, embedding_function: &dyn EmbeddingFunction) -> VectorLiteResult<Vec<SearchResult>> {
let query_embedding = embedding_function.generate_embedding(query_text)?;
let index = self.index.read().map_err(|_| VectorLiteError::LockError("Failed to acquire read lock for search_text".to_string()))?;
Ok(index.search(&query_embedding, k, similarity_metric))
}
pub fn get_vector(&self, id: u64) -> VectorLiteResult<Option<Vector>> {
let index = self.index.read().map_err(|_| VectorLiteError::LockError("Failed to acquire read lock for get_vector".to_string()))?;
Ok(index.get_vector(id).cloned())
}
pub fn get_info(&self) -> VectorLiteResult<CollectionInfo> {
let index = self.index.read().map_err(|_| VectorLiteError::LockError("Failed to acquire read lock for get_info".to_string()))?;
Ok(CollectionInfo {
name: self.name.clone(),
count: index.len(),
is_empty: index.is_empty(),
dimension: index.dimension(),
})
}
pub fn name(&self) -> &str {
&self.name
}
pub fn next_id(&self) -> u64 {
self.next_id.load(Ordering::Relaxed)
}
pub fn index_read(&self) -> Result<std::sync::RwLockReadGuard<'_, VectorIndexWrapper>, String> {
self.index.read().map_err(|_| "Failed to acquire read lock".to_string())
}
pub fn save_to_file(&self, path: &Path) -> Result<(), PersistenceError> {
save_collection_to_file(self, path)
}
pub fn load_from_file(path: &Path) -> Result<Self, PersistenceError> {
load_collection_from_file(path)
}
}
#[cfg(test)]
mod tests {
use super::*;
struct MockEmbeddingFunction {
dimension: usize,
}
impl MockEmbeddingFunction {
fn new(dimension: usize) -> Self {
Self { dimension }
}
}
impl EmbeddingFunction for MockEmbeddingFunction {
fn generate_embedding(&self, _text: &str) -> crate::embeddings::Result<Vec<f64>> {
Ok(vec![1.0; self.dimension])
}
fn dimension(&self) -> usize {
self.dimension
}
}
#[test]
fn test_client_creation() {
let embedding_fn = MockEmbeddingFunction::new(3);
let client = VectorLiteClient::new(Box::new(embedding_fn));
assert!(client.collections.is_empty());
assert!(client.list_collections().is_empty());
}
#[test]
fn test_create_collection() {
let embedding_fn = MockEmbeddingFunction::new(3);
let mut client = VectorLiteClient::new(Box::new(embedding_fn));
let result = client.create_collection("test_collection", IndexType::Flat);
assert!(result.is_ok());
assert!(client.has_collection("test_collection"));
assert_eq!(client.list_collections(), vec!["test_collection"]);
}
#[test]
fn test_create_duplicate_collection() {
let embedding_fn = MockEmbeddingFunction::new(3);
let mut client = VectorLiteClient::new(Box::new(embedding_fn));
client.create_collection("test_collection", IndexType::Flat).unwrap();
let result = client.create_collection("test_collection", IndexType::Flat);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), VectorLiteError::CollectionAlreadyExists { .. }));
}
#[test]
fn test_get_collection() {
let embedding_fn = MockEmbeddingFunction::new(3);
let mut client = VectorLiteClient::new(Box::new(embedding_fn));
client.create_collection("test_collection", IndexType::Flat).unwrap();
let collection = client.get_collection("test_collection");
assert!(collection.is_some());
assert_eq!(collection.unwrap().name(), "test_collection");
let collection = client.get_collection("non_existent");
assert!(collection.is_none());
}
#[test]
fn test_delete_collection() {
let embedding_fn = MockEmbeddingFunction::new(3);
let mut client = VectorLiteClient::new(Box::new(embedding_fn));
client.create_collection("test_collection", IndexType::Flat).unwrap();
assert!(client.has_collection("test_collection"));
let result = client.delete_collection("test_collection");
assert!(result.is_ok());
assert!(!client.has_collection("test_collection"));
let result = client.delete_collection("non_existent");
assert!(result.is_err());
}
#[test]
fn test_add_text_to_collection() {
let embedding_fn = MockEmbeddingFunction::new(3);
let mut client = VectorLiteClient::new(Box::new(embedding_fn));
client.create_collection("test_collection", IndexType::Flat).unwrap();
let result = client.add_text_to_collection("test_collection", "Hello world", None);
assert!(result.is_ok());
let id = result.unwrap();
assert_eq!(id, 0);
let result = client.add_text_to_collection("test_collection", "Another text", None);
assert!(result.is_ok());
let id = result.unwrap();
assert_eq!(id, 1);
let info = client.get_collection_info("test_collection").unwrap();
assert_eq!(info.count, 2);
}
#[test]
fn test_add_text_to_nonexistent_collection() {
let embedding_fn = MockEmbeddingFunction::new(3);
let client = VectorLiteClient::new(Box::new(embedding_fn));
let result = client.add_text_to_collection("non_existent", "Hello world", None);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), VectorLiteError::CollectionNotFound { .. }));
}
#[test]
fn test_collection_operations() {
let embedding_fn = MockEmbeddingFunction::new(3);
let mut client = VectorLiteClient::new(Box::new(embedding_fn));
client.create_collection("test_collection", IndexType::Flat).unwrap();
let info = client.get_collection_info("test_collection").unwrap();
assert!(info.is_empty);
assert_eq!(info.count, 0);
assert_eq!(info.name, "test_collection");
let id = client.add_text_to_collection("test_collection", "Hello world", None).unwrap();
assert_eq!(id, 0);
let info = client.get_collection_info("test_collection").unwrap();
assert!(!info.is_empty);
assert_eq!(info.count, 1);
let id = client.add_text_to_collection("test_collection", "Another text", None).unwrap();
assert_eq!(id, 1);
let info = client.get_collection_info("test_collection").unwrap();
assert_eq!(info.count, 2);
let results = client.search_text_in_collection("test_collection", "Hello", 1, SimilarityMetric::Cosine).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, 0);
let vector = client.get_vector_from_collection("test_collection", 0).unwrap();
assert!(vector.is_some());
assert_eq!(vector.unwrap().id, 0);
client.delete_from_collection("test_collection", 0).unwrap();
let info = client.get_collection_info("test_collection").unwrap();
assert_eq!(info.count, 1);
let vector = client.get_vector_from_collection("test_collection", 0).unwrap();
assert!(vector.is_none());
}
#[test]
fn test_collection_with_hnsw_index() {
let embedding_fn = MockEmbeddingFunction::new(3);
let mut client = VectorLiteClient::new(Box::new(embedding_fn));
client.create_collection("hnsw_collection", IndexType::HNSW).unwrap();
let id1 = client.add_text_to_collection("hnsw_collection", "First document", None).unwrap();
let id2 = client.add_text_to_collection("hnsw_collection", "Second document", None).unwrap();
assert_eq!(id1, 0);
assert_eq!(id2, 1);
let info = client.get_collection_info("hnsw_collection").unwrap();
assert_eq!(info.count, 2);
let results = client.search_text_in_collection("hnsw_collection", "First", 1, SimilarityMetric::Cosine).unwrap();
assert_eq!(results.len(), 1);
}
#[test]
fn test_collection_save_and_load() {
let embedding_fn = MockEmbeddingFunction::new(3);
let mut client = VectorLiteClient::new(Box::new(embedding_fn));
client.create_collection("test_collection", IndexType::Flat).unwrap();
client.add_text_to_collection("test_collection", "Hello world", None).unwrap();
client.add_text_to_collection("test_collection", "Another text", None).unwrap();
let collection = client.get_collection("test_collection").unwrap();
let temp_dir = tempfile::TempDir::new().unwrap();
let file_path = temp_dir.path().join("test_collection.vlc");
collection.save_to_file(&file_path).unwrap();
assert!(file_path.exists());
let loaded_collection = Collection::load_from_file(&file_path).unwrap();
assert_eq!(loaded_collection.name(), "test_collection");
let info = loaded_collection.get_info().unwrap();
assert_eq!(info.count, 2);
assert_eq!(info.dimension, 3);
assert!(!info.is_empty);
let test_embedding_fn = MockEmbeddingFunction::new(3);
let results = loaded_collection.search_text("Hello", 2, SimilarityMetric::Cosine, &test_embedding_fn).unwrap();
assert_eq!(results.len(), 2);
}
#[test]
fn test_collection_save_and_load_hnsw() {
let embedding_fn = MockEmbeddingFunction::new(3);
let mut client = VectorLiteClient::new(Box::new(embedding_fn));
client.create_collection("test_hnsw_collection", IndexType::HNSW).unwrap();
client.add_text_to_collection("test_hnsw_collection", "First document", None).unwrap();
client.add_text_to_collection("test_hnsw_collection", "Second document", None).unwrap();
let collection = client.get_collection("test_hnsw_collection").unwrap();
let info = collection.get_info().unwrap();
assert_eq!(info.count, 2);
assert_eq!(info.dimension, 3);
let test_embedding_fn = MockEmbeddingFunction::new(3);
let results = collection.search_text("First", 1, SimilarityMetric::Cosine, &test_embedding_fn).unwrap();
assert_eq!(results.len(), 1);
let temp_dir = tempfile::TempDir::new().unwrap();
let file_path = temp_dir.path().join("test_hnsw_collection.vlc");
collection.save_to_file(&file_path).unwrap();
assert!(file_path.exists());
let loaded_collection = Collection::load_from_file(&file_path).unwrap();
assert_eq!(loaded_collection.name(), "test_hnsw_collection");
let info = loaded_collection.get_info().unwrap();
assert_eq!(info.count, 2);
assert_eq!(info.dimension, 3);
assert!(!info.is_empty);
let results = loaded_collection.search_text("First", 1, SimilarityMetric::Cosine, &test_embedding_fn).unwrap();
assert_eq!(results.len(), 1);
}
#[test]
fn test_collection_save_nonexistent_directory() {
let embedding_fn = MockEmbeddingFunction::new(3);
let mut client = VectorLiteClient::new(Box::new(embedding_fn));
client.create_collection("test_collection", IndexType::Flat).unwrap();
client.add_text_to_collection("test_collection", "Hello world", None).unwrap();
let collection = client.get_collection("test_collection").unwrap();
let temp_dir = tempfile::TempDir::new().unwrap();
let file_path = temp_dir.path().join("nonexistent").join("test_collection.vlc");
let result = collection.save_to_file(&file_path);
assert!(result.is_ok());
assert!(file_path.exists());
}
#[test]
fn test_collection_load_nonexistent_file() {
let temp_dir = tempfile::TempDir::new().unwrap();
let file_path = temp_dir.path().join("nonexistent.vlc");
let result = Collection::load_from_file(&file_path);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), PersistenceError::Io(_)));
}
#[test]
fn test_collection_load_invalid_json() {
let temp_dir = tempfile::TempDir::new().unwrap();
let file_path = temp_dir.path().join("invalid.vlc");
std::fs::write(&file_path, "invalid json content").unwrap();
let result = Collection::load_from_file(&file_path);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), PersistenceError::Serialization(_)));
}
}