Skip to main content

allsource_core/application/services/
vector_search.rs

1use crate::domain::entities::Event;
2use crate::domain::repositories::{
3    EventRepository, SearchResult, VectorEntry, VectorSearchQuery, VectorSearchRepository,
4};
5use crate::domain::value_objects::{DistanceMetric, EmbeddingVector};
6use crate::error::{AllSourceError, Result};
7use serde::{Deserialize, Serialize};
8use std::sync::Arc;
9use uuid::Uuid;
10
11/// Configuration for the vector search service
12#[derive(Debug, Clone)]
13pub struct VectorSearchConfig {
14    /// Default number of results to return
15    pub default_k: usize,
16    /// Maximum number of results to return
17    pub max_k: usize,
18    /// Default similarity threshold for cosine similarity
19    pub default_min_similarity: f32,
20    /// Default distance metric
21    pub default_metric: DistanceMetric,
22    /// Whether to include source text in results
23    pub include_source_text: bool,
24}
25
26impl Default for VectorSearchConfig {
27    fn default() -> Self {
28        Self {
29            default_k: 10,
30            max_k: 100,
31            default_min_similarity: 0.0,
32            default_metric: DistanceMetric::Cosine,
33            include_source_text: true,
34        }
35    }
36}
37
38/// Request to index an event's embedding
39#[derive(Debug, Clone)]
40pub struct IndexEventRequest {
41    pub event_id: Uuid,
42    pub tenant_id: String,
43    pub embedding: EmbeddingVector,
44    pub source_text: Option<String>,
45}
46
47/// Request for semantic search
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct SemanticSearchRequest {
50    /// The query embedding vector
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub query_embedding: Option<Vec<f32>>,
53    /// Number of results to return (default: 10)
54    #[serde(default)]
55    pub k: Option<usize>,
56    /// Tenant ID filter
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub tenant_id: Option<String>,
59    /// Event type filter
60    #[serde(skip_serializing_if = "Option::is_none")]
61    pub event_type: Option<String>,
62    /// Minimum similarity threshold (for cosine/dot product)
63    #[serde(skip_serializing_if = "Option::is_none")]
64    pub min_similarity: Option<f32>,
65    /// Maximum distance threshold (for euclidean)
66    #[serde(skip_serializing_if = "Option::is_none")]
67    pub max_distance: Option<f32>,
68    /// Distance metric (default: cosine)
69    #[serde(default)]
70    pub metric: Option<String>,
71    /// Whether to include full event data in results
72    #[serde(default)]
73    pub include_events: bool,
74}
75
76impl Default for SemanticSearchRequest {
77    fn default() -> Self {
78        Self {
79            query_embedding: None,
80            k: None,
81            tenant_id: None,
82            event_type: None,
83            min_similarity: None,
84            max_distance: None,
85            metric: None,
86            include_events: false,
87        }
88    }
89}
90
91/// A single semantic search result
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct SemanticSearchResultItem {
94    /// The event ID that matched
95    pub event_id: Uuid,
96    /// The similarity/distance score
97    pub score: f32,
98    /// The source text (if available)
99    #[serde(skip_serializing_if = "Option::is_none")]
100    pub source_text: Option<String>,
101    /// The full event (if requested)
102    #[serde(skip_serializing_if = "Option::is_none")]
103    pub event: Option<EventSummary>,
104}
105
106/// Summary of an event for search results
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct EventSummary {
109    pub id: Uuid,
110    pub event_type: String,
111    pub entity_id: String,
112    pub tenant_id: String,
113    pub timestamp: chrono::DateTime<chrono::Utc>,
114    #[serde(skip_serializing_if = "Option::is_none")]
115    pub payload: Option<serde_json::Value>,
116}
117
118impl From<&Event> for EventSummary {
119    fn from(event: &Event) -> Self {
120        Self {
121            id: event.id(),
122            event_type: event.event_type_str().to_string(),
123            entity_id: event.entity_id_str().to_string(),
124            tenant_id: event.tenant_id_str().to_string(),
125            timestamp: event.timestamp(),
126            payload: Some(event.payload().clone()),
127        }
128    }
129}
130
131/// Response from semantic search
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct SemanticSearchResponse {
134    /// The search results
135    pub results: Vec<SemanticSearchResultItem>,
136    /// Total number of results
137    pub count: usize,
138    /// The metric used for scoring
139    pub metric: String,
140    /// Query execution stats
141    pub stats: SearchStats,
142}
143
144/// Statistics about the search operation
145#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct SearchStats {
147    /// Total vectors searched
148    pub vectors_searched: usize,
149    /// Time taken in microseconds
150    pub search_time_us: u64,
151}
152
153/// Vector Search Service
154///
155/// Orchestrates vector search operations including:
156/// - Indexing event embeddings
157/// - Semantic similarity search
158/// - Integration with event repository for full results
159///
160/// This service follows the application layer pattern, coordinating
161/// between the domain repositories without containing domain logic.
162pub struct VectorSearchService {
163    vector_repo: Arc<dyn VectorSearchRepository>,
164    event_repo: Option<Arc<dyn EventRepository>>,
165    config: VectorSearchConfig,
166}
167
168impl VectorSearchService {
169    pub fn new(vector_repo: Arc<dyn VectorSearchRepository>) -> Self {
170        Self {
171            vector_repo,
172            event_repo: None,
173            config: VectorSearchConfig::default(),
174        }
175    }
176
177    pub fn with_event_repo(mut self, event_repo: Arc<dyn EventRepository>) -> Self {
178        self.event_repo = Some(event_repo);
179        self
180    }
181
182    pub fn with_config(mut self, config: VectorSearchConfig) -> Self {
183        self.config = config;
184        self
185    }
186
187    /// Index a single event embedding
188    pub async fn index_event(&self, request: IndexEventRequest) -> Result<()> {
189        if let Some(source_text) = &request.source_text {
190            self.vector_repo
191                .store_with_text(
192                    request.event_id,
193                    &request.embedding,
194                    &request.tenant_id,
195                    source_text,
196                )
197                .await
198        } else {
199            self.vector_repo
200                .store(request.event_id, &request.embedding, &request.tenant_id)
201                .await
202        }
203    }
204
205    /// Index multiple events in batch
206    pub async fn index_events_batch(
207        &self,
208        requests: Vec<IndexEventRequest>,
209    ) -> Result<BatchIndexResult> {
210        if requests.is_empty() {
211            return Ok(BatchIndexResult {
212                indexed: 0,
213                failed: 0,
214                errors: vec![],
215            });
216        }
217
218        let entries: Vec<_> = requests
219            .iter()
220            .map(|r| (r.event_id, r.embedding.clone(), r.tenant_id.clone()))
221            .collect();
222
223        self.vector_repo.store_batch(&entries).await?;
224
225        Ok(BatchIndexResult {
226            indexed: requests.len(),
227            failed: 0,
228            errors: vec![],
229        })
230    }
231
232    /// Perform semantic search
233    pub async fn search(&self, request: SemanticSearchRequest) -> Result<SemanticSearchResponse> {
234        let start_time = std::time::Instant::now();
235
236        // Parse and validate query embedding
237        let query_embedding = request
238            .query_embedding
239            .ok_or_else(|| AllSourceError::InvalidInput("query_embedding is required".to_string()))?;
240
241        let query_vector = EmbeddingVector::new(query_embedding)?;
242
243        // Parse metric
244        let metric = match request.metric.as_deref() {
245            Some("cosine") | None => DistanceMetric::Cosine,
246            Some("euclidean") => DistanceMetric::Euclidean,
247            Some("dot_product") => DistanceMetric::DotProduct,
248            Some(m) => {
249                return Err(AllSourceError::InvalidInput(format!(
250                    "Unknown metric: {}. Supported: cosine, euclidean, dot_product",
251                    m
252                )));
253            }
254        };
255
256        // Build query
257        let k = request
258            .k
259            .unwrap_or(self.config.default_k)
260            .min(self.config.max_k);
261
262        let mut query = VectorSearchQuery::new(query_vector, k).with_metric(metric);
263
264        if let Some(tenant_id) = request.tenant_id {
265            query = query.with_tenant(tenant_id);
266        }
267
268        if let Some(event_type) = request.event_type {
269            query = query.with_event_type(event_type);
270        }
271
272        if let Some(min_sim) = request.min_similarity {
273            query = query.with_min_similarity(min_sim);
274        }
275
276        if let Some(max_dist) = request.max_distance {
277            query = query.with_max_distance(max_dist);
278        }
279
280        // Execute search
281        let search_results = self.vector_repo.search(&query).await?;
282        let vectors_searched = self.vector_repo.count(None).await.unwrap_or(0);
283
284        // Optionally fetch full events
285        let results = if request.include_events {
286            self.enrich_with_events(search_results).await?
287        } else {
288            search_results
289                .into_iter()
290                .map(|r| SemanticSearchResultItem {
291                    event_id: r.event_id,
292                    score: r.score.value(),
293                    source_text: r.source_text,
294                    event: None,
295                })
296                .collect()
297        };
298
299        let search_time_us = start_time.elapsed().as_micros() as u64;
300        let count = results.len();
301
302        Ok(SemanticSearchResponse {
303            results,
304            count,
305            metric: format!("{:?}", metric).to_lowercase(),
306            stats: SearchStats {
307                vectors_searched,
308                search_time_us,
309            },
310        })
311    }
312
313    /// Get embedding for a specific event
314    pub async fn get_embedding(&self, event_id: Uuid) -> Result<Option<VectorEntry>> {
315        self.vector_repo.get_by_event_id(event_id).await
316    }
317
318    /// Delete embedding for an event
319    pub async fn delete_embedding(&self, event_id: Uuid) -> Result<bool> {
320        self.vector_repo.delete(event_id).await
321    }
322
323    /// Delete all embeddings for a tenant
324    pub async fn delete_tenant_embeddings(&self, tenant_id: &str) -> Result<usize> {
325        self.vector_repo.delete_by_tenant(tenant_id).await
326    }
327
328    /// Get index statistics
329    pub async fn get_stats(&self) -> Result<IndexStats> {
330        let total_vectors = self.vector_repo.count(None).await?;
331        let dimensions = self.vector_repo.dimensions().await?;
332
333        Ok(IndexStats {
334            total_vectors,
335            dimensions,
336        })
337    }
338
339    /// Health check
340    pub async fn health_check(&self) -> Result<()> {
341        self.vector_repo.health_check().await
342    }
343
344    /// Enrich search results with full event data
345    async fn enrich_with_events(
346        &self,
347        results: Vec<SearchResult>,
348    ) -> Result<Vec<SemanticSearchResultItem>> {
349        let event_repo = match &self.event_repo {
350            Some(repo) => repo,
351            None => {
352                // No event repo, return without events
353                return Ok(results
354                    .into_iter()
355                    .map(|r| SemanticSearchResultItem {
356                        event_id: r.event_id,
357                        score: r.score.value(),
358                        source_text: r.source_text,
359                        event: None,
360                    })
361                    .collect());
362            }
363        };
364
365        let mut enriched = Vec::with_capacity(results.len());
366
367        for result in results {
368            let event = event_repo.find_by_id(result.event_id).await?;
369
370            enriched.push(SemanticSearchResultItem {
371                event_id: result.event_id,
372                score: result.score.value(),
373                source_text: result.source_text,
374                event: event.as_ref().map(EventSummary::from),
375            });
376        }
377
378        Ok(enriched)
379    }
380}
381
382/// Result of batch indexing operation
383#[derive(Debug, Clone, Serialize, Deserialize)]
384pub struct BatchIndexResult {
385    pub indexed: usize,
386    pub failed: usize,
387    pub errors: Vec<String>,
388}
389
390/// Index statistics
391#[derive(Debug, Clone, Serialize, Deserialize)]
392pub struct IndexStats {
393    pub total_vectors: usize,
394    pub dimensions: Option<usize>,
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400    use crate::infrastructure::repositories::InMemoryVectorSearchRepository;
401
402    fn create_test_service() -> VectorSearchService {
403        let repo = Arc::new(InMemoryVectorSearchRepository::new());
404        VectorSearchService::new(repo)
405    }
406
407    fn create_test_embedding(dims: usize, seed: f32) -> EmbeddingVector {
408        let values: Vec<f32> = (0..dims).map(|i| (i as f32 + seed) / dims as f32).collect();
409        EmbeddingVector::new(values).unwrap()
410    }
411
412    #[tokio::test]
413    async fn test_index_and_search() {
414        let service = create_test_service();
415
416        // Index some events
417        let embeddings = vec![
418            (Uuid::new_v4(), vec![1.0, 0.0, 0.0_f32]),
419            (Uuid::new_v4(), vec![0.9, 0.1, 0.0]),
420            (Uuid::new_v4(), vec![0.0, 1.0, 0.0]),
421        ];
422
423        for (id, values) in &embeddings {
424            service
425                .index_event(IndexEventRequest {
426                    event_id: *id,
427                    tenant_id: "tenant-1".to_string(),
428                    embedding: EmbeddingVector::new(values.clone()).unwrap(),
429                    source_text: None,
430                })
431                .await
432                .unwrap();
433        }
434
435        // Search
436        let response = service
437            .search(SemanticSearchRequest {
438                query_embedding: Some(vec![1.0, 0.0, 0.0]),
439                k: Some(2),
440                tenant_id: Some("tenant-1".to_string()),
441                ..Default::default()
442            })
443            .await
444            .unwrap();
445
446        assert_eq!(response.count, 2);
447        assert_eq!(response.results[0].event_id, embeddings[0].0);
448    }
449
450    #[tokio::test]
451    async fn test_batch_index() {
452        let service = create_test_service();
453
454        let requests: Vec<_> = (0..10)
455            .map(|i| IndexEventRequest {
456                event_id: Uuid::new_v4(),
457                tenant_id: "tenant-1".to_string(),
458                embedding: create_test_embedding(384, i as f32),
459                source_text: Some(format!("Document {}", i)),
460            })
461            .collect();
462
463        let result = service.index_events_batch(requests).await.unwrap();
464        assert_eq!(result.indexed, 10);
465        assert_eq!(result.failed, 0);
466
467        let stats = service.get_stats().await.unwrap();
468        assert_eq!(stats.total_vectors, 10);
469        assert_eq!(stats.dimensions, Some(384));
470    }
471
472    #[tokio::test]
473    async fn test_search_with_min_similarity() {
474        let service = create_test_service();
475
476        // Index vectors
477        service
478            .index_event(IndexEventRequest {
479                event_id: Uuid::new_v4(),
480                tenant_id: "tenant-1".to_string(),
481                embedding: EmbeddingVector::new(vec![1.0, 0.0, 0.0]).unwrap(),
482                source_text: None,
483            })
484            .await
485            .unwrap();
486
487        service
488            .index_event(IndexEventRequest {
489                event_id: Uuid::new_v4(),
490                tenant_id: "tenant-1".to_string(),
491                embedding: EmbeddingVector::new(vec![0.0, 1.0, 0.0]).unwrap(),
492                source_text: None,
493            })
494            .await
495            .unwrap();
496
497        // Search with high threshold
498        let response = service
499            .search(SemanticSearchRequest {
500                query_embedding: Some(vec![1.0, 0.0, 0.0]),
501                k: Some(10),
502                tenant_id: Some("tenant-1".to_string()),
503                min_similarity: Some(0.5),
504                ..Default::default()
505            })
506            .await
507            .unwrap();
508
509        // Only one should match (the exact match)
510        assert_eq!(response.count, 1);
511    }
512
513    #[tokio::test]
514    async fn test_delete_embedding() {
515        let service = create_test_service();
516
517        let event_id = Uuid::new_v4();
518        service
519            .index_event(IndexEventRequest {
520                event_id,
521                tenant_id: "tenant-1".to_string(),
522                embedding: create_test_embedding(384, 1.0),
523                source_text: None,
524            })
525            .await
526            .unwrap();
527
528        assert!(service.get_embedding(event_id).await.unwrap().is_some());
529
530        let deleted = service.delete_embedding(event_id).await.unwrap();
531        assert!(deleted);
532
533        assert!(service.get_embedding(event_id).await.unwrap().is_none());
534    }
535
536    #[tokio::test]
537    async fn test_health_check() {
538        let service = create_test_service();
539        assert!(service.health_check().await.is_ok());
540    }
541
542    #[tokio::test]
543    async fn test_invalid_metric() {
544        let service = create_test_service();
545
546        let result = service
547            .search(SemanticSearchRequest {
548                query_embedding: Some(vec![1.0, 0.0, 0.0]),
549                metric: Some("invalid".to_string()),
550                ..Default::default()
551            })
552            .await;
553
554        assert!(result.is_err());
555        if let Err(e) = result {
556            assert!(e.to_string().contains("Unknown metric"));
557        }
558    }
559
560    #[tokio::test]
561    async fn test_missing_query_embedding() {
562        let service = create_test_service();
563
564        let result = service
565            .search(SemanticSearchRequest {
566                query_embedding: None,
567                ..Default::default()
568            })
569            .await;
570
571        assert!(result.is_err());
572        if let Err(e) = result {
573            assert!(e.to_string().contains("query_embedding is required"));
574        }
575    }
576}