use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use super::StorageBackend;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorDocument {
pub id: String,
pub content: String,
pub embedding: Vec<f32>,
pub metadata: HashMap<String, serde_json::Value>,
}
impl VectorDocument {
pub fn new(id: impl Into<String>, content: impl Into<String>, embedding: Vec<f32>) -> Self {
Self {
id: id.into(),
content: content.into(),
embedding,
metadata: HashMap::new(),
}
}
pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
}
#[derive(Debug, Clone)]
pub struct VectorSearchResult {
pub document: VectorDocument,
pub score: f32,
}
#[derive(Debug, Clone)]
pub struct FilterCondition {
pub key: String,
pub value: serde_json::Value,
}
impl FilterCondition {
pub fn new(key: impl Into<String>, value: serde_json::Value) -> Self {
Self {
key: key.into(),
value,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct VectorFilter {
pub conditions: HashMap<String, serde_json::Value>,
pub any: Vec<FilterCondition>,
pub none: Vec<FilterCondition>,
}
impl VectorFilter {
pub fn new() -> Self {
Self::default()
}
pub fn with_condition(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.conditions.insert(key.into(), value);
self
}
pub fn with_any_condition(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.any.push(FilterCondition::new(key, value));
self
}
pub fn with_none_condition(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.none.push(FilterCondition::new(key, value));
self
}
pub fn is_empty(&self) -> bool {
self.conditions.is_empty() && self.any.is_empty() && self.none.is_empty()
}
}
#[async_trait]
pub trait VectorStore: StorageBackend {
async fn ensure_collection(&self, collection: &str, dimension: usize) -> anyhow::Result<()>;
async fn delete_collection(&self, collection: &str) -> anyhow::Result<()>;
async fn list_collections(&self) -> anyhow::Result<Vec<String>>;
async fn collection_exists(&self, collection: &str) -> anyhow::Result<bool>;
async fn upsert(&self, collection: &str, document: VectorDocument) -> anyhow::Result<()>;
async fn upsert_batch(
&self,
collection: &str,
documents: Vec<VectorDocument>,
) -> anyhow::Result<()>;
async fn get(&self, collection: &str, id: &str) -> anyhow::Result<Option<VectorDocument>>;
async fn delete(&self, collection: &str, id: &str) -> anyhow::Result<()>;
async fn delete_batch(&self, collection: &str, ids: &[String]) -> anyhow::Result<()>;
async fn search(
&self,
collection: &str,
query_embedding: &[f32],
limit: usize,
filter: Option<VectorFilter>,
) -> anyhow::Result<Vec<VectorSearchResult>>;
async fn search_with_threshold(
&self,
collection: &str,
query_embedding: &[f32],
limit: usize,
min_score: f32,
filter: Option<VectorFilter>,
) -> anyhow::Result<Vec<VectorSearchResult>> {
let results = self
.search(collection, query_embedding, limit, filter)
.await?;
Ok(results
.into_iter()
.filter(|r| r.score >= min_score)
.collect())
}
async fn count(&self, collection: &str) -> anyhow::Result<u64>;
async fn collection_info(&self, collection: &str) -> anyhow::Result<CollectionInfo>;
}
#[derive(Debug, Clone)]
pub struct CollectionInfo {
pub name: String,
pub dimension: usize,
pub count: u64,
pub distance_metric: DistanceMetric,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum DistanceMetric {
#[default]
Cosine,
Euclidean,
DotProduct,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_document_creation() {
let doc = VectorDocument::new("doc1", "Hello world", vec![0.1, 0.2, 0.3])
.with_metadata("source", serde_json::json!("test"))
.with_metadata("page", serde_json::json!(1));
assert_eq!(doc.id, "doc1");
assert_eq!(doc.content, "Hello world");
assert_eq!(doc.embedding.len(), 3);
assert_eq!(doc.metadata.get("source").unwrap(), "test");
assert_eq!(doc.metadata.get("page").unwrap(), 1);
}
#[test]
fn test_filter_creation() {
let filter = VectorFilter::new()
.with_condition("tenant_id", serde_json::json!("acme"))
.with_condition("status", serde_json::json!("active"))
.with_any_condition("visibility", serde_json::json!("team:legal"))
.with_none_condition("denyScopes", serde_json::json!("team:blocked"));
assert!(!filter.is_empty());
assert_eq!(filter.conditions.len(), 2);
assert_eq!(filter.any.len(), 1);
assert_eq!(filter.none.len(), 1);
}
#[test]
fn test_document_serialization() {
let doc = VectorDocument::new("doc1", "Test content", vec![0.1, 0.2])
.with_metadata("key", serde_json::json!("value"));
let json = serde_json::to_string(&doc).unwrap();
let parsed: VectorDocument = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.id, "doc1");
assert_eq!(parsed.embedding, vec![0.1, 0.2]);
}
}