use async_trait::async_trait;
use std::collections::HashMap;
use crate::context::types::{
ContextError, ContextItem, KnowledgeItem, MemoryItem, VectorBatchOperation, VectorId,
};
use crate::context::vector_db::VectorDatabaseStats;
use crate::types::AgentId;
use serde_json::Value;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum DistanceMetric {
#[default]
Cosine,
Euclidean,
DotProduct,
}
#[async_trait]
pub trait VectorDb: Send + Sync {
async fn initialize(&self) -> Result<(), ContextError>;
async fn store_knowledge_item(
&self,
item: &KnowledgeItem,
embedding: Vec<f32>,
) -> Result<VectorId, ContextError>;
async fn store_memory_item(
&self,
agent_id: AgentId,
memory: &MemoryItem,
embedding: Vec<f32>,
) -> Result<VectorId, ContextError>;
async fn batch_store(&self, batch: VectorBatchOperation)
-> Result<Vec<VectorId>, ContextError>;
async fn search_knowledge_base(
&self,
agent_id: AgentId,
query_embedding: Vec<f32>,
limit: usize,
) -> Result<Vec<KnowledgeItem>, ContextError>;
async fn semantic_search(
&self,
agent_id: AgentId,
query_embedding: Vec<f32>,
limit: usize,
threshold: f32,
) -> Result<Vec<ContextItem>, ContextError>;
async fn advanced_search(
&self,
agent_id: AgentId,
query_embedding: Vec<f32>,
filters: HashMap<String, String>,
limit: usize,
threshold: f32,
) -> Result<Vec<super::types::VectorSearchResult>, ContextError>;
async fn delete_knowledge_item(&self, vector_id: VectorId) -> Result<(), ContextError>;
async fn batch_delete(&self, vector_ids: Vec<VectorId>) -> Result<(), ContextError>;
async fn update_metadata(
&self,
vector_id: VectorId,
metadata: HashMap<String, Value>,
) -> Result<(), ContextError>;
async fn get_stats(&self) -> Result<VectorDatabaseStats, ContextError>;
async fn create_index(&self, field_name: &str) -> Result<(), ContextError>;
async fn optimize_collection(&self) -> Result<(), ContextError>;
async fn health_check(&self) -> Result<bool, ContextError>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_distance_metric_default() {
let metric = DistanceMetric::default();
assert!(matches!(metric, DistanceMetric::Cosine));
}
}