use crate::{
domain::{
entities::Event,
repositories::{
EventRepository, SearchResult, VectorEntry, VectorSearchQuery, VectorSearchRepository,
},
value_objects::{DistanceMetric, EmbeddingVector},
},
error::{AllSourceError, Result},
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use uuid::Uuid;
#[derive(Debug, Clone)]
pub struct VectorSearchConfig {
pub default_k: usize,
pub max_k: usize,
pub default_min_similarity: f32,
pub default_metric: DistanceMetric,
pub include_source_text: bool,
}
impl Default for VectorSearchConfig {
fn default() -> Self {
Self {
default_k: 10,
max_k: 100,
default_min_similarity: 0.0,
default_metric: DistanceMetric::Cosine,
include_source_text: true,
}
}
}
#[derive(Debug, Clone)]
pub struct IndexEventRequest {
pub event_id: Uuid,
pub tenant_id: String,
pub embedding: EmbeddingVector,
pub source_text: Option<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SemanticSearchRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub query_embedding: Option<Vec<f32>>,
#[serde(default)]
pub k: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tenant_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub event_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub min_similarity: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_distance: Option<f32>,
#[serde(default)]
pub metric: Option<String>,
#[serde(default)]
pub include_events: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SemanticSearchResultItem {
pub event_id: Uuid,
pub score: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub source_text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub event: Option<EventSummary>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EventSummary {
pub id: Uuid,
pub event_type: String,
pub entity_id: String,
pub tenant_id: String,
pub timestamp: chrono::DateTime<chrono::Utc>,
#[serde(skip_serializing_if = "Option::is_none")]
pub payload: Option<serde_json::Value>,
}
impl From<&Event> for EventSummary {
fn from(event: &Event) -> Self {
Self {
id: event.id(),
event_type: event.event_type_str().to_string(),
entity_id: event.entity_id_str().to_string(),
tenant_id: event.tenant_id_str().to_string(),
timestamp: event.timestamp(),
payload: Some(event.payload().clone()),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SemanticSearchResponse {
pub results: Vec<SemanticSearchResultItem>,
pub count: usize,
pub metric: String,
pub stats: SearchStats,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchStats {
pub vectors_searched: usize,
pub search_time_us: u64,
}
pub struct VectorSearchService {
vector_repo: Arc<dyn VectorSearchRepository>,
event_repo: Option<Arc<dyn EventRepository>>,
config: VectorSearchConfig,
}
impl VectorSearchService {
pub fn new(vector_repo: Arc<dyn VectorSearchRepository>) -> Self {
Self {
vector_repo,
event_repo: None,
config: VectorSearchConfig::default(),
}
}
pub fn with_event_repo(mut self, event_repo: Arc<dyn EventRepository>) -> Self {
self.event_repo = Some(event_repo);
self
}
pub fn with_config(mut self, config: VectorSearchConfig) -> Self {
self.config = config;
self
}
pub async fn index_event(&self, request: IndexEventRequest) -> Result<()> {
if let Some(source_text) = &request.source_text {
self.vector_repo
.store_with_text(
request.event_id,
&request.embedding,
&request.tenant_id,
source_text,
)
.await
} else {
self.vector_repo
.store(request.event_id, &request.embedding, &request.tenant_id)
.await
}
}
pub async fn index_events_batch(
&self,
requests: Vec<IndexEventRequest>,
) -> Result<BatchIndexResult> {
if requests.is_empty() {
return Ok(BatchIndexResult {
indexed: 0,
failed: 0,
errors: vec![],
});
}
let entries: Vec<_> = requests
.iter()
.map(|r| (r.event_id, r.embedding.clone(), r.tenant_id.clone()))
.collect();
self.vector_repo.store_batch(&entries).await?;
Ok(BatchIndexResult {
indexed: requests.len(),
failed: 0,
errors: vec![],
})
}
pub async fn search(&self, request: SemanticSearchRequest) -> Result<SemanticSearchResponse> {
let start_time = std::time::Instant::now();
let query_embedding = request.query_embedding.ok_or_else(|| {
AllSourceError::InvalidInput("query_embedding is required".to_string())
})?;
let query_vector = EmbeddingVector::new(query_embedding)?;
let metric = match request.metric.as_deref() {
Some("cosine") | None => DistanceMetric::Cosine,
Some("euclidean") => DistanceMetric::Euclidean,
Some("dot_product") => DistanceMetric::DotProduct,
Some(m) => {
return Err(AllSourceError::InvalidInput(format!(
"Unknown metric: {m}. Supported: cosine, euclidean, dot_product"
)));
}
};
let k = request
.k
.unwrap_or(self.config.default_k)
.min(self.config.max_k);
let mut query = VectorSearchQuery::new(query_vector, k).with_metric(metric);
if let Some(tenant_id) = request.tenant_id {
query = query.with_tenant(tenant_id);
}
if let Some(event_type) = request.event_type {
query = query.with_event_type(event_type);
}
if let Some(min_sim) = request.min_similarity {
query = query.with_min_similarity(min_sim);
}
if let Some(max_dist) = request.max_distance {
query = query.with_max_distance(max_dist);
}
let search_results = self.vector_repo.search(&query).await?;
let vectors_searched = self.vector_repo.count(None).await.unwrap_or(0);
let results = if request.include_events {
self.enrich_with_events(search_results).await?
} else {
search_results
.into_iter()
.map(|r| SemanticSearchResultItem {
event_id: r.event_id,
score: r.score.value(),
source_text: r.source_text,
event: None,
})
.collect()
};
let search_time_us = start_time.elapsed().as_micros() as u64;
let count = results.len();
Ok(SemanticSearchResponse {
results,
count,
metric: format!("{metric:?}").to_lowercase(),
stats: SearchStats {
vectors_searched,
search_time_us,
},
})
}
pub async fn get_embedding(&self, event_id: Uuid) -> Result<Option<VectorEntry>> {
self.vector_repo.get_by_event_id(event_id).await
}
pub async fn delete_embedding(&self, event_id: Uuid) -> Result<bool> {
self.vector_repo.delete(event_id).await
}
pub async fn delete_tenant_embeddings(&self, tenant_id: &str) -> Result<usize> {
self.vector_repo.delete_by_tenant(tenant_id).await
}
pub async fn get_stats(&self) -> Result<IndexStats> {
let total_vectors = self.vector_repo.count(None).await?;
let dimensions = self.vector_repo.dimensions().await?;
Ok(IndexStats {
total_vectors,
dimensions,
})
}
pub async fn health_check(&self) -> Result<()> {
self.vector_repo.health_check().await
}
async fn enrich_with_events(
&self,
results: Vec<SearchResult>,
) -> Result<Vec<SemanticSearchResultItem>> {
let Some(event_repo) = &self.event_repo else {
return Ok(results
.into_iter()
.map(|r| SemanticSearchResultItem {
event_id: r.event_id,
score: r.score.value(),
source_text: r.source_text,
event: None,
})
.collect());
};
let mut enriched = Vec::with_capacity(results.len());
for result in results {
let event = event_repo.find_by_id(result.event_id).await?;
enriched.push(SemanticSearchResultItem {
event_id: result.event_id,
score: result.score.value(),
source_text: result.source_text,
event: event.as_ref().map(EventSummary::from),
});
}
Ok(enriched)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchIndexResult {
pub indexed: usize,
pub failed: usize,
pub errors: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexStats {
pub total_vectors: usize,
pub dimensions: Option<usize>,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::infrastructure::repositories::InMemoryVectorSearchRepository;
fn create_test_service() -> VectorSearchService {
let repo = Arc::new(InMemoryVectorSearchRepository::new());
VectorSearchService::new(repo)
}
fn create_test_embedding(dims: usize, seed: f32) -> EmbeddingVector {
let values: Vec<f32> = (0..dims).map(|i| (i as f32 + seed) / dims as f32).collect();
EmbeddingVector::new(values).unwrap()
}
#[tokio::test]
async fn test_index_and_search() {
let service = create_test_service();
let embeddings = vec![
(Uuid::new_v4(), vec![1.0, 0.0, 0.0_f32]),
(Uuid::new_v4(), vec![0.9, 0.1, 0.0]),
(Uuid::new_v4(), vec![0.0, 1.0, 0.0]),
];
for (id, values) in &embeddings {
service
.index_event(IndexEventRequest {
event_id: *id,
tenant_id: "tenant-1".to_string(),
embedding: EmbeddingVector::new(values.clone()).unwrap(),
source_text: None,
})
.await
.unwrap();
}
let response = service
.search(SemanticSearchRequest {
query_embedding: Some(vec![1.0, 0.0, 0.0]),
k: Some(2),
tenant_id: Some("tenant-1".to_string()),
..Default::default()
})
.await
.unwrap();
assert_eq!(response.count, 2);
assert_eq!(response.results[0].event_id, embeddings[0].0);
}
#[tokio::test]
async fn test_batch_index() {
let service = create_test_service();
let requests: Vec<_> = (0..10)
.map(|i| IndexEventRequest {
event_id: Uuid::new_v4(),
tenant_id: "tenant-1".to_string(),
embedding: create_test_embedding(384, i as f32),
source_text: Some(format!("Document {i}")),
})
.collect();
let result = service.index_events_batch(requests).await.unwrap();
assert_eq!(result.indexed, 10);
assert_eq!(result.failed, 0);
let stats = service.get_stats().await.unwrap();
assert_eq!(stats.total_vectors, 10);
assert_eq!(stats.dimensions, Some(384));
}
#[tokio::test]
async fn test_search_with_min_similarity() {
let service = create_test_service();
service
.index_event(IndexEventRequest {
event_id: Uuid::new_v4(),
tenant_id: "tenant-1".to_string(),
embedding: EmbeddingVector::new(vec![1.0, 0.0, 0.0]).unwrap(),
source_text: None,
})
.await
.unwrap();
service
.index_event(IndexEventRequest {
event_id: Uuid::new_v4(),
tenant_id: "tenant-1".to_string(),
embedding: EmbeddingVector::new(vec![0.0, 1.0, 0.0]).unwrap(),
source_text: None,
})
.await
.unwrap();
let response = service
.search(SemanticSearchRequest {
query_embedding: Some(vec![1.0, 0.0, 0.0]),
k: Some(10),
tenant_id: Some("tenant-1".to_string()),
min_similarity: Some(0.5),
..Default::default()
})
.await
.unwrap();
assert_eq!(response.count, 1);
}
#[tokio::test]
async fn test_delete_embedding() {
let service = create_test_service();
let event_id = Uuid::new_v4();
service
.index_event(IndexEventRequest {
event_id,
tenant_id: "tenant-1".to_string(),
embedding: create_test_embedding(384, 1.0),
source_text: None,
})
.await
.unwrap();
assert!(service.get_embedding(event_id).await.unwrap().is_some());
let deleted = service.delete_embedding(event_id).await.unwrap();
assert!(deleted);
assert!(service.get_embedding(event_id).await.unwrap().is_none());
}
#[tokio::test]
async fn test_health_check() {
let service = create_test_service();
assert!(service.health_check().await.is_ok());
}
#[tokio::test]
async fn test_invalid_metric() {
let service = create_test_service();
let result = service
.search(SemanticSearchRequest {
query_embedding: Some(vec![1.0, 0.0, 0.0]),
metric: Some("invalid".to_string()),
..Default::default()
})
.await;
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("Unknown metric"));
}
}
#[tokio::test]
async fn test_missing_query_embedding() {
let service = create_test_service();
let result = service
.search(SemanticSearchRequest {
query_embedding: None,
..Default::default()
})
.await;
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("query_embedding is required"));
}
}
}