1use crate::domain::entities::Event;
2use crate::domain::repositories::{
3 EventRepository, SearchResult, VectorEntry, VectorSearchQuery, VectorSearchRepository,
4};
5use crate::domain::value_objects::{DistanceMetric, EmbeddingVector};
6use crate::error::{AllSourceError, Result};
7use serde::{Deserialize, Serialize};
8use std::sync::Arc;
9use uuid::Uuid;
10
11#[derive(Debug, Clone)]
13pub struct VectorSearchConfig {
14 pub default_k: usize,
16 pub max_k: usize,
18 pub default_min_similarity: f32,
20 pub default_metric: DistanceMetric,
22 pub include_source_text: bool,
24}
25
26impl Default for VectorSearchConfig {
27 fn default() -> Self {
28 Self {
29 default_k: 10,
30 max_k: 100,
31 default_min_similarity: 0.0,
32 default_metric: DistanceMetric::Cosine,
33 include_source_text: true,
34 }
35 }
36}
37
38#[derive(Debug, Clone)]
40pub struct IndexEventRequest {
41 pub event_id: Uuid,
42 pub tenant_id: String,
43 pub embedding: EmbeddingVector,
44 pub source_text: Option<String>,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct SemanticSearchRequest {
50 #[serde(skip_serializing_if = "Option::is_none")]
52 pub query_embedding: Option<Vec<f32>>,
53 #[serde(default)]
55 pub k: Option<usize>,
56 #[serde(skip_serializing_if = "Option::is_none")]
58 pub tenant_id: Option<String>,
59 #[serde(skip_serializing_if = "Option::is_none")]
61 pub event_type: Option<String>,
62 #[serde(skip_serializing_if = "Option::is_none")]
64 pub min_similarity: Option<f32>,
65 #[serde(skip_serializing_if = "Option::is_none")]
67 pub max_distance: Option<f32>,
68 #[serde(default)]
70 pub metric: Option<String>,
71 #[serde(default)]
73 pub include_events: bool,
74}
75
76impl Default for SemanticSearchRequest {
77 fn default() -> Self {
78 Self {
79 query_embedding: None,
80 k: None,
81 tenant_id: None,
82 event_type: None,
83 min_similarity: None,
84 max_distance: None,
85 metric: None,
86 include_events: false,
87 }
88 }
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct SemanticSearchResultItem {
94 pub event_id: Uuid,
96 pub score: f32,
98 #[serde(skip_serializing_if = "Option::is_none")]
100 pub source_text: Option<String>,
101 #[serde(skip_serializing_if = "Option::is_none")]
103 pub event: Option<EventSummary>,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct EventSummary {
109 pub id: Uuid,
110 pub event_type: String,
111 pub entity_id: String,
112 pub tenant_id: String,
113 pub timestamp: chrono::DateTime<chrono::Utc>,
114 #[serde(skip_serializing_if = "Option::is_none")]
115 pub payload: Option<serde_json::Value>,
116}
117
118impl From<&Event> for EventSummary {
119 fn from(event: &Event) -> Self {
120 Self {
121 id: event.id(),
122 event_type: event.event_type_str().to_string(),
123 entity_id: event.entity_id_str().to_string(),
124 tenant_id: event.tenant_id_str().to_string(),
125 timestamp: event.timestamp(),
126 payload: Some(event.payload().clone()),
127 }
128 }
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct SemanticSearchResponse {
134 pub results: Vec<SemanticSearchResultItem>,
136 pub count: usize,
138 pub metric: String,
140 pub stats: SearchStats,
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct SearchStats {
147 pub vectors_searched: usize,
149 pub search_time_us: u64,
151}
152
153pub struct VectorSearchService {
163 vector_repo: Arc<dyn VectorSearchRepository>,
164 event_repo: Option<Arc<dyn EventRepository>>,
165 config: VectorSearchConfig,
166}
167
168impl VectorSearchService {
169 pub fn new(vector_repo: Arc<dyn VectorSearchRepository>) -> Self {
170 Self {
171 vector_repo,
172 event_repo: None,
173 config: VectorSearchConfig::default(),
174 }
175 }
176
177 pub fn with_event_repo(mut self, event_repo: Arc<dyn EventRepository>) -> Self {
178 self.event_repo = Some(event_repo);
179 self
180 }
181
182 pub fn with_config(mut self, config: VectorSearchConfig) -> Self {
183 self.config = config;
184 self
185 }
186
187 pub async fn index_event(&self, request: IndexEventRequest) -> Result<()> {
189 if let Some(source_text) = &request.source_text {
190 self.vector_repo
191 .store_with_text(
192 request.event_id,
193 &request.embedding,
194 &request.tenant_id,
195 source_text,
196 )
197 .await
198 } else {
199 self.vector_repo
200 .store(request.event_id, &request.embedding, &request.tenant_id)
201 .await
202 }
203 }
204
205 pub async fn index_events_batch(
207 &self,
208 requests: Vec<IndexEventRequest>,
209 ) -> Result<BatchIndexResult> {
210 if requests.is_empty() {
211 return Ok(BatchIndexResult {
212 indexed: 0,
213 failed: 0,
214 errors: vec![],
215 });
216 }
217
218 let entries: Vec<_> = requests
219 .iter()
220 .map(|r| (r.event_id, r.embedding.clone(), r.tenant_id.clone()))
221 .collect();
222
223 self.vector_repo.store_batch(&entries).await?;
224
225 Ok(BatchIndexResult {
226 indexed: requests.len(),
227 failed: 0,
228 errors: vec![],
229 })
230 }
231
232 pub async fn search(&self, request: SemanticSearchRequest) -> Result<SemanticSearchResponse> {
234 let start_time = std::time::Instant::now();
235
236 let query_embedding = request
238 .query_embedding
239 .ok_or_else(|| AllSourceError::InvalidInput("query_embedding is required".to_string()))?;
240
241 let query_vector = EmbeddingVector::new(query_embedding)?;
242
243 let metric = match request.metric.as_deref() {
245 Some("cosine") | None => DistanceMetric::Cosine,
246 Some("euclidean") => DistanceMetric::Euclidean,
247 Some("dot_product") => DistanceMetric::DotProduct,
248 Some(m) => {
249 return Err(AllSourceError::InvalidInput(format!(
250 "Unknown metric: {}. Supported: cosine, euclidean, dot_product",
251 m
252 )));
253 }
254 };
255
256 let k = request
258 .k
259 .unwrap_or(self.config.default_k)
260 .min(self.config.max_k);
261
262 let mut query = VectorSearchQuery::new(query_vector, k).with_metric(metric);
263
264 if let Some(tenant_id) = request.tenant_id {
265 query = query.with_tenant(tenant_id);
266 }
267
268 if let Some(event_type) = request.event_type {
269 query = query.with_event_type(event_type);
270 }
271
272 if let Some(min_sim) = request.min_similarity {
273 query = query.with_min_similarity(min_sim);
274 }
275
276 if let Some(max_dist) = request.max_distance {
277 query = query.with_max_distance(max_dist);
278 }
279
280 let search_results = self.vector_repo.search(&query).await?;
282 let vectors_searched = self.vector_repo.count(None).await.unwrap_or(0);
283
284 let results = if request.include_events {
286 self.enrich_with_events(search_results).await?
287 } else {
288 search_results
289 .into_iter()
290 .map(|r| SemanticSearchResultItem {
291 event_id: r.event_id,
292 score: r.score.value(),
293 source_text: r.source_text,
294 event: None,
295 })
296 .collect()
297 };
298
299 let search_time_us = start_time.elapsed().as_micros() as u64;
300 let count = results.len();
301
302 Ok(SemanticSearchResponse {
303 results,
304 count,
305 metric: format!("{:?}", metric).to_lowercase(),
306 stats: SearchStats {
307 vectors_searched,
308 search_time_us,
309 },
310 })
311 }
312
313 pub async fn get_embedding(&self, event_id: Uuid) -> Result<Option<VectorEntry>> {
315 self.vector_repo.get_by_event_id(event_id).await
316 }
317
318 pub async fn delete_embedding(&self, event_id: Uuid) -> Result<bool> {
320 self.vector_repo.delete(event_id).await
321 }
322
323 pub async fn delete_tenant_embeddings(&self, tenant_id: &str) -> Result<usize> {
325 self.vector_repo.delete_by_tenant(tenant_id).await
326 }
327
328 pub async fn get_stats(&self) -> Result<IndexStats> {
330 let total_vectors = self.vector_repo.count(None).await?;
331 let dimensions = self.vector_repo.dimensions().await?;
332
333 Ok(IndexStats {
334 total_vectors,
335 dimensions,
336 })
337 }
338
339 pub async fn health_check(&self) -> Result<()> {
341 self.vector_repo.health_check().await
342 }
343
344 async fn enrich_with_events(
346 &self,
347 results: Vec<SearchResult>,
348 ) -> Result<Vec<SemanticSearchResultItem>> {
349 let event_repo = match &self.event_repo {
350 Some(repo) => repo,
351 None => {
352 return Ok(results
354 .into_iter()
355 .map(|r| SemanticSearchResultItem {
356 event_id: r.event_id,
357 score: r.score.value(),
358 source_text: r.source_text,
359 event: None,
360 })
361 .collect());
362 }
363 };
364
365 let mut enriched = Vec::with_capacity(results.len());
366
367 for result in results {
368 let event = event_repo.find_by_id(result.event_id).await?;
369
370 enriched.push(SemanticSearchResultItem {
371 event_id: result.event_id,
372 score: result.score.value(),
373 source_text: result.source_text,
374 event: event.as_ref().map(EventSummary::from),
375 });
376 }
377
378 Ok(enriched)
379 }
380}
381
382#[derive(Debug, Clone, Serialize, Deserialize)]
384pub struct BatchIndexResult {
385 pub indexed: usize,
386 pub failed: usize,
387 pub errors: Vec<String>,
388}
389
390#[derive(Debug, Clone, Serialize, Deserialize)]
392pub struct IndexStats {
393 pub total_vectors: usize,
394 pub dimensions: Option<usize>,
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400 use crate::infrastructure::repositories::InMemoryVectorSearchRepository;
401
402 fn create_test_service() -> VectorSearchService {
403 let repo = Arc::new(InMemoryVectorSearchRepository::new());
404 VectorSearchService::new(repo)
405 }
406
407 fn create_test_embedding(dims: usize, seed: f32) -> EmbeddingVector {
408 let values: Vec<f32> = (0..dims).map(|i| (i as f32 + seed) / dims as f32).collect();
409 EmbeddingVector::new(values).unwrap()
410 }
411
412 #[tokio::test]
413 async fn test_index_and_search() {
414 let service = create_test_service();
415
416 let embeddings = vec![
418 (Uuid::new_v4(), vec![1.0, 0.0, 0.0_f32]),
419 (Uuid::new_v4(), vec![0.9, 0.1, 0.0]),
420 (Uuid::new_v4(), vec![0.0, 1.0, 0.0]),
421 ];
422
423 for (id, values) in &embeddings {
424 service
425 .index_event(IndexEventRequest {
426 event_id: *id,
427 tenant_id: "tenant-1".to_string(),
428 embedding: EmbeddingVector::new(values.clone()).unwrap(),
429 source_text: None,
430 })
431 .await
432 .unwrap();
433 }
434
435 let response = service
437 .search(SemanticSearchRequest {
438 query_embedding: Some(vec![1.0, 0.0, 0.0]),
439 k: Some(2),
440 tenant_id: Some("tenant-1".to_string()),
441 ..Default::default()
442 })
443 .await
444 .unwrap();
445
446 assert_eq!(response.count, 2);
447 assert_eq!(response.results[0].event_id, embeddings[0].0);
448 }
449
450 #[tokio::test]
451 async fn test_batch_index() {
452 let service = create_test_service();
453
454 let requests: Vec<_> = (0..10)
455 .map(|i| IndexEventRequest {
456 event_id: Uuid::new_v4(),
457 tenant_id: "tenant-1".to_string(),
458 embedding: create_test_embedding(384, i as f32),
459 source_text: Some(format!("Document {}", i)),
460 })
461 .collect();
462
463 let result = service.index_events_batch(requests).await.unwrap();
464 assert_eq!(result.indexed, 10);
465 assert_eq!(result.failed, 0);
466
467 let stats = service.get_stats().await.unwrap();
468 assert_eq!(stats.total_vectors, 10);
469 assert_eq!(stats.dimensions, Some(384));
470 }
471
472 #[tokio::test]
473 async fn test_search_with_min_similarity() {
474 let service = create_test_service();
475
476 service
478 .index_event(IndexEventRequest {
479 event_id: Uuid::new_v4(),
480 tenant_id: "tenant-1".to_string(),
481 embedding: EmbeddingVector::new(vec![1.0, 0.0, 0.0]).unwrap(),
482 source_text: None,
483 })
484 .await
485 .unwrap();
486
487 service
488 .index_event(IndexEventRequest {
489 event_id: Uuid::new_v4(),
490 tenant_id: "tenant-1".to_string(),
491 embedding: EmbeddingVector::new(vec![0.0, 1.0, 0.0]).unwrap(),
492 source_text: None,
493 })
494 .await
495 .unwrap();
496
497 let response = service
499 .search(SemanticSearchRequest {
500 query_embedding: Some(vec![1.0, 0.0, 0.0]),
501 k: Some(10),
502 tenant_id: Some("tenant-1".to_string()),
503 min_similarity: Some(0.5),
504 ..Default::default()
505 })
506 .await
507 .unwrap();
508
509 assert_eq!(response.count, 1);
511 }
512
513 #[tokio::test]
514 async fn test_delete_embedding() {
515 let service = create_test_service();
516
517 let event_id = Uuid::new_v4();
518 service
519 .index_event(IndexEventRequest {
520 event_id,
521 tenant_id: "tenant-1".to_string(),
522 embedding: create_test_embedding(384, 1.0),
523 source_text: None,
524 })
525 .await
526 .unwrap();
527
528 assert!(service.get_embedding(event_id).await.unwrap().is_some());
529
530 let deleted = service.delete_embedding(event_id).await.unwrap();
531 assert!(deleted);
532
533 assert!(service.get_embedding(event_id).await.unwrap().is_none());
534 }
535
536 #[tokio::test]
537 async fn test_health_check() {
538 let service = create_test_service();
539 assert!(service.health_check().await.is_ok());
540 }
541
542 #[tokio::test]
543 async fn test_invalid_metric() {
544 let service = create_test_service();
545
546 let result = service
547 .search(SemanticSearchRequest {
548 query_embedding: Some(vec![1.0, 0.0, 0.0]),
549 metric: Some("invalid".to_string()),
550 ..Default::default()
551 })
552 .await;
553
554 assert!(result.is_err());
555 if let Err(e) = result {
556 assert!(e.to_string().contains("Unknown metric"));
557 }
558 }
559
560 #[tokio::test]
561 async fn test_missing_query_embedding() {
562 let service = create_test_service();
563
564 let result = service
565 .search(SemanticSearchRequest {
566 query_embedding: None,
567 ..Default::default()
568 })
569 .await;
570
571 assert!(result.is_err());
572 if let Err(e) = result {
573 assert!(e.to_string().contains("query_embedding is required"));
574 }
575 }
576}