rexis_rag/retrieval/
semantic.rs

1//! # Semantic Vector Search
2//!
3//! High-performance semantic search using vector embeddings and similarity metrics.
4//! Supports multiple similarity algorithms and optimization techniques.
5
6use crate::{Document, Embedding, EmbeddingProvider, RragResult, SearchResult};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12/// Semantic retriever configuration
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct SemanticConfig {
15    /// Similarity metric to use
16    pub similarity_metric: SimilarityMetric,
17
18    /// Embedding dimension
19    pub embedding_dimension: usize,
20
21    /// Whether to normalize embeddings
22    pub normalize_embeddings: bool,
23
24    /// Index type for efficient search
25    pub index_type: IndexType,
26
27    /// Number of clusters for IVF index
28    pub num_clusters: Option<usize>,
29
30    /// Number of probes for IVF search
31    pub num_probes: Option<usize>,
32
33    /// Enable GPU acceleration if available
34    pub use_gpu: bool,
35}
36
37impl Default for SemanticConfig {
38    fn default() -> Self {
39        Self {
40            similarity_metric: SimilarityMetric::Cosine,
41            embedding_dimension: 768,
42            normalize_embeddings: true,
43            index_type: IndexType::Flat,
44            num_clusters: None,
45            num_probes: None,
46            use_gpu: false,
47        }
48    }
49}
50
51/// Similarity metrics for vector comparison
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub enum SimilarityMetric {
54    /// Cosine similarity (angle between vectors)
55    Cosine,
56    /// Euclidean distance (L2 norm)
57    Euclidean,
58    /// Dot product (inner product)
59    DotProduct,
60    /// Manhattan distance (L1 norm)
61    Manhattan,
62}
63
64/// Index types for efficient search
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub enum IndexType {
67    /// Flat index (brute force search)
68    Flat,
69    /// Inverted File Index (clustering-based)
70    IVF,
71    /// Hierarchical Navigable Small World
72    HNSW,
73    /// Locality Sensitive Hashing
74    LSH,
75}
76
77/// Vector document for semantic search
78#[derive(Debug, Clone)]
79struct VectorDocument {
80    /// Document ID
81    id: String,
82
83    /// Original content
84    content: String,
85
86    /// Document embedding
87    embedding: Embedding,
88
89    /// Normalized embedding (if applicable)
90    normalized_embedding: Option<Vec<f32>>,
91
92    /// Metadata
93    metadata: HashMap<String, serde_json::Value>,
94}
95
96/// Semantic retriever implementation
97pub struct SemanticRetriever {
98    /// Configuration
99    config: SemanticConfig,
100
101    /// Document storage
102    documents: Arc<RwLock<HashMap<String, VectorDocument>>>,
103
104    /// Embedding service
105    embedding_service: Arc<dyn EmbeddingProvider>,
106
107    /// Index for efficient search (simplified for this example)
108    index: Arc<RwLock<VectorIndex>>,
109}
110
111/// Simplified vector index
112struct VectorIndex {
113    /// Document IDs in order
114    doc_ids: Vec<String>,
115
116    /// Embeddings matrix (row-major)
117    embeddings: Vec<Vec<f32>>,
118
119    /// Index type
120    index_type: IndexType,
121}
122
123impl SemanticRetriever {
124    /// Create a new semantic retriever
125    pub fn new(config: SemanticConfig, embedding_service: Arc<dyn EmbeddingProvider>) -> Self {
126        Self {
127            config,
128            documents: Arc::new(RwLock::new(HashMap::new())),
129            embedding_service,
130            index: Arc::new(RwLock::new(VectorIndex {
131                doc_ids: Vec::new(),
132                embeddings: Vec::new(),
133                index_type: IndexType::Flat,
134            })),
135        }
136    }
137
138    /// Index a document with semantic embedding
139    pub async fn index_document(&self, doc: &Document) -> RragResult<()> {
140        // Generate embedding for the document
141        let embedding = self.embedding_service.embed_text(&doc.content).await?;
142
143        // Normalize if configured
144        let normalized = if self.config.normalize_embeddings {
145            Some(Self::normalize_vector(&embedding.vector))
146        } else {
147            None
148        };
149
150        let vector_doc = VectorDocument {
151            id: doc.id.clone(),
152            content: doc.content.to_string(),
153            embedding: embedding.clone(),
154            normalized_embedding: normalized,
155            metadata: doc.metadata.clone(),
156        };
157
158        // Store document
159        let mut documents = self.documents.write().await;
160        documents.insert(doc.id.clone(), vector_doc);
161
162        // Update index
163        let mut index = self.index.write().await;
164        index.doc_ids.push(doc.id.clone());
165        index.embeddings.push(if self.config.normalize_embeddings {
166            Self::normalize_vector(&embedding.vector)
167        } else {
168            embedding.vector
169        });
170
171        Ok(())
172    }
173
174    /// Search for similar documents
175    pub async fn search(
176        &self,
177        query: &str,
178        limit: usize,
179        min_score: Option<f32>,
180    ) -> RragResult<Vec<SearchResult>> {
181        // Generate query embedding
182        let query_embedding = self.embedding_service.embed_text(query).await?;
183
184        let query_vector = if self.config.normalize_embeddings {
185            Self::normalize_vector(&query_embedding.vector)
186        } else {
187            query_embedding.vector
188        };
189
190        // Perform search
191        let index = self.index.read().await;
192        let documents = self.documents.read().await;
193
194        let mut scores: Vec<(String, f32)> = Vec::new();
195
196        // Calculate similarities
197        for (i, doc_embedding) in index.embeddings.iter().enumerate() {
198            let similarity = self.calculate_similarity(&query_vector, doc_embedding);
199
200            if let Some(threshold) = min_score {
201                if similarity < threshold {
202                    continue;
203                }
204            }
205
206            scores.push((index.doc_ids[i].clone(), similarity));
207        }
208
209        // Sort by similarity
210        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
211        scores.truncate(limit);
212
213        // Build results
214        let results: Vec<SearchResult> = scores
215            .into_iter()
216            .enumerate()
217            .filter_map(|(rank, (doc_id, score))| {
218                documents.get(&doc_id).map(|doc| SearchResult {
219                    id: doc_id,
220                    content: doc.content.clone(),
221                    score,
222                    rank,
223                    metadata: doc.metadata.clone(),
224                    embedding: Some(doc.embedding.clone()),
225                })
226            })
227            .collect();
228
229        Ok(results)
230    }
231
232    /// Search with pre-computed embedding
233    pub async fn search_by_embedding(
234        &self,
235        embedding: &Embedding,
236        limit: usize,
237        min_score: Option<f32>,
238    ) -> RragResult<Vec<SearchResult>> {
239        let query_vector = if self.config.normalize_embeddings {
240            Self::normalize_vector(&embedding.vector)
241        } else {
242            embedding.vector.clone()
243        };
244
245        let index = self.index.read().await;
246        let documents = self.documents.read().await;
247
248        let mut scores: Vec<(String, f32)> = Vec::new();
249
250        for (i, doc_embedding) in index.embeddings.iter().enumerate() {
251            let similarity = self.calculate_similarity(&query_vector, doc_embedding);
252
253            if let Some(threshold) = min_score {
254                if similarity < threshold {
255                    continue;
256                }
257            }
258
259            scores.push((index.doc_ids[i].clone(), similarity));
260        }
261
262        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
263        scores.truncate(limit);
264
265        let results: Vec<SearchResult> = scores
266            .into_iter()
267            .enumerate()
268            .filter_map(|(rank, (doc_id, score))| {
269                documents.get(&doc_id).map(|doc| SearchResult {
270                    id: doc_id,
271                    content: doc.content.clone(),
272                    score,
273                    rank,
274                    metadata: doc.metadata.clone(),
275                    embedding: Some(doc.embedding.clone()),
276                })
277            })
278            .collect();
279
280        Ok(results)
281    }
282
283    /// Calculate similarity between two vectors
284    fn calculate_similarity(&self, vec1: &[f32], vec2: &[f32]) -> f32 {
285        match self.config.similarity_metric {
286            SimilarityMetric::Cosine => Self::cosine_similarity(vec1, vec2),
287            SimilarityMetric::Euclidean => {
288                let distance = Self::euclidean_distance(vec1, vec2);
289                1.0 / (1.0 + distance) // Convert distance to similarity
290            }
291            SimilarityMetric::DotProduct => Self::dot_product(vec1, vec2),
292            SimilarityMetric::Manhattan => {
293                let distance = Self::manhattan_distance(vec1, vec2);
294                1.0 / (1.0 + distance) // Convert distance to similarity
295            }
296        }
297    }
298
299    /// Cosine similarity between two vectors
300    fn cosine_similarity(vec1: &[f32], vec2: &[f32]) -> f32 {
301        let dot = Self::dot_product(vec1, vec2);
302        let norm1 = vec1.iter().map(|x| x * x).sum::<f32>().sqrt();
303        let norm2 = vec2.iter().map(|x| x * x).sum::<f32>().sqrt();
304
305        if norm1 == 0.0 || norm2 == 0.0 {
306            0.0
307        } else {
308            dot / (norm1 * norm2)
309        }
310    }
311
312    /// Dot product of two vectors
313    fn dot_product(vec1: &[f32], vec2: &[f32]) -> f32 {
314        vec1.iter().zip(vec2.iter()).map(|(a, b)| a * b).sum()
315    }
316
317    /// Euclidean distance between two vectors
318    fn euclidean_distance(vec1: &[f32], vec2: &[f32]) -> f32 {
319        vec1.iter()
320            .zip(vec2.iter())
321            .map(|(a, b)| (a - b).powi(2))
322            .sum::<f32>()
323            .sqrt()
324    }
325
326    /// Manhattan distance between two vectors
327    fn manhattan_distance(vec1: &[f32], vec2: &[f32]) -> f32 {
328        vec1.iter()
329            .zip(vec2.iter())
330            .map(|(a, b)| (a - b).abs())
331            .sum()
332    }
333
334    /// Normalize a vector to unit length
335    fn normalize_vector(vec: &[f32]) -> Vec<f32> {
336        let norm = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
337
338        if norm == 0.0 {
339            vec.to_vec()
340        } else {
341            vec.iter().map(|x| x / norm).collect()
342        }
343    }
344
345    /// Batch index multiple documents
346    pub async fn index_batch(&self, documents: Vec<Document>) -> RragResult<()> {
347        // Generate embedding requests
348        let requests: Vec<crate::EmbeddingRequest> = documents
349            .iter()
350            .map(|doc| crate::EmbeddingRequest::new(&doc.id, doc.content.as_ref()))
351            .collect();
352
353        let embedding_batch = self.embedding_service.embed_batch(requests).await?;
354
355        let mut docs_map = self.documents.write().await;
356        let mut index = self.index.write().await;
357
358        for doc in documents.iter() {
359            if let Some(embedding) = embedding_batch.embeddings.get(&doc.id) {
360                let normalized = if self.config.normalize_embeddings {
361                    Some(Self::normalize_vector(&embedding.vector))
362                } else {
363                    None
364                };
365
366                let vector_doc = VectorDocument {
367                    id: doc.id.clone(),
368                    content: doc.content.to_string(),
369                    embedding: embedding.clone(),
370                    normalized_embedding: normalized.clone(),
371                    metadata: doc.metadata.clone(),
372                };
373
374                docs_map.insert(doc.id.clone(), vector_doc);
375                index.doc_ids.push(doc.id.clone());
376                index
377                    .embeddings
378                    .push(normalized.unwrap_or_else(|| embedding.vector.clone()));
379            }
380        }
381
382        Ok(())
383    }
384
385    /// Clear the index
386    pub async fn clear(&self) -> RragResult<()> {
387        let mut documents = self.documents.write().await;
388        let mut index = self.index.write().await;
389
390        documents.clear();
391        index.doc_ids.clear();
392        index.embeddings.clear();
393
394        Ok(())
395    }
396
397    /// Get index statistics
398    pub async fn stats(&self) -> HashMap<String, serde_json::Value> {
399        let documents = self.documents.read().await;
400        let _index = self.index.read().await;
401
402        let mut stats = HashMap::new();
403        stats.insert("total_documents".to_string(), documents.len().into());
404        stats.insert(
405            "embedding_dimension".to_string(),
406            self.config.embedding_dimension.into(),
407        );
408        stats.insert(
409            "index_type".to_string(),
410            format!("{:?}", self.config.index_type).into(),
411        );
412        stats.insert(
413            "similarity_metric".to_string(),
414            format!("{:?}", self.config.similarity_metric).into(),
415        );
416
417        let memory_size = documents.len() * self.config.embedding_dimension * 4; // 4 bytes per f32
418        stats.insert("index_memory_bytes".to_string(), memory_size.into());
419
420        stats
421    }
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427    use crate::embeddings::MockEmbeddingService;
428
429    #[tokio::test]
430    async fn test_semantic_search() {
431        let mock_service = Arc::new(MockEmbeddingService::new());
432        let retriever = SemanticRetriever::new(SemanticConfig::default(), mock_service);
433
434        let docs = vec![
435            Document::with_id(
436                "1",
437                "Machine learning is a subset of artificial intelligence",
438            ),
439            Document::with_id("2", "Deep learning uses neural networks"),
440            Document::with_id(
441                "3",
442                "Natural language processing enables computers to understand text",
443            ),
444        ];
445
446        retriever.index_batch(docs).await.unwrap();
447
448        let results = retriever
449            .search("AI and machine learning", 2, Some(0.5))
450            .await
451            .unwrap();
452        assert!(!results.is_empty());
453    }
454}