1use crate::{
2 domain::{
3 entities::Event,
4 repositories::{
5 EventRepository, SearchResult, VectorEntry, VectorSearchQuery, VectorSearchRepository,
6 },
7 value_objects::{DistanceMetric, EmbeddingVector},
8 },
9 error::{AllSourceError, Result},
10};
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use uuid::Uuid;
14
15#[derive(Debug, Clone)]
17pub struct VectorSearchConfig {
18 pub default_k: usize,
20 pub max_k: usize,
22 pub default_min_similarity: f32,
24 pub default_metric: DistanceMetric,
26 pub include_source_text: bool,
28}
29
30impl Default for VectorSearchConfig {
31 fn default() -> Self {
32 Self {
33 default_k: 10,
34 max_k: 100,
35 default_min_similarity: 0.0,
36 default_metric: DistanceMetric::Cosine,
37 include_source_text: true,
38 }
39 }
40}
41
42#[derive(Debug, Clone)]
44pub struct IndexEventRequest {
45 pub event_id: Uuid,
46 pub tenant_id: String,
47 pub embedding: EmbeddingVector,
48 pub source_text: Option<String>,
49}
50
51#[derive(Debug, Clone, Default, Serialize, Deserialize)]
53pub struct SemanticSearchRequest {
54 #[serde(skip_serializing_if = "Option::is_none")]
56 pub query_embedding: Option<Vec<f32>>,
57 #[serde(default)]
59 pub k: Option<usize>,
60 #[serde(skip_serializing_if = "Option::is_none")]
62 pub tenant_id: Option<String>,
63 #[serde(skip_serializing_if = "Option::is_none")]
65 pub event_type: Option<String>,
66 #[serde(skip_serializing_if = "Option::is_none")]
68 pub min_similarity: Option<f32>,
69 #[serde(skip_serializing_if = "Option::is_none")]
71 pub max_distance: Option<f32>,
72 #[serde(default)]
74 pub metric: Option<String>,
75 #[serde(default)]
77 pub include_events: bool,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct SemanticSearchResultItem {
83 pub event_id: Uuid,
85 pub score: f32,
87 #[serde(skip_serializing_if = "Option::is_none")]
89 pub source_text: Option<String>,
90 #[serde(skip_serializing_if = "Option::is_none")]
92 pub event: Option<EventSummary>,
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct EventSummary {
98 pub id: Uuid,
99 pub event_type: String,
100 pub entity_id: String,
101 pub tenant_id: String,
102 pub timestamp: chrono::DateTime<chrono::Utc>,
103 #[serde(skip_serializing_if = "Option::is_none")]
104 pub payload: Option<serde_json::Value>,
105}
106
107impl From<&Event> for EventSummary {
108 fn from(event: &Event) -> Self {
109 Self {
110 id: event.id(),
111 event_type: event.event_type_str().to_string(),
112 entity_id: event.entity_id_str().to_string(),
113 tenant_id: event.tenant_id_str().to_string(),
114 timestamp: event.timestamp(),
115 payload: Some(event.payload().clone()),
116 }
117 }
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct SemanticSearchResponse {
123 pub results: Vec<SemanticSearchResultItem>,
125 pub count: usize,
127 pub metric: String,
129 pub stats: SearchStats,
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct SearchStats {
136 pub vectors_searched: usize,
138 pub search_time_us: u64,
140}
141
142pub struct VectorSearchService {
152 vector_repo: Arc<dyn VectorSearchRepository>,
153 event_repo: Option<Arc<dyn EventRepository>>,
154 config: VectorSearchConfig,
155}
156
157impl VectorSearchService {
158 pub fn new(vector_repo: Arc<dyn VectorSearchRepository>) -> Self {
159 Self {
160 vector_repo,
161 event_repo: None,
162 config: VectorSearchConfig::default(),
163 }
164 }
165
166 pub fn with_event_repo(mut self, event_repo: Arc<dyn EventRepository>) -> Self {
167 self.event_repo = Some(event_repo);
168 self
169 }
170
171 pub fn with_config(mut self, config: VectorSearchConfig) -> Self {
172 self.config = config;
173 self
174 }
175
176 pub async fn index_event(&self, request: IndexEventRequest) -> Result<()> {
178 if let Some(source_text) = &request.source_text {
179 self.vector_repo
180 .store_with_text(
181 request.event_id,
182 &request.embedding,
183 &request.tenant_id,
184 source_text,
185 )
186 .await
187 } else {
188 self.vector_repo
189 .store(request.event_id, &request.embedding, &request.tenant_id)
190 .await
191 }
192 }
193
194 pub async fn index_events_batch(
196 &self,
197 requests: Vec<IndexEventRequest>,
198 ) -> Result<BatchIndexResult> {
199 if requests.is_empty() {
200 return Ok(BatchIndexResult {
201 indexed: 0,
202 failed: 0,
203 errors: vec![],
204 });
205 }
206
207 let entries: Vec<_> = requests
208 .iter()
209 .map(|r| (r.event_id, r.embedding.clone(), r.tenant_id.clone()))
210 .collect();
211
212 self.vector_repo.store_batch(&entries).await?;
213
214 Ok(BatchIndexResult {
215 indexed: requests.len(),
216 failed: 0,
217 errors: vec![],
218 })
219 }
220
221 pub async fn search(&self, request: SemanticSearchRequest) -> Result<SemanticSearchResponse> {
223 let start_time = std::time::Instant::now();
224
225 let query_embedding = request.query_embedding.ok_or_else(|| {
227 AllSourceError::InvalidInput("query_embedding is required".to_string())
228 })?;
229
230 let query_vector = EmbeddingVector::new(query_embedding)?;
231
232 let metric = match request.metric.as_deref() {
234 Some("cosine") | None => DistanceMetric::Cosine,
235 Some("euclidean") => DistanceMetric::Euclidean,
236 Some("dot_product") => DistanceMetric::DotProduct,
237 Some(m) => {
238 return Err(AllSourceError::InvalidInput(format!(
239 "Unknown metric: {}. Supported: cosine, euclidean, dot_product",
240 m
241 )));
242 }
243 };
244
245 let k = request
247 .k
248 .unwrap_or(self.config.default_k)
249 .min(self.config.max_k);
250
251 let mut query = VectorSearchQuery::new(query_vector, k).with_metric(metric);
252
253 if let Some(tenant_id) = request.tenant_id {
254 query = query.with_tenant(tenant_id);
255 }
256
257 if let Some(event_type) = request.event_type {
258 query = query.with_event_type(event_type);
259 }
260
261 if let Some(min_sim) = request.min_similarity {
262 query = query.with_min_similarity(min_sim);
263 }
264
265 if let Some(max_dist) = request.max_distance {
266 query = query.with_max_distance(max_dist);
267 }
268
269 let search_results = self.vector_repo.search(&query).await?;
271 let vectors_searched = self.vector_repo.count(None).await.unwrap_or(0);
272
273 let results = if request.include_events {
275 self.enrich_with_events(search_results).await?
276 } else {
277 search_results
278 .into_iter()
279 .map(|r| SemanticSearchResultItem {
280 event_id: r.event_id,
281 score: r.score.value(),
282 source_text: r.source_text,
283 event: None,
284 })
285 .collect()
286 };
287
288 let search_time_us = start_time.elapsed().as_micros() as u64;
289 let count = results.len();
290
291 Ok(SemanticSearchResponse {
292 results,
293 count,
294 metric: format!("{:?}", metric).to_lowercase(),
295 stats: SearchStats {
296 vectors_searched,
297 search_time_us,
298 },
299 })
300 }
301
302 pub async fn get_embedding(&self, event_id: Uuid) -> Result<Option<VectorEntry>> {
304 self.vector_repo.get_by_event_id(event_id).await
305 }
306
307 pub async fn delete_embedding(&self, event_id: Uuid) -> Result<bool> {
309 self.vector_repo.delete(event_id).await
310 }
311
312 pub async fn delete_tenant_embeddings(&self, tenant_id: &str) -> Result<usize> {
314 self.vector_repo.delete_by_tenant(tenant_id).await
315 }
316
317 pub async fn get_stats(&self) -> Result<IndexStats> {
319 let total_vectors = self.vector_repo.count(None).await?;
320 let dimensions = self.vector_repo.dimensions().await?;
321
322 Ok(IndexStats {
323 total_vectors,
324 dimensions,
325 })
326 }
327
328 pub async fn health_check(&self) -> Result<()> {
330 self.vector_repo.health_check().await
331 }
332
333 async fn enrich_with_events(
335 &self,
336 results: Vec<SearchResult>,
337 ) -> Result<Vec<SemanticSearchResultItem>> {
338 let event_repo = match &self.event_repo {
339 Some(repo) => repo,
340 None => {
341 return Ok(results
343 .into_iter()
344 .map(|r| SemanticSearchResultItem {
345 event_id: r.event_id,
346 score: r.score.value(),
347 source_text: r.source_text,
348 event: None,
349 })
350 .collect());
351 }
352 };
353
354 let mut enriched = Vec::with_capacity(results.len());
355
356 for result in results {
357 let event = event_repo.find_by_id(result.event_id).await?;
358
359 enriched.push(SemanticSearchResultItem {
360 event_id: result.event_id,
361 score: result.score.value(),
362 source_text: result.source_text,
363 event: event.as_ref().map(EventSummary::from),
364 });
365 }
366
367 Ok(enriched)
368 }
369}
370
371#[derive(Debug, Clone, Serialize, Deserialize)]
373pub struct BatchIndexResult {
374 pub indexed: usize,
375 pub failed: usize,
376 pub errors: Vec<String>,
377}
378
379#[derive(Debug, Clone, Serialize, Deserialize)]
381pub struct IndexStats {
382 pub total_vectors: usize,
383 pub dimensions: Option<usize>,
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389 use crate::infrastructure::repositories::InMemoryVectorSearchRepository;
390
391 fn create_test_service() -> VectorSearchService {
392 let repo = Arc::new(InMemoryVectorSearchRepository::new());
393 VectorSearchService::new(repo)
394 }
395
396 fn create_test_embedding(dims: usize, seed: f32) -> EmbeddingVector {
397 let values: Vec<f32> = (0..dims).map(|i| (i as f32 + seed) / dims as f32).collect();
398 EmbeddingVector::new(values).unwrap()
399 }
400
401 #[tokio::test]
402 async fn test_index_and_search() {
403 let service = create_test_service();
404
405 let embeddings = vec![
407 (Uuid::new_v4(), vec![1.0, 0.0, 0.0_f32]),
408 (Uuid::new_v4(), vec![0.9, 0.1, 0.0]),
409 (Uuid::new_v4(), vec![0.0, 1.0, 0.0]),
410 ];
411
412 for (id, values) in &embeddings {
413 service
414 .index_event(IndexEventRequest {
415 event_id: *id,
416 tenant_id: "tenant-1".to_string(),
417 embedding: EmbeddingVector::new(values.clone()).unwrap(),
418 source_text: None,
419 })
420 .await
421 .unwrap();
422 }
423
424 let response = service
426 .search(SemanticSearchRequest {
427 query_embedding: Some(vec![1.0, 0.0, 0.0]),
428 k: Some(2),
429 tenant_id: Some("tenant-1".to_string()),
430 ..Default::default()
431 })
432 .await
433 .unwrap();
434
435 assert_eq!(response.count, 2);
436 assert_eq!(response.results[0].event_id, embeddings[0].0);
437 }
438
439 #[tokio::test]
440 async fn test_batch_index() {
441 let service = create_test_service();
442
443 let requests: Vec<_> = (0..10)
444 .map(|i| IndexEventRequest {
445 event_id: Uuid::new_v4(),
446 tenant_id: "tenant-1".to_string(),
447 embedding: create_test_embedding(384, i as f32),
448 source_text: Some(format!("Document {}", i)),
449 })
450 .collect();
451
452 let result = service.index_events_batch(requests).await.unwrap();
453 assert_eq!(result.indexed, 10);
454 assert_eq!(result.failed, 0);
455
456 let stats = service.get_stats().await.unwrap();
457 assert_eq!(stats.total_vectors, 10);
458 assert_eq!(stats.dimensions, Some(384));
459 }
460
461 #[tokio::test]
462 async fn test_search_with_min_similarity() {
463 let service = create_test_service();
464
465 service
467 .index_event(IndexEventRequest {
468 event_id: Uuid::new_v4(),
469 tenant_id: "tenant-1".to_string(),
470 embedding: EmbeddingVector::new(vec![1.0, 0.0, 0.0]).unwrap(),
471 source_text: None,
472 })
473 .await
474 .unwrap();
475
476 service
477 .index_event(IndexEventRequest {
478 event_id: Uuid::new_v4(),
479 tenant_id: "tenant-1".to_string(),
480 embedding: EmbeddingVector::new(vec![0.0, 1.0, 0.0]).unwrap(),
481 source_text: None,
482 })
483 .await
484 .unwrap();
485
486 let response = service
488 .search(SemanticSearchRequest {
489 query_embedding: Some(vec![1.0, 0.0, 0.0]),
490 k: Some(10),
491 tenant_id: Some("tenant-1".to_string()),
492 min_similarity: Some(0.5),
493 ..Default::default()
494 })
495 .await
496 .unwrap();
497
498 assert_eq!(response.count, 1);
500 }
501
502 #[tokio::test]
503 async fn test_delete_embedding() {
504 let service = create_test_service();
505
506 let event_id = Uuid::new_v4();
507 service
508 .index_event(IndexEventRequest {
509 event_id,
510 tenant_id: "tenant-1".to_string(),
511 embedding: create_test_embedding(384, 1.0),
512 source_text: None,
513 })
514 .await
515 .unwrap();
516
517 assert!(service.get_embedding(event_id).await.unwrap().is_some());
518
519 let deleted = service.delete_embedding(event_id).await.unwrap();
520 assert!(deleted);
521
522 assert!(service.get_embedding(event_id).await.unwrap().is_none());
523 }
524
525 #[tokio::test]
526 async fn test_health_check() {
527 let service = create_test_service();
528 assert!(service.health_check().await.is_ok());
529 }
530
531 #[tokio::test]
532 async fn test_invalid_metric() {
533 let service = create_test_service();
534
535 let result = service
536 .search(SemanticSearchRequest {
537 query_embedding: Some(vec![1.0, 0.0, 0.0]),
538 metric: Some("invalid".to_string()),
539 ..Default::default()
540 })
541 .await;
542
543 assert!(result.is_err());
544 if let Err(e) = result {
545 assert!(e.to_string().contains("Unknown metric"));
546 }
547 }
548
549 #[tokio::test]
550 async fn test_missing_query_embedding() {
551 let service = create_test_service();
552
553 let result = service
554 .search(SemanticSearchRequest {
555 query_embedding: None,
556 ..Default::default()
557 })
558 .await;
559
560 assert!(result.is_err());
561 if let Err(e) = result {
562 assert!(e.to_string().contains("query_embedding is required"));
563 }
564 }
565}