Skip to main content

allsource_core/application/services/
vector_search.rs

1use crate::{
2    domain::{
3        entities::Event,
4        repositories::{
5            EventRepository, SearchResult, VectorEntry, VectorSearchQuery, VectorSearchRepository,
6        },
7        value_objects::{DistanceMetric, EmbeddingVector},
8    },
9    error::{AllSourceError, Result},
10};
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use uuid::Uuid;
14
15/// Configuration for the vector search service
16#[derive(Debug, Clone)]
17pub struct VectorSearchConfig {
18    /// Default number of results to return
19    pub default_k: usize,
20    /// Maximum number of results to return
21    pub max_k: usize,
22    /// Default similarity threshold for cosine similarity
23    pub default_min_similarity: f32,
24    /// Default distance metric
25    pub default_metric: DistanceMetric,
26    /// Whether to include source text in results
27    pub include_source_text: bool,
28}
29
30impl Default for VectorSearchConfig {
31    fn default() -> Self {
32        Self {
33            default_k: 10,
34            max_k: 100,
35            default_min_similarity: 0.0,
36            default_metric: DistanceMetric::Cosine,
37            include_source_text: true,
38        }
39    }
40}
41
42/// Request to index an event's embedding
43#[derive(Debug, Clone)]
44pub struct IndexEventRequest {
45    pub event_id: Uuid,
46    pub tenant_id: String,
47    pub embedding: EmbeddingVector,
48    pub source_text: Option<String>,
49}
50
51/// Request for semantic search
52#[derive(Debug, Clone, Default, Serialize, Deserialize)]
53pub struct SemanticSearchRequest {
54    /// The query embedding vector
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub query_embedding: Option<Vec<f32>>,
57    /// Number of results to return (default: 10)
58    #[serde(default)]
59    pub k: Option<usize>,
60    /// Tenant ID filter
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub tenant_id: Option<String>,
63    /// Event type filter
64    #[serde(skip_serializing_if = "Option::is_none")]
65    pub event_type: Option<String>,
66    /// Minimum similarity threshold (for cosine/dot product)
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub min_similarity: Option<f32>,
69    /// Maximum distance threshold (for euclidean)
70    #[serde(skip_serializing_if = "Option::is_none")]
71    pub max_distance: Option<f32>,
72    /// Distance metric (default: cosine)
73    #[serde(default)]
74    pub metric: Option<String>,
75    /// Whether to include full event data in results
76    #[serde(default)]
77    pub include_events: bool,
78}
79
80/// A single semantic search result
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct SemanticSearchResultItem {
83    /// The event ID that matched
84    pub event_id: Uuid,
85    /// The similarity/distance score
86    pub score: f32,
87    /// The source text (if available)
88    #[serde(skip_serializing_if = "Option::is_none")]
89    pub source_text: Option<String>,
90    /// The full event (if requested)
91    #[serde(skip_serializing_if = "Option::is_none")]
92    pub event: Option<EventSummary>,
93}
94
95/// Summary of an event for search results
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct EventSummary {
98    pub id: Uuid,
99    pub event_type: String,
100    pub entity_id: String,
101    pub tenant_id: String,
102    pub timestamp: chrono::DateTime<chrono::Utc>,
103    #[serde(skip_serializing_if = "Option::is_none")]
104    pub payload: Option<serde_json::Value>,
105}
106
107impl From<&Event> for EventSummary {
108    fn from(event: &Event) -> Self {
109        Self {
110            id: event.id(),
111            event_type: event.event_type_str().to_string(),
112            entity_id: event.entity_id_str().to_string(),
113            tenant_id: event.tenant_id_str().to_string(),
114            timestamp: event.timestamp(),
115            payload: Some(event.payload().clone()),
116        }
117    }
118}
119
120/// Response from semantic search
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct SemanticSearchResponse {
123    /// The search results
124    pub results: Vec<SemanticSearchResultItem>,
125    /// Total number of results
126    pub count: usize,
127    /// The metric used for scoring
128    pub metric: String,
129    /// Query execution stats
130    pub stats: SearchStats,
131}
132
133/// Statistics about the search operation
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct SearchStats {
136    /// Total vectors searched
137    pub vectors_searched: usize,
138    /// Time taken in microseconds
139    pub search_time_us: u64,
140}
141
142/// Vector Search Service
143///
144/// Orchestrates vector search operations including:
145/// - Indexing event embeddings
146/// - Semantic similarity search
147/// - Integration with event repository for full results
148///
149/// This service follows the application layer pattern, coordinating
150/// between the domain repositories without containing domain logic.
151pub struct VectorSearchService {
152    vector_repo: Arc<dyn VectorSearchRepository>,
153    event_repo: Option<Arc<dyn EventRepository>>,
154    config: VectorSearchConfig,
155}
156
157impl VectorSearchService {
158    pub fn new(vector_repo: Arc<dyn VectorSearchRepository>) -> Self {
159        Self {
160            vector_repo,
161            event_repo: None,
162            config: VectorSearchConfig::default(),
163        }
164    }
165
166    pub fn with_event_repo(mut self, event_repo: Arc<dyn EventRepository>) -> Self {
167        self.event_repo = Some(event_repo);
168        self
169    }
170
171    pub fn with_config(mut self, config: VectorSearchConfig) -> Self {
172        self.config = config;
173        self
174    }
175
176    /// Index a single event embedding
177    pub async fn index_event(&self, request: IndexEventRequest) -> Result<()> {
178        if let Some(source_text) = &request.source_text {
179            self.vector_repo
180                .store_with_text(
181                    request.event_id,
182                    &request.embedding,
183                    &request.tenant_id,
184                    source_text,
185                )
186                .await
187        } else {
188            self.vector_repo
189                .store(request.event_id, &request.embedding, &request.tenant_id)
190                .await
191        }
192    }
193
194    /// Index multiple events in batch
195    pub async fn index_events_batch(
196        &self,
197        requests: Vec<IndexEventRequest>,
198    ) -> Result<BatchIndexResult> {
199        if requests.is_empty() {
200            return Ok(BatchIndexResult {
201                indexed: 0,
202                failed: 0,
203                errors: vec![],
204            });
205        }
206
207        let entries: Vec<_> = requests
208            .iter()
209            .map(|r| (r.event_id, r.embedding.clone(), r.tenant_id.clone()))
210            .collect();
211
212        self.vector_repo.store_batch(&entries).await?;
213
214        Ok(BatchIndexResult {
215            indexed: requests.len(),
216            failed: 0,
217            errors: vec![],
218        })
219    }
220
221    /// Perform semantic search
222    pub async fn search(&self, request: SemanticSearchRequest) -> Result<SemanticSearchResponse> {
223        let start_time = std::time::Instant::now();
224
225        // Parse and validate query embedding
226        let query_embedding = request.query_embedding.ok_or_else(|| {
227            AllSourceError::InvalidInput("query_embedding is required".to_string())
228        })?;
229
230        let query_vector = EmbeddingVector::new(query_embedding)?;
231
232        // Parse metric
233        let metric = match request.metric.as_deref() {
234            Some("cosine") | None => DistanceMetric::Cosine,
235            Some("euclidean") => DistanceMetric::Euclidean,
236            Some("dot_product") => DistanceMetric::DotProduct,
237            Some(m) => {
238                return Err(AllSourceError::InvalidInput(format!(
239                    "Unknown metric: {}. Supported: cosine, euclidean, dot_product",
240                    m
241                )));
242            }
243        };
244
245        // Build query
246        let k = request
247            .k
248            .unwrap_or(self.config.default_k)
249            .min(self.config.max_k);
250
251        let mut query = VectorSearchQuery::new(query_vector, k).with_metric(metric);
252
253        if let Some(tenant_id) = request.tenant_id {
254            query = query.with_tenant(tenant_id);
255        }
256
257        if let Some(event_type) = request.event_type {
258            query = query.with_event_type(event_type);
259        }
260
261        if let Some(min_sim) = request.min_similarity {
262            query = query.with_min_similarity(min_sim);
263        }
264
265        if let Some(max_dist) = request.max_distance {
266            query = query.with_max_distance(max_dist);
267        }
268
269        // Execute search
270        let search_results = self.vector_repo.search(&query).await?;
271        let vectors_searched = self.vector_repo.count(None).await.unwrap_or(0);
272
273        // Optionally fetch full events
274        let results = if request.include_events {
275            self.enrich_with_events(search_results).await?
276        } else {
277            search_results
278                .into_iter()
279                .map(|r| SemanticSearchResultItem {
280                    event_id: r.event_id,
281                    score: r.score.value(),
282                    source_text: r.source_text,
283                    event: None,
284                })
285                .collect()
286        };
287
288        let search_time_us = start_time.elapsed().as_micros() as u64;
289        let count = results.len();
290
291        Ok(SemanticSearchResponse {
292            results,
293            count,
294            metric: format!("{:?}", metric).to_lowercase(),
295            stats: SearchStats {
296                vectors_searched,
297                search_time_us,
298            },
299        })
300    }
301
302    /// Get embedding for a specific event
303    pub async fn get_embedding(&self, event_id: Uuid) -> Result<Option<VectorEntry>> {
304        self.vector_repo.get_by_event_id(event_id).await
305    }
306
307    /// Delete embedding for an event
308    pub async fn delete_embedding(&self, event_id: Uuid) -> Result<bool> {
309        self.vector_repo.delete(event_id).await
310    }
311
312    /// Delete all embeddings for a tenant
313    pub async fn delete_tenant_embeddings(&self, tenant_id: &str) -> Result<usize> {
314        self.vector_repo.delete_by_tenant(tenant_id).await
315    }
316
317    /// Get index statistics
318    pub async fn get_stats(&self) -> Result<IndexStats> {
319        let total_vectors = self.vector_repo.count(None).await?;
320        let dimensions = self.vector_repo.dimensions().await?;
321
322        Ok(IndexStats {
323            total_vectors,
324            dimensions,
325        })
326    }
327
328    /// Health check
329    pub async fn health_check(&self) -> Result<()> {
330        self.vector_repo.health_check().await
331    }
332
333    /// Enrich search results with full event data
334    async fn enrich_with_events(
335        &self,
336        results: Vec<SearchResult>,
337    ) -> Result<Vec<SemanticSearchResultItem>> {
338        let event_repo = match &self.event_repo {
339            Some(repo) => repo,
340            None => {
341                // No event repo, return without events
342                return Ok(results
343                    .into_iter()
344                    .map(|r| SemanticSearchResultItem {
345                        event_id: r.event_id,
346                        score: r.score.value(),
347                        source_text: r.source_text,
348                        event: None,
349                    })
350                    .collect());
351            }
352        };
353
354        let mut enriched = Vec::with_capacity(results.len());
355
356        for result in results {
357            let event = event_repo.find_by_id(result.event_id).await?;
358
359            enriched.push(SemanticSearchResultItem {
360                event_id: result.event_id,
361                score: result.score.value(),
362                source_text: result.source_text,
363                event: event.as_ref().map(EventSummary::from),
364            });
365        }
366
367        Ok(enriched)
368    }
369}
370
371/// Result of batch indexing operation
372#[derive(Debug, Clone, Serialize, Deserialize)]
373pub struct BatchIndexResult {
374    pub indexed: usize,
375    pub failed: usize,
376    pub errors: Vec<String>,
377}
378
379/// Index statistics
380#[derive(Debug, Clone, Serialize, Deserialize)]
381pub struct IndexStats {
382    pub total_vectors: usize,
383    pub dimensions: Option<usize>,
384}
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389    use crate::infrastructure::repositories::InMemoryVectorSearchRepository;
390
391    fn create_test_service() -> VectorSearchService {
392        let repo = Arc::new(InMemoryVectorSearchRepository::new());
393        VectorSearchService::new(repo)
394    }
395
396    fn create_test_embedding(dims: usize, seed: f32) -> EmbeddingVector {
397        let values: Vec<f32> = (0..dims).map(|i| (i as f32 + seed) / dims as f32).collect();
398        EmbeddingVector::new(values).unwrap()
399    }
400
401    #[tokio::test]
402    async fn test_index_and_search() {
403        let service = create_test_service();
404
405        // Index some events
406        let embeddings = vec![
407            (Uuid::new_v4(), vec![1.0, 0.0, 0.0_f32]),
408            (Uuid::new_v4(), vec![0.9, 0.1, 0.0]),
409            (Uuid::new_v4(), vec![0.0, 1.0, 0.0]),
410        ];
411
412        for (id, values) in &embeddings {
413            service
414                .index_event(IndexEventRequest {
415                    event_id: *id,
416                    tenant_id: "tenant-1".to_string(),
417                    embedding: EmbeddingVector::new(values.clone()).unwrap(),
418                    source_text: None,
419                })
420                .await
421                .unwrap();
422        }
423
424        // Search
425        let response = service
426            .search(SemanticSearchRequest {
427                query_embedding: Some(vec![1.0, 0.0, 0.0]),
428                k: Some(2),
429                tenant_id: Some("tenant-1".to_string()),
430                ..Default::default()
431            })
432            .await
433            .unwrap();
434
435        assert_eq!(response.count, 2);
436        assert_eq!(response.results[0].event_id, embeddings[0].0);
437    }
438
439    #[tokio::test]
440    async fn test_batch_index() {
441        let service = create_test_service();
442
443        let requests: Vec<_> = (0..10)
444            .map(|i| IndexEventRequest {
445                event_id: Uuid::new_v4(),
446                tenant_id: "tenant-1".to_string(),
447                embedding: create_test_embedding(384, i as f32),
448                source_text: Some(format!("Document {}", i)),
449            })
450            .collect();
451
452        let result = service.index_events_batch(requests).await.unwrap();
453        assert_eq!(result.indexed, 10);
454        assert_eq!(result.failed, 0);
455
456        let stats = service.get_stats().await.unwrap();
457        assert_eq!(stats.total_vectors, 10);
458        assert_eq!(stats.dimensions, Some(384));
459    }
460
461    #[tokio::test]
462    async fn test_search_with_min_similarity() {
463        let service = create_test_service();
464
465        // Index vectors
466        service
467            .index_event(IndexEventRequest {
468                event_id: Uuid::new_v4(),
469                tenant_id: "tenant-1".to_string(),
470                embedding: EmbeddingVector::new(vec![1.0, 0.0, 0.0]).unwrap(),
471                source_text: None,
472            })
473            .await
474            .unwrap();
475
476        service
477            .index_event(IndexEventRequest {
478                event_id: Uuid::new_v4(),
479                tenant_id: "tenant-1".to_string(),
480                embedding: EmbeddingVector::new(vec![0.0, 1.0, 0.0]).unwrap(),
481                source_text: None,
482            })
483            .await
484            .unwrap();
485
486        // Search with high threshold
487        let response = service
488            .search(SemanticSearchRequest {
489                query_embedding: Some(vec![1.0, 0.0, 0.0]),
490                k: Some(10),
491                tenant_id: Some("tenant-1".to_string()),
492                min_similarity: Some(0.5),
493                ..Default::default()
494            })
495            .await
496            .unwrap();
497
498        // Only one should match (the exact match)
499        assert_eq!(response.count, 1);
500    }
501
502    #[tokio::test]
503    async fn test_delete_embedding() {
504        let service = create_test_service();
505
506        let event_id = Uuid::new_v4();
507        service
508            .index_event(IndexEventRequest {
509                event_id,
510                tenant_id: "tenant-1".to_string(),
511                embedding: create_test_embedding(384, 1.0),
512                source_text: None,
513            })
514            .await
515            .unwrap();
516
517        assert!(service.get_embedding(event_id).await.unwrap().is_some());
518
519        let deleted = service.delete_embedding(event_id).await.unwrap();
520        assert!(deleted);
521
522        assert!(service.get_embedding(event_id).await.unwrap().is_none());
523    }
524
525    #[tokio::test]
526    async fn test_health_check() {
527        let service = create_test_service();
528        assert!(service.health_check().await.is_ok());
529    }
530
531    #[tokio::test]
532    async fn test_invalid_metric() {
533        let service = create_test_service();
534
535        let result = service
536            .search(SemanticSearchRequest {
537                query_embedding: Some(vec![1.0, 0.0, 0.0]),
538                metric: Some("invalid".to_string()),
539                ..Default::default()
540            })
541            .await;
542
543        assert!(result.is_err());
544        if let Err(e) = result {
545            assert!(e.to_string().contains("Unknown metric"));
546        }
547    }
548
549    #[tokio::test]
550    async fn test_missing_query_embedding() {
551        let service = create_test_service();
552
553        let result = service
554            .search(SemanticSearchRequest {
555                query_embedding: None,
556                ..Default::default()
557            })
558            .await;
559
560        assert!(result.is_err());
561        if let Err(e) = result {
562            assert!(e.to_string().contains("query_embedding is required"));
563        }
564    }
565}