1use crate::{
2 application::{
3 dto::EventDto,
4 services::{SemanticSearchRequest, VectorSearchService},
5 },
6 domain::{repositories::EventRepository, value_objects::EmbeddingVector},
7 error::{AllSourceError, Result},
8};
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use uuid::Uuid;
12
13pub struct SemanticSearchUseCase {
23 vector_service: Arc<VectorSearchService>,
24 event_repository: Arc<dyn EventRepository>,
25}
26
27impl SemanticSearchUseCase {
28 pub fn new(
29 vector_service: Arc<VectorSearchService>,
30 event_repository: Arc<dyn EventRepository>,
31 ) -> Self {
32 Self {
33 vector_service,
34 event_repository,
35 }
36 }
37
38 pub async fn execute(
40 &self,
41 request: SemanticSearchUseCaseRequest,
42 ) -> Result<SemanticSearchUseCaseResponse> {
43 let embedding = request.query_embedding.ok_or_else(|| {
45 AllSourceError::InvalidInput("query_embedding is required".to_string())
46 })?;
47
48 if embedding.is_empty() {
49 return Err(AllSourceError::InvalidInput(
50 "query_embedding cannot be empty".to_string(),
51 ));
52 }
53
54 let k = request.k.unwrap_or(10);
56 if k == 0 {
57 return Err(AllSourceError::InvalidInput(
58 "k must be greater than 0".to_string(),
59 ));
60 }
61 if k > 1000 {
62 return Err(AllSourceError::InvalidInput(
63 "k cannot exceed 1000".to_string(),
64 ));
65 }
66
67 let search_request = SemanticSearchRequest {
69 query_embedding: Some(embedding),
70 k: Some(k),
71 tenant_id: request.tenant_id.clone(),
72 event_type: request.event_type.clone(),
73 min_similarity: request.min_similarity,
74 max_distance: request.max_distance,
75 metric: request.metric.clone(),
76 include_events: request.include_events.unwrap_or(false),
77 };
78
79 let search_response = self.vector_service.search(search_request).await?;
81
82 let events = if request.include_events.unwrap_or(false) {
84 let mut events = Vec::with_capacity(search_response.results.len());
85 for result in &search_response.results {
86 if let Some(event) = self.event_repository.find_by_id(result.event_id).await? {
87 events.push(EventDto::from(&event));
88 }
89 }
90 Some(events)
91 } else {
92 None
93 };
94
95 Ok(SemanticSearchUseCaseResponse {
96 results: search_response
97 .results
98 .into_iter()
99 .map(|r| SemanticSearchResultDto {
100 event_id: r.event_id,
101 score: r.score,
102 source_text: r.source_text,
103 })
104 .collect(),
105 events,
106 count: search_response.count,
107 metric: search_response.metric,
108 vectors_searched: search_response.stats.vectors_searched,
109 search_time_us: search_response.stats.search_time_us,
110 })
111 }
112
113 pub async fn find_similar(
115 &self,
116 event_id: Uuid,
117 k: usize,
118 tenant_id: Option<String>,
119 ) -> Result<SemanticSearchUseCaseResponse> {
120 let entry = self
122 .vector_service
123 .get_embedding(event_id)
124 .await?
125 .ok_or_else(|| {
126 AllSourceError::EventNotFound(format!("No embedding found for event {}", event_id))
127 })?;
128
129 let search_request = SemanticSearchRequest {
131 query_embedding: Some(entry.embedding.values().to_vec()),
132 k: Some(k + 1), tenant_id,
134 event_type: None,
135 min_similarity: None,
136 max_distance: None,
137 metric: None,
138 include_events: false,
139 };
140
141 let mut response = self.vector_service.search(search_request).await?;
142
143 response.results.retain(|r| r.event_id != event_id);
145 response.results.truncate(k);
146 response.count = response.results.len();
147
148 Ok(SemanticSearchUseCaseResponse {
149 results: response
150 .results
151 .into_iter()
152 .map(|r| SemanticSearchResultDto {
153 event_id: r.event_id,
154 score: r.score,
155 source_text: r.source_text,
156 })
157 .collect(),
158 events: None,
159 count: response.count,
160 metric: response.metric,
161 vectors_searched: response.stats.vectors_searched,
162 search_time_us: response.stats.search_time_us,
163 })
164 }
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct SemanticSearchUseCaseRequest {
170 pub query_embedding: Option<Vec<f32>>,
172 pub k: Option<usize>,
174 pub tenant_id: Option<String>,
176 pub event_type: Option<String>,
178 pub min_similarity: Option<f32>,
180 pub max_distance: Option<f32>,
182 pub metric: Option<String>,
184 pub include_events: Option<bool>,
186}
187
188impl Default for SemanticSearchUseCaseRequest {
189 fn default() -> Self {
190 Self {
191 query_embedding: None,
192 k: Some(10),
193 tenant_id: None,
194 event_type: None,
195 min_similarity: None,
196 max_distance: None,
197 metric: None,
198 include_events: None,
199 }
200 }
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct SemanticSearchResultDto {
206 pub event_id: Uuid,
207 pub score: f32,
208 pub source_text: Option<String>,
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct SemanticSearchUseCaseResponse {
214 pub results: Vec<SemanticSearchResultDto>,
216 pub events: Option<Vec<EventDto>>,
218 pub count: usize,
220 pub metric: String,
222 pub vectors_searched: usize,
224 pub search_time_us: u64,
226}
227
228pub struct IndexEventEmbeddingUseCase {
232 vector_service: Arc<VectorSearchService>,
233}
234
235impl IndexEventEmbeddingUseCase {
236 pub fn new(vector_service: Arc<VectorSearchService>) -> Self {
237 Self { vector_service }
238 }
239
240 pub async fn execute(&self, request: IndexEventEmbeddingRequest) -> Result<()> {
242 let embedding = EmbeddingVector::new(request.embedding)?;
244
245 self.vector_service
247 .index_event(crate::application::services::IndexEventRequest {
248 event_id: request.event_id,
249 tenant_id: request.tenant_id,
250 embedding,
251 source_text: request.source_text,
252 })
253 .await
254 }
255
256 pub async fn execute_batch(
258 &self,
259 requests: Vec<IndexEventEmbeddingRequest>,
260 ) -> Result<BatchIndexResponse> {
261 let mut indexed = 0;
262 let mut failed = 0;
263 let mut errors = Vec::new();
264
265 for request in requests {
266 match EmbeddingVector::new(request.embedding) {
267 Ok(embedding) => {
268 match self
269 .vector_service
270 .index_event(crate::application::services::IndexEventRequest {
271 event_id: request.event_id,
272 tenant_id: request.tenant_id,
273 embedding,
274 source_text: request.source_text,
275 })
276 .await
277 {
278 Ok(_) => indexed += 1,
279 Err(e) => {
280 failed += 1;
281 errors.push(format!("Event {}: {}", request.event_id, e));
282 }
283 }
284 }
285 Err(e) => {
286 failed += 1;
287 errors.push(format!("Event {}: {}", request.event_id, e));
288 }
289 }
290 }
291
292 Ok(BatchIndexResponse {
293 indexed,
294 failed,
295 errors,
296 })
297 }
298}
299
300#[derive(Debug, Clone, Serialize, Deserialize)]
302pub struct IndexEventEmbeddingRequest {
303 pub event_id: Uuid,
304 pub tenant_id: String,
305 pub embedding: Vec<f32>,
306 pub source_text: Option<String>,
307}
308
309#[derive(Debug, Clone, Serialize, Deserialize)]
311pub struct BatchIndexResponse {
312 pub indexed: usize,
313 pub failed: usize,
314 pub errors: Vec<String>,
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320 use crate::{
321 domain::entities::Event, infrastructure::repositories::InMemoryVectorSearchRepository,
322 };
323 use async_trait::async_trait;
324 use chrono::Utc;
325 use serde_json::json;
326
327 struct MockEventRepository {
329 events: Vec<Event>,
330 }
331
332 impl MockEventRepository {
333 fn with_events(events: Vec<Event>) -> Self {
334 Self { events }
335 }
336 }
337
338 #[async_trait]
339 impl EventRepository for MockEventRepository {
340 async fn save(&self, _event: &Event) -> Result<()> {
341 unimplemented!()
342 }
343
344 async fn save_batch(&self, _events: &[Event]) -> Result<()> {
345 unimplemented!()
346 }
347
348 async fn find_by_id(&self, id: Uuid) -> Result<Option<Event>> {
349 Ok(self.events.iter().find(|e| e.id() == id).cloned())
350 }
351
352 async fn find_by_entity(&self, _entity_id: &str, _tenant_id: &str) -> Result<Vec<Event>> {
353 unimplemented!()
354 }
355
356 async fn find_by_type(&self, _event_type: &str, _tenant_id: &str) -> Result<Vec<Event>> {
357 unimplemented!()
358 }
359
360 async fn find_by_time_range(
361 &self,
362 _tenant_id: &str,
363 _start: chrono::DateTime<Utc>,
364 _end: chrono::DateTime<Utc>,
365 ) -> Result<Vec<Event>> {
366 unimplemented!()
367 }
368
369 async fn find_by_entity_as_of(
370 &self,
371 _entity_id: &str,
372 _tenant_id: &str,
373 _as_of: chrono::DateTime<Utc>,
374 ) -> Result<Vec<Event>> {
375 unimplemented!()
376 }
377
378 async fn count(&self, _tenant_id: &str) -> Result<usize> {
379 unimplemented!()
380 }
381
382 async fn health_check(&self) -> Result<()> {
383 Ok(())
384 }
385 }
386
387 fn create_test_use_case() -> (SemanticSearchUseCase, Arc<VectorSearchService>) {
388 let vector_repo = Arc::new(InMemoryVectorSearchRepository::new());
389 let vector_service = Arc::new(VectorSearchService::new(vector_repo));
390
391 let events = vec![
392 Event::from_strings(
393 "user.created".to_string(),
394 "user-1".to_string(),
395 "tenant-1".to_string(),
396 json!({"name": "Test"}),
397 None,
398 )
399 .unwrap(),
400 ];
401
402 let event_repo = Arc::new(MockEventRepository::with_events(events));
403
404 (
405 SemanticSearchUseCase::new(vector_service.clone(), event_repo),
406 vector_service,
407 )
408 }
409
410 #[tokio::test]
411 async fn test_semantic_search() {
412 let (use_case, vector_service) = create_test_use_case();
413
414 let id1 = Uuid::new_v4();
416 let id2 = Uuid::new_v4();
417
418 vector_service
419 .index_event(crate::application::services::IndexEventRequest {
420 event_id: id1,
421 tenant_id: "tenant-1".to_string(),
422 embedding: EmbeddingVector::new(vec![1.0, 0.0, 0.0]).unwrap(),
423 source_text: Some("first document".to_string()),
424 })
425 .await
426 .unwrap();
427
428 vector_service
429 .index_event(crate::application::services::IndexEventRequest {
430 event_id: id2,
431 tenant_id: "tenant-1".to_string(),
432 embedding: EmbeddingVector::new(vec![0.0, 1.0, 0.0]).unwrap(),
433 source_text: Some("second document".to_string()),
434 })
435 .await
436 .unwrap();
437
438 let response = use_case
440 .execute(SemanticSearchUseCaseRequest {
441 query_embedding: Some(vec![1.0, 0.0, 0.0]),
442 k: Some(2),
443 tenant_id: Some("tenant-1".to_string()),
444 ..Default::default()
445 })
446 .await
447 .unwrap();
448
449 assert_eq!(response.count, 2);
450 assert_eq!(response.results[0].event_id, id1);
451 assert!((response.results[0].score - 1.0).abs() < 1e-6);
452 }
453
454 #[tokio::test]
455 async fn test_find_similar() {
456 let (use_case, vector_service) = create_test_use_case();
457
458 let id1 = Uuid::new_v4();
460 let id2 = Uuid::new_v4();
461 let id3 = Uuid::new_v4();
462
463 vector_service
464 .index_event(crate::application::services::IndexEventRequest {
465 event_id: id1,
466 tenant_id: "tenant-1".to_string(),
467 embedding: EmbeddingVector::new(vec![1.0, 0.0, 0.0]).unwrap(),
468 source_text: None,
469 })
470 .await
471 .unwrap();
472
473 vector_service
474 .index_event(crate::application::services::IndexEventRequest {
475 event_id: id2,
476 tenant_id: "tenant-1".to_string(),
477 embedding: EmbeddingVector::new(vec![0.9, 0.1, 0.0]).unwrap(),
478 source_text: None,
479 })
480 .await
481 .unwrap();
482
483 vector_service
484 .index_event(crate::application::services::IndexEventRequest {
485 event_id: id3,
486 tenant_id: "tenant-1".to_string(),
487 embedding: EmbeddingVector::new(vec![0.0, 1.0, 0.0]).unwrap(),
488 source_text: None,
489 })
490 .await
491 .unwrap();
492
493 let response = use_case
495 .find_similar(id1, 2, Some("tenant-1".to_string()))
496 .await
497 .unwrap();
498
499 assert!(!response.results.iter().any(|r| r.event_id == id1));
501 assert!(response.results.len() <= 2);
502
503 assert_eq!(response.results[0].event_id, id2);
505 }
506
507 #[tokio::test]
508 async fn test_validation_errors() {
509 let (use_case, _) = create_test_use_case();
510
511 let result = use_case
513 .execute(SemanticSearchUseCaseRequest {
514 query_embedding: None,
515 ..Default::default()
516 })
517 .await;
518 assert!(result.is_err());
519
520 let result = use_case
522 .execute(SemanticSearchUseCaseRequest {
523 query_embedding: Some(vec![]),
524 ..Default::default()
525 })
526 .await;
527 assert!(result.is_err());
528
529 let result = use_case
531 .execute(SemanticSearchUseCaseRequest {
532 query_embedding: Some(vec![1.0, 0.0, 0.0]),
533 k: Some(0),
534 ..Default::default()
535 })
536 .await;
537 assert!(result.is_err());
538
539 let result = use_case
541 .execute(SemanticSearchUseCaseRequest {
542 query_embedding: Some(vec![1.0, 0.0, 0.0]),
543 k: Some(2000),
544 ..Default::default()
545 })
546 .await;
547 assert!(result.is_err());
548 }
549
550 #[tokio::test]
551 async fn test_index_use_case() {
552 use crate::domain::repositories::VectorSearchRepository;
553
554 let vector_repo = Arc::new(InMemoryVectorSearchRepository::new());
555 let vector_service = Arc::new(VectorSearchService::new(vector_repo.clone()));
556 let use_case = IndexEventEmbeddingUseCase::new(vector_service);
557
558 let event_id = Uuid::new_v4();
559 use_case
560 .execute(IndexEventEmbeddingRequest {
561 event_id,
562 tenant_id: "tenant-1".to_string(),
563 embedding: vec![1.0, 0.0, 0.0],
564 source_text: Some("test content".to_string()),
565 })
566 .await
567 .unwrap();
568
569 assert_eq!(
570 VectorSearchRepository::count(&*vector_repo, None)
571 .await
572 .unwrap(),
573 1
574 );
575 }
576
577 #[tokio::test]
578 async fn test_batch_index_use_case() {
579 let vector_repo = Arc::new(InMemoryVectorSearchRepository::new());
580 let vector_service = Arc::new(VectorSearchService::new(vector_repo.clone()));
581 let use_case = IndexEventEmbeddingUseCase::new(vector_service);
582
583 let requests: Vec<_> = (0..5)
584 .map(|i| IndexEventEmbeddingRequest {
585 event_id: Uuid::new_v4(),
586 tenant_id: "tenant-1".to_string(),
587 embedding: vec![i as f32, 0.0, 0.0],
588 source_text: None,
589 })
590 .collect();
591
592 let response = use_case.execute_batch(requests).await.unwrap();
593 assert_eq!(response.indexed, 5);
594 assert_eq!(response.failed, 0);
595 }
596}