1use crate::{
2 domain::{
3 entities::Event,
4 repositories::{
5 EventRepository, SearchResult, VectorEntry, VectorSearchQuery, VectorSearchRepository,
6 },
7 value_objects::{DistanceMetric, EmbeddingVector},
8 },
9 error::{AllSourceError, Result},
10};
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use uuid::Uuid;
14
15#[derive(Debug, Clone)]
17pub struct VectorSearchConfig {
18 pub default_k: usize,
20 pub max_k: usize,
22 pub default_min_similarity: f32,
24 pub default_metric: DistanceMetric,
26 pub include_source_text: bool,
28}
29
30impl Default for VectorSearchConfig {
31 fn default() -> Self {
32 Self {
33 default_k: 10,
34 max_k: 100,
35 default_min_similarity: 0.0,
36 default_metric: DistanceMetric::Cosine,
37 include_source_text: true,
38 }
39 }
40}
41
42#[derive(Debug, Clone)]
44pub struct IndexEventRequest {
45 pub event_id: Uuid,
46 pub tenant_id: String,
47 pub embedding: EmbeddingVector,
48 pub source_text: Option<String>,
49}
50
51#[derive(Debug, Clone, Default, Serialize, Deserialize)]
53pub struct SemanticSearchRequest {
54 #[serde(skip_serializing_if = "Option::is_none")]
56 pub query_embedding: Option<Vec<f32>>,
57 #[serde(default)]
59 pub k: Option<usize>,
60 #[serde(skip_serializing_if = "Option::is_none")]
62 pub tenant_id: Option<String>,
63 #[serde(skip_serializing_if = "Option::is_none")]
65 pub event_type: Option<String>,
66 #[serde(skip_serializing_if = "Option::is_none")]
68 pub min_similarity: Option<f32>,
69 #[serde(skip_serializing_if = "Option::is_none")]
71 pub max_distance: Option<f32>,
72 #[serde(default)]
74 pub metric: Option<String>,
75 #[serde(default)]
77 pub include_events: bool,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct SemanticSearchResultItem {
83 pub event_id: Uuid,
85 pub score: f32,
87 #[serde(skip_serializing_if = "Option::is_none")]
89 pub source_text: Option<String>,
90 #[serde(skip_serializing_if = "Option::is_none")]
92 pub event: Option<EventSummary>,
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct EventSummary {
98 pub id: Uuid,
99 pub event_type: String,
100 pub entity_id: String,
101 pub tenant_id: String,
102 pub timestamp: chrono::DateTime<chrono::Utc>,
103 #[serde(skip_serializing_if = "Option::is_none")]
104 pub payload: Option<serde_json::Value>,
105}
106
107impl From<&Event> for EventSummary {
108 fn from(event: &Event) -> Self {
109 Self {
110 id: event.id(),
111 event_type: event.event_type_str().to_string(),
112 entity_id: event.entity_id_str().to_string(),
113 tenant_id: event.tenant_id_str().to_string(),
114 timestamp: event.timestamp(),
115 payload: Some(event.payload().clone()),
116 }
117 }
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct SemanticSearchResponse {
123 pub results: Vec<SemanticSearchResultItem>,
125 pub count: usize,
127 pub metric: String,
129 pub stats: SearchStats,
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct SearchStats {
136 pub vectors_searched: usize,
138 pub search_time_us: u64,
140}
141
142pub struct VectorSearchService {
152 vector_repo: Arc<dyn VectorSearchRepository>,
153 event_repo: Option<Arc<dyn EventRepository>>,
154 config: VectorSearchConfig,
155}
156
157impl VectorSearchService {
158 pub fn new(vector_repo: Arc<dyn VectorSearchRepository>) -> Self {
159 Self {
160 vector_repo,
161 event_repo: None,
162 config: VectorSearchConfig::default(),
163 }
164 }
165
166 pub fn with_event_repo(mut self, event_repo: Arc<dyn EventRepository>) -> Self {
167 self.event_repo = Some(event_repo);
168 self
169 }
170
171 pub fn with_config(mut self, config: VectorSearchConfig) -> Self {
172 self.config = config;
173 self
174 }
175
176 pub async fn index_event(&self, request: IndexEventRequest) -> Result<()> {
178 if let Some(source_text) = &request.source_text {
179 self.vector_repo
180 .store_with_text(
181 request.event_id,
182 &request.embedding,
183 &request.tenant_id,
184 source_text,
185 )
186 .await
187 } else {
188 self.vector_repo
189 .store(request.event_id, &request.embedding, &request.tenant_id)
190 .await
191 }
192 }
193
194 pub async fn index_events_batch(
196 &self,
197 requests: Vec<IndexEventRequest>,
198 ) -> Result<BatchIndexResult> {
199 if requests.is_empty() {
200 return Ok(BatchIndexResult {
201 indexed: 0,
202 failed: 0,
203 errors: vec![],
204 });
205 }
206
207 let entries: Vec<_> = requests
208 .iter()
209 .map(|r| (r.event_id, r.embedding.clone(), r.tenant_id.clone()))
210 .collect();
211
212 self.vector_repo.store_batch(&entries).await?;
213
214 Ok(BatchIndexResult {
215 indexed: requests.len(),
216 failed: 0,
217 errors: vec![],
218 })
219 }
220
221 pub async fn search(&self, request: SemanticSearchRequest) -> Result<SemanticSearchResponse> {
223 let start_time = std::time::Instant::now();
224
225 let query_embedding = request.query_embedding.ok_or_else(|| {
227 AllSourceError::InvalidInput("query_embedding is required".to_string())
228 })?;
229
230 let query_vector = EmbeddingVector::new(query_embedding)?;
231
232 let metric = match request.metric.as_deref() {
234 Some("cosine") | None => DistanceMetric::Cosine,
235 Some("euclidean") => DistanceMetric::Euclidean,
236 Some("dot_product") => DistanceMetric::DotProduct,
237 Some(m) => {
238 return Err(AllSourceError::InvalidInput(format!(
239 "Unknown metric: {m}. Supported: cosine, euclidean, dot_product"
240 )));
241 }
242 };
243
244 let k = request
246 .k
247 .unwrap_or(self.config.default_k)
248 .min(self.config.max_k);
249
250 let mut query = VectorSearchQuery::new(query_vector, k).with_metric(metric);
251
252 if let Some(tenant_id) = request.tenant_id {
253 query = query.with_tenant(tenant_id);
254 }
255
256 if let Some(event_type) = request.event_type {
257 query = query.with_event_type(event_type);
258 }
259
260 if let Some(min_sim) = request.min_similarity {
261 query = query.with_min_similarity(min_sim);
262 }
263
264 if let Some(max_dist) = request.max_distance {
265 query = query.with_max_distance(max_dist);
266 }
267
268 let search_results = self.vector_repo.search(&query).await?;
270 let vectors_searched = self.vector_repo.count(None).await.unwrap_or(0);
271
272 let results = if request.include_events {
274 self.enrich_with_events(search_results).await?
275 } else {
276 search_results
277 .into_iter()
278 .map(|r| SemanticSearchResultItem {
279 event_id: r.event_id,
280 score: r.score.value(),
281 source_text: r.source_text,
282 event: None,
283 })
284 .collect()
285 };
286
287 let search_time_us = start_time.elapsed().as_micros() as u64;
288 let count = results.len();
289
290 Ok(SemanticSearchResponse {
291 results,
292 count,
293 metric: format!("{metric:?}").to_lowercase(),
294 stats: SearchStats {
295 vectors_searched,
296 search_time_us,
297 },
298 })
299 }
300
301 pub async fn get_embedding(&self, event_id: Uuid) -> Result<Option<VectorEntry>> {
303 self.vector_repo.get_by_event_id(event_id).await
304 }
305
306 pub async fn delete_embedding(&self, event_id: Uuid) -> Result<bool> {
308 self.vector_repo.delete(event_id).await
309 }
310
311 pub async fn delete_tenant_embeddings(&self, tenant_id: &str) -> Result<usize> {
313 self.vector_repo.delete_by_tenant(tenant_id).await
314 }
315
316 pub async fn get_stats(&self) -> Result<IndexStats> {
318 let total_vectors = self.vector_repo.count(None).await?;
319 let dimensions = self.vector_repo.dimensions().await?;
320
321 Ok(IndexStats {
322 total_vectors,
323 dimensions,
324 })
325 }
326
327 pub async fn health_check(&self) -> Result<()> {
329 self.vector_repo.health_check().await
330 }
331
332 async fn enrich_with_events(
334 &self,
335 results: Vec<SearchResult>,
336 ) -> Result<Vec<SemanticSearchResultItem>> {
337 let Some(event_repo) = &self.event_repo else {
338 return Ok(results
340 .into_iter()
341 .map(|r| SemanticSearchResultItem {
342 event_id: r.event_id,
343 score: r.score.value(),
344 source_text: r.source_text,
345 event: None,
346 })
347 .collect());
348 };
349
350 let mut enriched = Vec::with_capacity(results.len());
351
352 for result in results {
353 let event = event_repo.find_by_id(result.event_id).await?;
354
355 enriched.push(SemanticSearchResultItem {
356 event_id: result.event_id,
357 score: result.score.value(),
358 source_text: result.source_text,
359 event: event.as_ref().map(EventSummary::from),
360 });
361 }
362
363 Ok(enriched)
364 }
365}
366
367#[derive(Debug, Clone, Serialize, Deserialize)]
369pub struct BatchIndexResult {
370 pub indexed: usize,
371 pub failed: usize,
372 pub errors: Vec<String>,
373}
374
375#[derive(Debug, Clone, Serialize, Deserialize)]
377pub struct IndexStats {
378 pub total_vectors: usize,
379 pub dimensions: Option<usize>,
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385 use crate::infrastructure::repositories::InMemoryVectorSearchRepository;
386
387 fn create_test_service() -> VectorSearchService {
388 let repo = Arc::new(InMemoryVectorSearchRepository::new());
389 VectorSearchService::new(repo)
390 }
391
392 fn create_test_embedding(dims: usize, seed: f32) -> EmbeddingVector {
393 let values: Vec<f32> = (0..dims).map(|i| (i as f32 + seed) / dims as f32).collect();
394 EmbeddingVector::new(values).unwrap()
395 }
396
397 #[tokio::test]
398 async fn test_index_and_search() {
399 let service = create_test_service();
400
401 let embeddings = vec![
403 (Uuid::new_v4(), vec![1.0, 0.0, 0.0_f32]),
404 (Uuid::new_v4(), vec![0.9, 0.1, 0.0]),
405 (Uuid::new_v4(), vec![0.0, 1.0, 0.0]),
406 ];
407
408 for (id, values) in &embeddings {
409 service
410 .index_event(IndexEventRequest {
411 event_id: *id,
412 tenant_id: "tenant-1".to_string(),
413 embedding: EmbeddingVector::new(values.clone()).unwrap(),
414 source_text: None,
415 })
416 .await
417 .unwrap();
418 }
419
420 let response = service
422 .search(SemanticSearchRequest {
423 query_embedding: Some(vec![1.0, 0.0, 0.0]),
424 k: Some(2),
425 tenant_id: Some("tenant-1".to_string()),
426 ..Default::default()
427 })
428 .await
429 .unwrap();
430
431 assert_eq!(response.count, 2);
432 assert_eq!(response.results[0].event_id, embeddings[0].0);
433 }
434
435 #[tokio::test]
436 async fn test_batch_index() {
437 let service = create_test_service();
438
439 let requests: Vec<_> = (0..10)
440 .map(|i| IndexEventRequest {
441 event_id: Uuid::new_v4(),
442 tenant_id: "tenant-1".to_string(),
443 embedding: create_test_embedding(384, i as f32),
444 source_text: Some(format!("Document {i}")),
445 })
446 .collect();
447
448 let result = service.index_events_batch(requests).await.unwrap();
449 assert_eq!(result.indexed, 10);
450 assert_eq!(result.failed, 0);
451
452 let stats = service.get_stats().await.unwrap();
453 assert_eq!(stats.total_vectors, 10);
454 assert_eq!(stats.dimensions, Some(384));
455 }
456
457 #[tokio::test]
458 async fn test_search_with_min_similarity() {
459 let service = create_test_service();
460
461 service
463 .index_event(IndexEventRequest {
464 event_id: Uuid::new_v4(),
465 tenant_id: "tenant-1".to_string(),
466 embedding: EmbeddingVector::new(vec![1.0, 0.0, 0.0]).unwrap(),
467 source_text: None,
468 })
469 .await
470 .unwrap();
471
472 service
473 .index_event(IndexEventRequest {
474 event_id: Uuid::new_v4(),
475 tenant_id: "tenant-1".to_string(),
476 embedding: EmbeddingVector::new(vec![0.0, 1.0, 0.0]).unwrap(),
477 source_text: None,
478 })
479 .await
480 .unwrap();
481
482 let response = service
484 .search(SemanticSearchRequest {
485 query_embedding: Some(vec![1.0, 0.0, 0.0]),
486 k: Some(10),
487 tenant_id: Some("tenant-1".to_string()),
488 min_similarity: Some(0.5),
489 ..Default::default()
490 })
491 .await
492 .unwrap();
493
494 assert_eq!(response.count, 1);
496 }
497
498 #[tokio::test]
499 async fn test_delete_embedding() {
500 let service = create_test_service();
501
502 let event_id = Uuid::new_v4();
503 service
504 .index_event(IndexEventRequest {
505 event_id,
506 tenant_id: "tenant-1".to_string(),
507 embedding: create_test_embedding(384, 1.0),
508 source_text: None,
509 })
510 .await
511 .unwrap();
512
513 assert!(service.get_embedding(event_id).await.unwrap().is_some());
514
515 let deleted = service.delete_embedding(event_id).await.unwrap();
516 assert!(deleted);
517
518 assert!(service.get_embedding(event_id).await.unwrap().is_none());
519 }
520
521 #[tokio::test]
522 async fn test_health_check() {
523 let service = create_test_service();
524 assert!(service.health_check().await.is_ok());
525 }
526
527 #[tokio::test]
528 async fn test_invalid_metric() {
529 let service = create_test_service();
530
531 let result = service
532 .search(SemanticSearchRequest {
533 query_embedding: Some(vec![1.0, 0.0, 0.0]),
534 metric: Some("invalid".to_string()),
535 ..Default::default()
536 })
537 .await;
538
539 assert!(result.is_err());
540 if let Err(e) = result {
541 assert!(e.to_string().contains("Unknown metric"));
542 }
543 }
544
545 #[tokio::test]
546 async fn test_missing_query_embedding() {
547 let service = create_test_service();
548
549 let result = service
550 .search(SemanticSearchRequest {
551 query_embedding: None,
552 ..Default::default()
553 })
554 .await;
555
556 assert!(result.is_err());
557 if let Err(e) = result {
558 assert!(e.to_string().contains("query_embedding is required"));
559 }
560 }
561}