Skip to main content

allsource_core/application/services/
vector_search.rs

1use crate::{
2    domain::{
3        entities::Event,
4        repositories::{
5            EventRepository, SearchResult, VectorEntry, VectorSearchQuery, VectorSearchRepository,
6        },
7        value_objects::{DistanceMetric, EmbeddingVector},
8    },
9    error::{AllSourceError, Result},
10};
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use uuid::Uuid;
14
15/// Configuration for the vector search service
16#[derive(Debug, Clone)]
17pub struct VectorSearchConfig {
18    /// Default number of results to return
19    pub default_k: usize,
20    /// Maximum number of results to return
21    pub max_k: usize,
22    /// Default similarity threshold for cosine similarity
23    pub default_min_similarity: f32,
24    /// Default distance metric
25    pub default_metric: DistanceMetric,
26    /// Whether to include source text in results
27    pub include_source_text: bool,
28}
29
30impl Default for VectorSearchConfig {
31    fn default() -> Self {
32        Self {
33            default_k: 10,
34            max_k: 100,
35            default_min_similarity: 0.0,
36            default_metric: DistanceMetric::Cosine,
37            include_source_text: true,
38        }
39    }
40}
41
42/// Request to index an event's embedding
43#[derive(Debug, Clone)]
44pub struct IndexEventRequest {
45    pub event_id: Uuid,
46    pub tenant_id: String,
47    pub embedding: EmbeddingVector,
48    pub source_text: Option<String>,
49}
50
51/// Request for semantic search
52#[derive(Debug, Clone, Default, Serialize, Deserialize)]
53pub struct SemanticSearchRequest {
54    /// The query embedding vector
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub query_embedding: Option<Vec<f32>>,
57    /// Number of results to return (default: 10)
58    #[serde(default)]
59    pub k: Option<usize>,
60    /// Tenant ID filter
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub tenant_id: Option<String>,
63    /// Event type filter
64    #[serde(skip_serializing_if = "Option::is_none")]
65    pub event_type: Option<String>,
66    /// Minimum similarity threshold (for cosine/dot product)
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub min_similarity: Option<f32>,
69    /// Maximum distance threshold (for euclidean)
70    #[serde(skip_serializing_if = "Option::is_none")]
71    pub max_distance: Option<f32>,
72    /// Distance metric (default: cosine)
73    #[serde(default)]
74    pub metric: Option<String>,
75    /// Whether to include full event data in results
76    #[serde(default)]
77    pub include_events: bool,
78}
79
80/// A single semantic search result
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct SemanticSearchResultItem {
83    /// The event ID that matched
84    pub event_id: Uuid,
85    /// The similarity/distance score
86    pub score: f32,
87    /// The source text (if available)
88    #[serde(skip_serializing_if = "Option::is_none")]
89    pub source_text: Option<String>,
90    /// The full event (if requested)
91    #[serde(skip_serializing_if = "Option::is_none")]
92    pub event: Option<EventSummary>,
93}
94
95/// Summary of an event for search results
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct EventSummary {
98    pub id: Uuid,
99    pub event_type: String,
100    pub entity_id: String,
101    pub tenant_id: String,
102    pub timestamp: chrono::DateTime<chrono::Utc>,
103    #[serde(skip_serializing_if = "Option::is_none")]
104    pub payload: Option<serde_json::Value>,
105}
106
107impl From<&Event> for EventSummary {
108    fn from(event: &Event) -> Self {
109        Self {
110            id: event.id(),
111            event_type: event.event_type_str().to_string(),
112            entity_id: event.entity_id_str().to_string(),
113            tenant_id: event.tenant_id_str().to_string(),
114            timestamp: event.timestamp(),
115            payload: Some(event.payload().clone()),
116        }
117    }
118}
119
120/// Response from semantic search
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct SemanticSearchResponse {
123    /// The search results
124    pub results: Vec<SemanticSearchResultItem>,
125    /// Total number of results
126    pub count: usize,
127    /// The metric used for scoring
128    pub metric: String,
129    /// Query execution stats
130    pub stats: SearchStats,
131}
132
133/// Statistics about the search operation
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct SearchStats {
136    /// Total vectors searched
137    pub vectors_searched: usize,
138    /// Time taken in microseconds
139    pub search_time_us: u64,
140}
141
142/// Vector Search Service
143///
144/// Orchestrates vector search operations including:
145/// - Indexing event embeddings
146/// - Semantic similarity search
147/// - Integration with event repository for full results
148///
149/// This service follows the application layer pattern, coordinating
150/// between the domain repositories without containing domain logic.
151pub struct VectorSearchService {
152    vector_repo: Arc<dyn VectorSearchRepository>,
153    event_repo: Option<Arc<dyn EventRepository>>,
154    config: VectorSearchConfig,
155}
156
157impl VectorSearchService {
158    pub fn new(vector_repo: Arc<dyn VectorSearchRepository>) -> Self {
159        Self {
160            vector_repo,
161            event_repo: None,
162            config: VectorSearchConfig::default(),
163        }
164    }
165
166    pub fn with_event_repo(mut self, event_repo: Arc<dyn EventRepository>) -> Self {
167        self.event_repo = Some(event_repo);
168        self
169    }
170
171    pub fn with_config(mut self, config: VectorSearchConfig) -> Self {
172        self.config = config;
173        self
174    }
175
176    /// Index a single event embedding
177    pub async fn index_event(&self, request: IndexEventRequest) -> Result<()> {
178        if let Some(source_text) = &request.source_text {
179            self.vector_repo
180                .store_with_text(
181                    request.event_id,
182                    &request.embedding,
183                    &request.tenant_id,
184                    source_text,
185                )
186                .await
187        } else {
188            self.vector_repo
189                .store(request.event_id, &request.embedding, &request.tenant_id)
190                .await
191        }
192    }
193
194    /// Index multiple events in batch
195    pub async fn index_events_batch(
196        &self,
197        requests: Vec<IndexEventRequest>,
198    ) -> Result<BatchIndexResult> {
199        if requests.is_empty() {
200            return Ok(BatchIndexResult {
201                indexed: 0,
202                failed: 0,
203                errors: vec![],
204            });
205        }
206
207        let entries: Vec<_> = requests
208            .iter()
209            .map(|r| (r.event_id, r.embedding.clone(), r.tenant_id.clone()))
210            .collect();
211
212        self.vector_repo.store_batch(&entries).await?;
213
214        Ok(BatchIndexResult {
215            indexed: requests.len(),
216            failed: 0,
217            errors: vec![],
218        })
219    }
220
221    /// Perform semantic search
222    pub async fn search(&self, request: SemanticSearchRequest) -> Result<SemanticSearchResponse> {
223        let start_time = std::time::Instant::now();
224
225        // Parse and validate query embedding
226        let query_embedding = request.query_embedding.ok_or_else(|| {
227            AllSourceError::InvalidInput("query_embedding is required".to_string())
228        })?;
229
230        let query_vector = EmbeddingVector::new(query_embedding)?;
231
232        // Parse metric
233        let metric = match request.metric.as_deref() {
234            Some("cosine") | None => DistanceMetric::Cosine,
235            Some("euclidean") => DistanceMetric::Euclidean,
236            Some("dot_product") => DistanceMetric::DotProduct,
237            Some(m) => {
238                return Err(AllSourceError::InvalidInput(format!(
239                    "Unknown metric: {m}. Supported: cosine, euclidean, dot_product"
240                )));
241            }
242        };
243
244        // Build query
245        let k = request
246            .k
247            .unwrap_or(self.config.default_k)
248            .min(self.config.max_k);
249
250        let mut query = VectorSearchQuery::new(query_vector, k).with_metric(metric);
251
252        if let Some(tenant_id) = request.tenant_id {
253            query = query.with_tenant(tenant_id);
254        }
255
256        if let Some(event_type) = request.event_type {
257            query = query.with_event_type(event_type);
258        }
259
260        if let Some(min_sim) = request.min_similarity {
261            query = query.with_min_similarity(min_sim);
262        }
263
264        if let Some(max_dist) = request.max_distance {
265            query = query.with_max_distance(max_dist);
266        }
267
268        // Execute search
269        let search_results = self.vector_repo.search(&query).await?;
270        let vectors_searched = self.vector_repo.count(None).await.unwrap_or(0);
271
272        // Optionally fetch full events
273        let results = if request.include_events {
274            self.enrich_with_events(search_results).await?
275        } else {
276            search_results
277                .into_iter()
278                .map(|r| SemanticSearchResultItem {
279                    event_id: r.event_id,
280                    score: r.score.value(),
281                    source_text: r.source_text,
282                    event: None,
283                })
284                .collect()
285        };
286
287        let search_time_us = start_time.elapsed().as_micros() as u64;
288        let count = results.len();
289
290        Ok(SemanticSearchResponse {
291            results,
292            count,
293            metric: format!("{metric:?}").to_lowercase(),
294            stats: SearchStats {
295                vectors_searched,
296                search_time_us,
297            },
298        })
299    }
300
301    /// Get embedding for a specific event
302    pub async fn get_embedding(&self, event_id: Uuid) -> Result<Option<VectorEntry>> {
303        self.vector_repo.get_by_event_id(event_id).await
304    }
305
306    /// Delete embedding for an event
307    pub async fn delete_embedding(&self, event_id: Uuid) -> Result<bool> {
308        self.vector_repo.delete(event_id).await
309    }
310
311    /// Delete all embeddings for a tenant
312    pub async fn delete_tenant_embeddings(&self, tenant_id: &str) -> Result<usize> {
313        self.vector_repo.delete_by_tenant(tenant_id).await
314    }
315
316    /// Get index statistics
317    pub async fn get_stats(&self) -> Result<IndexStats> {
318        let total_vectors = self.vector_repo.count(None).await?;
319        let dimensions = self.vector_repo.dimensions().await?;
320
321        Ok(IndexStats {
322            total_vectors,
323            dimensions,
324        })
325    }
326
327    /// Health check
328    pub async fn health_check(&self) -> Result<()> {
329        self.vector_repo.health_check().await
330    }
331
332    /// Enrich search results with full event data
333    async fn enrich_with_events(
334        &self,
335        results: Vec<SearchResult>,
336    ) -> Result<Vec<SemanticSearchResultItem>> {
337        let Some(event_repo) = &self.event_repo else {
338            // No event repo, return without events
339            return Ok(results
340                .into_iter()
341                .map(|r| SemanticSearchResultItem {
342                    event_id: r.event_id,
343                    score: r.score.value(),
344                    source_text: r.source_text,
345                    event: None,
346                })
347                .collect());
348        };
349
350        let mut enriched = Vec::with_capacity(results.len());
351
352        for result in results {
353            let event = event_repo.find_by_id(result.event_id).await?;
354
355            enriched.push(SemanticSearchResultItem {
356                event_id: result.event_id,
357                score: result.score.value(),
358                source_text: result.source_text,
359                event: event.as_ref().map(EventSummary::from),
360            });
361        }
362
363        Ok(enriched)
364    }
365}
366
367/// Result of batch indexing operation
368#[derive(Debug, Clone, Serialize, Deserialize)]
369pub struct BatchIndexResult {
370    pub indexed: usize,
371    pub failed: usize,
372    pub errors: Vec<String>,
373}
374
375/// Index statistics
376#[derive(Debug, Clone, Serialize, Deserialize)]
377pub struct IndexStats {
378    pub total_vectors: usize,
379    pub dimensions: Option<usize>,
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385    use crate::infrastructure::repositories::InMemoryVectorSearchRepository;
386
387    fn create_test_service() -> VectorSearchService {
388        let repo = Arc::new(InMemoryVectorSearchRepository::new());
389        VectorSearchService::new(repo)
390    }
391
392    fn create_test_embedding(dims: usize, seed: f32) -> EmbeddingVector {
393        let values: Vec<f32> = (0..dims).map(|i| (i as f32 + seed) / dims as f32).collect();
394        EmbeddingVector::new(values).unwrap()
395    }
396
397    #[tokio::test]
398    async fn test_index_and_search() {
399        let service = create_test_service();
400
401        // Index some events
402        let embeddings = vec![
403            (Uuid::new_v4(), vec![1.0, 0.0, 0.0_f32]),
404            (Uuid::new_v4(), vec![0.9, 0.1, 0.0]),
405            (Uuid::new_v4(), vec![0.0, 1.0, 0.0]),
406        ];
407
408        for (id, values) in &embeddings {
409            service
410                .index_event(IndexEventRequest {
411                    event_id: *id,
412                    tenant_id: "tenant-1".to_string(),
413                    embedding: EmbeddingVector::new(values.clone()).unwrap(),
414                    source_text: None,
415                })
416                .await
417                .unwrap();
418        }
419
420        // Search
421        let response = service
422            .search(SemanticSearchRequest {
423                query_embedding: Some(vec![1.0, 0.0, 0.0]),
424                k: Some(2),
425                tenant_id: Some("tenant-1".to_string()),
426                ..Default::default()
427            })
428            .await
429            .unwrap();
430
431        assert_eq!(response.count, 2);
432        assert_eq!(response.results[0].event_id, embeddings[0].0);
433    }
434
435    #[tokio::test]
436    async fn test_batch_index() {
437        let service = create_test_service();
438
439        let requests: Vec<_> = (0..10)
440            .map(|i| IndexEventRequest {
441                event_id: Uuid::new_v4(),
442                tenant_id: "tenant-1".to_string(),
443                embedding: create_test_embedding(384, i as f32),
444                source_text: Some(format!("Document {i}")),
445            })
446            .collect();
447
448        let result = service.index_events_batch(requests).await.unwrap();
449        assert_eq!(result.indexed, 10);
450        assert_eq!(result.failed, 0);
451
452        let stats = service.get_stats().await.unwrap();
453        assert_eq!(stats.total_vectors, 10);
454        assert_eq!(stats.dimensions, Some(384));
455    }
456
457    #[tokio::test]
458    async fn test_search_with_min_similarity() {
459        let service = create_test_service();
460
461        // Index vectors
462        service
463            .index_event(IndexEventRequest {
464                event_id: Uuid::new_v4(),
465                tenant_id: "tenant-1".to_string(),
466                embedding: EmbeddingVector::new(vec![1.0, 0.0, 0.0]).unwrap(),
467                source_text: None,
468            })
469            .await
470            .unwrap();
471
472        service
473            .index_event(IndexEventRequest {
474                event_id: Uuid::new_v4(),
475                tenant_id: "tenant-1".to_string(),
476                embedding: EmbeddingVector::new(vec![0.0, 1.0, 0.0]).unwrap(),
477                source_text: None,
478            })
479            .await
480            .unwrap();
481
482        // Search with high threshold
483        let response = service
484            .search(SemanticSearchRequest {
485                query_embedding: Some(vec![1.0, 0.0, 0.0]),
486                k: Some(10),
487                tenant_id: Some("tenant-1".to_string()),
488                min_similarity: Some(0.5),
489                ..Default::default()
490            })
491            .await
492            .unwrap();
493
494        // Only one should match (the exact match)
495        assert_eq!(response.count, 1);
496    }
497
498    #[tokio::test]
499    async fn test_delete_embedding() {
500        let service = create_test_service();
501
502        let event_id = Uuid::new_v4();
503        service
504            .index_event(IndexEventRequest {
505                event_id,
506                tenant_id: "tenant-1".to_string(),
507                embedding: create_test_embedding(384, 1.0),
508                source_text: None,
509            })
510            .await
511            .unwrap();
512
513        assert!(service.get_embedding(event_id).await.unwrap().is_some());
514
515        let deleted = service.delete_embedding(event_id).await.unwrap();
516        assert!(deleted);
517
518        assert!(service.get_embedding(event_id).await.unwrap().is_none());
519    }
520
521    #[tokio::test]
522    async fn test_health_check() {
523        let service = create_test_service();
524        assert!(service.health_check().await.is_ok());
525    }
526
527    #[tokio::test]
528    async fn test_invalid_metric() {
529        let service = create_test_service();
530
531        let result = service
532            .search(SemanticSearchRequest {
533                query_embedding: Some(vec![1.0, 0.0, 0.0]),
534                metric: Some("invalid".to_string()),
535                ..Default::default()
536            })
537            .await;
538
539        assert!(result.is_err());
540        if let Err(e) = result {
541            assert!(e.to_string().contains("Unknown metric"));
542        }
543    }
544
545    #[tokio::test]
546    async fn test_missing_query_embedding() {
547        let service = create_test_service();
548
549        let result = service
550            .search(SemanticSearchRequest {
551                query_embedding: None,
552                ..Default::default()
553            })
554            .await;
555
556        assert!(result.is_err());
557        if let Err(e) = result {
558            assert!(e.to_string().contains("query_embedding is required"));
559        }
560    }
561}