Skip to main content

allsource_core/application/use_cases/
semantic_search.rs

1use crate::{
2    application::{
3        dto::EventDto,
4        services::{SemanticSearchRequest, VectorSearchService},
5    },
6    domain::{repositories::EventRepository, value_objects::EmbeddingVector},
7    error::{AllSourceError, Result},
8};
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use uuid::Uuid;
12
13/// Use Case: Semantic Search
14///
15/// This use case handles semantic (vector-based) search operations.
16///
17/// Responsibilities:
18/// - Validate search parameters
19/// - Execute vector similarity search
20/// - Optionally enrich results with full event data
21/// - Apply filters and pagination
22pub struct SemanticSearchUseCase {
23    vector_service: Arc<VectorSearchService>,
24    event_repository: Arc<dyn EventRepository>,
25}
26
27impl SemanticSearchUseCase {
28    pub fn new(
29        vector_service: Arc<VectorSearchService>,
30        event_repository: Arc<dyn EventRepository>,
31    ) -> Self {
32        Self {
33            vector_service,
34            event_repository,
35        }
36    }
37
38    /// Execute semantic search and return results
39    pub async fn execute(
40        &self,
41        request: SemanticSearchUseCaseRequest,
42    ) -> Result<SemanticSearchUseCaseResponse> {
43        // Validate query
44        let embedding = request.query_embedding.ok_or_else(|| {
45            AllSourceError::InvalidInput("query_embedding is required".to_string())
46        })?;
47
48        if embedding.is_empty() {
49            return Err(AllSourceError::InvalidInput(
50                "query_embedding cannot be empty".to_string(),
51            ));
52        }
53
54        // Validate k
55        let k = request.k.unwrap_or(10);
56        if k == 0 {
57            return Err(AllSourceError::InvalidInput(
58                "k must be greater than 0".to_string(),
59            ));
60        }
61        if k > 1000 {
62            return Err(AllSourceError::InvalidInput(
63                "k cannot exceed 1000".to_string(),
64            ));
65        }
66
67        // Build search request
68        let search_request = SemanticSearchRequest {
69            query_embedding: Some(embedding),
70            k: Some(k),
71            tenant_id: request.tenant_id.clone(),
72            event_type: request.event_type.clone(),
73            min_similarity: request.min_similarity,
74            max_distance: request.max_distance,
75            metric: request.metric.clone(),
76            include_events: request.include_events.unwrap_or(false),
77        };
78
79        // Execute search
80        let search_response = self.vector_service.search(search_request).await?;
81
82        // If we need full events, fetch them
83        let events = if request.include_events.unwrap_or(false) {
84            let mut events = Vec::with_capacity(search_response.results.len());
85            for result in &search_response.results {
86                if let Some(event) = self.event_repository.find_by_id(result.event_id).await? {
87                    events.push(EventDto::from(&event));
88                }
89            }
90            Some(events)
91        } else {
92            None
93        };
94
95        Ok(SemanticSearchUseCaseResponse {
96            results: search_response
97                .results
98                .into_iter()
99                .map(|r| SemanticSearchResultDto {
100                    event_id: r.event_id,
101                    score: r.score,
102                    source_text: r.source_text,
103                })
104                .collect(),
105            events,
106            count: search_response.count,
107            metric: search_response.metric,
108            vectors_searched: search_response.stats.vectors_searched,
109            search_time_us: search_response.stats.search_time_us,
110        })
111    }
112
113    /// Find similar events to a given event
114    pub async fn find_similar(
115        &self,
116        event_id: Uuid,
117        k: usize,
118        tenant_id: Option<String>,
119    ) -> Result<SemanticSearchUseCaseResponse> {
120        // Get the embedding for the source event
121        let entry = self
122            .vector_service
123            .get_embedding(event_id)
124            .await?
125            .ok_or_else(|| {
126                AllSourceError::EventNotFound(format!("No embedding found for event {}", event_id))
127            })?;
128
129        // Search for similar events (excluding the source event)
130        let search_request = SemanticSearchRequest {
131            query_embedding: Some(entry.embedding.values().to_vec()),
132            k: Some(k + 1), // Get one extra to exclude the source
133            tenant_id,
134            event_type: None,
135            min_similarity: None,
136            max_distance: None,
137            metric: None,
138            include_events: false,
139        };
140
141        let mut response = self.vector_service.search(search_request).await?;
142
143        // Filter out the source event and limit to k results
144        response.results.retain(|r| r.event_id != event_id);
145        response.results.truncate(k);
146        response.count = response.results.len();
147
148        Ok(SemanticSearchUseCaseResponse {
149            results: response
150                .results
151                .into_iter()
152                .map(|r| SemanticSearchResultDto {
153                    event_id: r.event_id,
154                    score: r.score,
155                    source_text: r.source_text,
156                })
157                .collect(),
158            events: None,
159            count: response.count,
160            metric: response.metric,
161            vectors_searched: response.stats.vectors_searched,
162            search_time_us: response.stats.search_time_us,
163        })
164    }
165}
166
167/// Request for semantic search use case
168#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct SemanticSearchUseCaseRequest {
170    /// The query embedding vector
171    pub query_embedding: Option<Vec<f32>>,
172    /// Number of results to return (default: 10, max: 1000)
173    pub k: Option<usize>,
174    /// Filter by tenant
175    pub tenant_id: Option<String>,
176    /// Filter by event type
177    pub event_type: Option<String>,
178    /// Minimum similarity threshold
179    pub min_similarity: Option<f32>,
180    /// Maximum distance threshold
181    pub max_distance: Option<f32>,
182    /// Distance metric ("cosine", "euclidean", "dot_product")
183    pub metric: Option<String>,
184    /// Whether to include full event data
185    pub include_events: Option<bool>,
186}
187
188impl Default for SemanticSearchUseCaseRequest {
189    fn default() -> Self {
190        Self {
191            query_embedding: None,
192            k: Some(10),
193            tenant_id: None,
194            event_type: None,
195            min_similarity: None,
196            max_distance: None,
197            metric: None,
198            include_events: None,
199        }
200    }
201}
202
203/// A single result from semantic search
204#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct SemanticSearchResultDto {
206    pub event_id: Uuid,
207    pub score: f32,
208    pub source_text: Option<String>,
209}
210
211/// Response from semantic search use case
212#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct SemanticSearchUseCaseResponse {
214    /// Search results
215    pub results: Vec<SemanticSearchResultDto>,
216    /// Full event data (if requested)
217    pub events: Option<Vec<EventDto>>,
218    /// Number of results
219    pub count: usize,
220    /// Metric used for scoring
221    pub metric: String,
222    /// Number of vectors searched
223    pub vectors_searched: usize,
224    /// Search time in microseconds
225    pub search_time_us: u64,
226}
227
228/// Use Case: Index Event Embedding
229///
230/// Handles indexing of event embeddings for semantic search.
231pub struct IndexEventEmbeddingUseCase {
232    vector_service: Arc<VectorSearchService>,
233}
234
235impl IndexEventEmbeddingUseCase {
236    pub fn new(vector_service: Arc<VectorSearchService>) -> Self {
237        Self { vector_service }
238    }
239
240    /// Index a single event embedding
241    pub async fn execute(&self, request: IndexEventEmbeddingRequest) -> Result<()> {
242        // Validate embedding
243        let embedding = EmbeddingVector::new(request.embedding)?;
244
245        // Index the embedding
246        self.vector_service
247            .index_event(crate::application::services::IndexEventRequest {
248                event_id: request.event_id,
249                tenant_id: request.tenant_id,
250                embedding,
251                source_text: request.source_text,
252            })
253            .await
254    }
255
256    /// Index multiple embeddings in batch
257    pub async fn execute_batch(
258        &self,
259        requests: Vec<IndexEventEmbeddingRequest>,
260    ) -> Result<BatchIndexResponse> {
261        let mut indexed = 0;
262        let mut failed = 0;
263        let mut errors = Vec::new();
264
265        for request in requests {
266            match EmbeddingVector::new(request.embedding) {
267                Ok(embedding) => {
268                    match self
269                        .vector_service
270                        .index_event(crate::application::services::IndexEventRequest {
271                            event_id: request.event_id,
272                            tenant_id: request.tenant_id,
273                            embedding,
274                            source_text: request.source_text,
275                        })
276                        .await
277                    {
278                        Ok(_) => indexed += 1,
279                        Err(e) => {
280                            failed += 1;
281                            errors.push(format!("Event {}: {}", request.event_id, e));
282                        }
283                    }
284                }
285                Err(e) => {
286                    failed += 1;
287                    errors.push(format!("Event {}: {}", request.event_id, e));
288                }
289            }
290        }
291
292        Ok(BatchIndexResponse {
293            indexed,
294            failed,
295            errors,
296        })
297    }
298}
299
300/// Request to index an event embedding
301#[derive(Debug, Clone, Serialize, Deserialize)]
302pub struct IndexEventEmbeddingRequest {
303    pub event_id: Uuid,
304    pub tenant_id: String,
305    pub embedding: Vec<f32>,
306    pub source_text: Option<String>,
307}
308
309/// Response from batch indexing
310#[derive(Debug, Clone, Serialize, Deserialize)]
311pub struct BatchIndexResponse {
312    pub indexed: usize,
313    pub failed: usize,
314    pub errors: Vec<String>,
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320    use crate::{
321        domain::entities::Event, infrastructure::repositories::InMemoryVectorSearchRepository,
322    };
323    use async_trait::async_trait;
324    use chrono::Utc;
325    use serde_json::json;
326
327    // Mock repository for testing
328    struct MockEventRepository {
329        events: Vec<Event>,
330    }
331
332    impl MockEventRepository {
333        fn with_events(events: Vec<Event>) -> Self {
334            Self { events }
335        }
336    }
337
338    #[async_trait]
339    impl EventRepository for MockEventRepository {
340        async fn save(&self, _event: &Event) -> Result<()> {
341            unimplemented!()
342        }
343
344        async fn save_batch(&self, _events: &[Event]) -> Result<()> {
345            unimplemented!()
346        }
347
348        async fn find_by_id(&self, id: Uuid) -> Result<Option<Event>> {
349            Ok(self.events.iter().find(|e| e.id() == id).cloned())
350        }
351
352        async fn find_by_entity(&self, _entity_id: &str, _tenant_id: &str) -> Result<Vec<Event>> {
353            unimplemented!()
354        }
355
356        async fn find_by_type(&self, _event_type: &str, _tenant_id: &str) -> Result<Vec<Event>> {
357            unimplemented!()
358        }
359
360        async fn find_by_time_range(
361            &self,
362            _tenant_id: &str,
363            _start: chrono::DateTime<Utc>,
364            _end: chrono::DateTime<Utc>,
365        ) -> Result<Vec<Event>> {
366            unimplemented!()
367        }
368
369        async fn find_by_entity_as_of(
370            &self,
371            _entity_id: &str,
372            _tenant_id: &str,
373            _as_of: chrono::DateTime<Utc>,
374        ) -> Result<Vec<Event>> {
375            unimplemented!()
376        }
377
378        async fn count(&self, _tenant_id: &str) -> Result<usize> {
379            unimplemented!()
380        }
381
382        async fn health_check(&self) -> Result<()> {
383            Ok(())
384        }
385    }
386
387    fn create_test_use_case() -> (SemanticSearchUseCase, Arc<VectorSearchService>) {
388        let vector_repo = Arc::new(InMemoryVectorSearchRepository::new());
389        let vector_service = Arc::new(VectorSearchService::new(vector_repo));
390
391        let events = vec![
392            Event::from_strings(
393                "user.created".to_string(),
394                "user-1".to_string(),
395                "tenant-1".to_string(),
396                json!({"name": "Test"}),
397                None,
398            )
399            .unwrap(),
400        ];
401
402        let event_repo = Arc::new(MockEventRepository::with_events(events));
403
404        (
405            SemanticSearchUseCase::new(vector_service.clone(), event_repo),
406            vector_service,
407        )
408    }
409
410    #[tokio::test]
411    async fn test_semantic_search() {
412        let (use_case, vector_service) = create_test_use_case();
413
414        // Index some embeddings
415        let id1 = Uuid::new_v4();
416        let id2 = Uuid::new_v4();
417
418        vector_service
419            .index_event(crate::application::services::IndexEventRequest {
420                event_id: id1,
421                tenant_id: "tenant-1".to_string(),
422                embedding: EmbeddingVector::new(vec![1.0, 0.0, 0.0]).unwrap(),
423                source_text: Some("first document".to_string()),
424            })
425            .await
426            .unwrap();
427
428        vector_service
429            .index_event(crate::application::services::IndexEventRequest {
430                event_id: id2,
431                tenant_id: "tenant-1".to_string(),
432                embedding: EmbeddingVector::new(vec![0.0, 1.0, 0.0]).unwrap(),
433                source_text: Some("second document".to_string()),
434            })
435            .await
436            .unwrap();
437
438        // Search
439        let response = use_case
440            .execute(SemanticSearchUseCaseRequest {
441                query_embedding: Some(vec![1.0, 0.0, 0.0]),
442                k: Some(2),
443                tenant_id: Some("tenant-1".to_string()),
444                ..Default::default()
445            })
446            .await
447            .unwrap();
448
449        assert_eq!(response.count, 2);
450        assert_eq!(response.results[0].event_id, id1);
451        assert!((response.results[0].score - 1.0).abs() < 1e-6);
452    }
453
454    #[tokio::test]
455    async fn test_find_similar() {
456        let (use_case, vector_service) = create_test_use_case();
457
458        // Index embeddings
459        let id1 = Uuid::new_v4();
460        let id2 = Uuid::new_v4();
461        let id3 = Uuid::new_v4();
462
463        vector_service
464            .index_event(crate::application::services::IndexEventRequest {
465                event_id: id1,
466                tenant_id: "tenant-1".to_string(),
467                embedding: EmbeddingVector::new(vec![1.0, 0.0, 0.0]).unwrap(),
468                source_text: None,
469            })
470            .await
471            .unwrap();
472
473        vector_service
474            .index_event(crate::application::services::IndexEventRequest {
475                event_id: id2,
476                tenant_id: "tenant-1".to_string(),
477                embedding: EmbeddingVector::new(vec![0.9, 0.1, 0.0]).unwrap(),
478                source_text: None,
479            })
480            .await
481            .unwrap();
482
483        vector_service
484            .index_event(crate::application::services::IndexEventRequest {
485                event_id: id3,
486                tenant_id: "tenant-1".to_string(),
487                embedding: EmbeddingVector::new(vec![0.0, 1.0, 0.0]).unwrap(),
488                source_text: None,
489            })
490            .await
491            .unwrap();
492
493        // Find similar to id1
494        let response = use_case
495            .find_similar(id1, 2, Some("tenant-1".to_string()))
496            .await
497            .unwrap();
498
499        // Should not include id1 itself
500        assert!(!response.results.iter().any(|r| r.event_id == id1));
501        assert!(response.results.len() <= 2);
502
503        // id2 should be first (most similar to id1)
504        assert_eq!(response.results[0].event_id, id2);
505    }
506
507    #[tokio::test]
508    async fn test_validation_errors() {
509        let (use_case, _) = create_test_use_case();
510
511        // Missing embedding
512        let result = use_case
513            .execute(SemanticSearchUseCaseRequest {
514                query_embedding: None,
515                ..Default::default()
516            })
517            .await;
518        assert!(result.is_err());
519
520        // Empty embedding
521        let result = use_case
522            .execute(SemanticSearchUseCaseRequest {
523                query_embedding: Some(vec![]),
524                ..Default::default()
525            })
526            .await;
527        assert!(result.is_err());
528
529        // k = 0
530        let result = use_case
531            .execute(SemanticSearchUseCaseRequest {
532                query_embedding: Some(vec![1.0, 0.0, 0.0]),
533                k: Some(0),
534                ..Default::default()
535            })
536            .await;
537        assert!(result.is_err());
538
539        // k too large
540        let result = use_case
541            .execute(SemanticSearchUseCaseRequest {
542                query_embedding: Some(vec![1.0, 0.0, 0.0]),
543                k: Some(2000),
544                ..Default::default()
545            })
546            .await;
547        assert!(result.is_err());
548    }
549
550    #[tokio::test]
551    async fn test_index_use_case() {
552        use crate::domain::repositories::VectorSearchRepository;
553
554        let vector_repo = Arc::new(InMemoryVectorSearchRepository::new());
555        let vector_service = Arc::new(VectorSearchService::new(vector_repo.clone()));
556        let use_case = IndexEventEmbeddingUseCase::new(vector_service);
557
558        let event_id = Uuid::new_v4();
559        use_case
560            .execute(IndexEventEmbeddingRequest {
561                event_id,
562                tenant_id: "tenant-1".to_string(),
563                embedding: vec![1.0, 0.0, 0.0],
564                source_text: Some("test content".to_string()),
565            })
566            .await
567            .unwrap();
568
569        assert_eq!(
570            VectorSearchRepository::count(&*vector_repo, None)
571                .await
572                .unwrap(),
573            1
574        );
575    }
576
577    #[tokio::test]
578    async fn test_batch_index_use_case() {
579        let vector_repo = Arc::new(InMemoryVectorSearchRepository::new());
580        let vector_service = Arc::new(VectorSearchService::new(vector_repo.clone()));
581        let use_case = IndexEventEmbeddingUseCase::new(vector_service);
582
583        let requests: Vec<_> = (0..5)
584            .map(|i| IndexEventEmbeddingRequest {
585                event_id: Uuid::new_v4(),
586                tenant_id: "tenant-1".to_string(),
587                embedding: vec![i as f32, 0.0, 0.0],
588                source_text: None,
589            })
590            .collect();
591
592        let response = use_case.execute_batch(requests).await.unwrap();
593        assert_eq!(response.indexed, 5);
594        assert_eq!(response.failed, 0);
595    }
596}