1use crate::{Document, Embedding, EmbeddingProvider, RragResult, SearchResult};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct SemanticConfig {
15 pub similarity_metric: SimilarityMetric,
17
18 pub embedding_dimension: usize,
20
21 pub normalize_embeddings: bool,
23
24 pub index_type: IndexType,
26
27 pub num_clusters: Option<usize>,
29
30 pub num_probes: Option<usize>,
32
33 pub use_gpu: bool,
35}
36
37impl Default for SemanticConfig {
38 fn default() -> Self {
39 Self {
40 similarity_metric: SimilarityMetric::Cosine,
41 embedding_dimension: 768,
42 normalize_embeddings: true,
43 index_type: IndexType::Flat,
44 num_clusters: None,
45 num_probes: None,
46 use_gpu: false,
47 }
48 }
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub enum SimilarityMetric {
54 Cosine,
56 Euclidean,
58 DotProduct,
60 Manhattan,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub enum IndexType {
67 Flat,
69 IVF,
71 HNSW,
73 LSH,
75}
76
77#[derive(Debug, Clone)]
79struct VectorDocument {
80 id: String,
82
83 content: String,
85
86 embedding: Embedding,
88
89 normalized_embedding: Option<Vec<f32>>,
91
92 metadata: HashMap<String, serde_json::Value>,
94}
95
96pub struct SemanticRetriever {
98 config: SemanticConfig,
100
101 documents: Arc<RwLock<HashMap<String, VectorDocument>>>,
103
104 embedding_service: Arc<dyn EmbeddingProvider>,
106
107 index: Arc<RwLock<VectorIndex>>,
109}
110
111struct VectorIndex {
113 doc_ids: Vec<String>,
115
116 embeddings: Vec<Vec<f32>>,
118
119 index_type: IndexType,
121}
122
123impl SemanticRetriever {
124 pub fn new(config: SemanticConfig, embedding_service: Arc<dyn EmbeddingProvider>) -> Self {
126 Self {
127 config,
128 documents: Arc::new(RwLock::new(HashMap::new())),
129 embedding_service,
130 index: Arc::new(RwLock::new(VectorIndex {
131 doc_ids: Vec::new(),
132 embeddings: Vec::new(),
133 index_type: IndexType::Flat,
134 })),
135 }
136 }
137
138 pub async fn index_document(&self, doc: &Document) -> RragResult<()> {
140 let embedding = self.embedding_service.embed_text(&doc.content).await?;
142
143 let normalized = if self.config.normalize_embeddings {
145 Some(Self::normalize_vector(&embedding.vector))
146 } else {
147 None
148 };
149
150 let vector_doc = VectorDocument {
151 id: doc.id.clone(),
152 content: doc.content.to_string(),
153 embedding: embedding.clone(),
154 normalized_embedding: normalized,
155 metadata: doc.metadata.clone(),
156 };
157
158 let mut documents = self.documents.write().await;
160 documents.insert(doc.id.clone(), vector_doc);
161
162 let mut index = self.index.write().await;
164 index.doc_ids.push(doc.id.clone());
165 index.embeddings.push(if self.config.normalize_embeddings {
166 Self::normalize_vector(&embedding.vector)
167 } else {
168 embedding.vector
169 });
170
171 Ok(())
172 }
173
174 pub async fn search(
176 &self,
177 query: &str,
178 limit: usize,
179 min_score: Option<f32>,
180 ) -> RragResult<Vec<SearchResult>> {
181 let query_embedding = self.embedding_service.embed_text(query).await?;
183
184 let query_vector = if self.config.normalize_embeddings {
185 Self::normalize_vector(&query_embedding.vector)
186 } else {
187 query_embedding.vector
188 };
189
190 let index = self.index.read().await;
192 let documents = self.documents.read().await;
193
194 let mut scores: Vec<(String, f32)> = Vec::new();
195
196 for (i, doc_embedding) in index.embeddings.iter().enumerate() {
198 let similarity = self.calculate_similarity(&query_vector, doc_embedding);
199
200 if let Some(threshold) = min_score {
201 if similarity < threshold {
202 continue;
203 }
204 }
205
206 scores.push((index.doc_ids[i].clone(), similarity));
207 }
208
209 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
211 scores.truncate(limit);
212
213 let results: Vec<SearchResult> = scores
215 .into_iter()
216 .enumerate()
217 .filter_map(|(rank, (doc_id, score))| {
218 documents.get(&doc_id).map(|doc| SearchResult {
219 id: doc_id,
220 content: doc.content.clone(),
221 score,
222 rank,
223 metadata: doc.metadata.clone(),
224 embedding: Some(doc.embedding.clone()),
225 })
226 })
227 .collect();
228
229 Ok(results)
230 }
231
232 pub async fn search_by_embedding(
234 &self,
235 embedding: &Embedding,
236 limit: usize,
237 min_score: Option<f32>,
238 ) -> RragResult<Vec<SearchResult>> {
239 let query_vector = if self.config.normalize_embeddings {
240 Self::normalize_vector(&embedding.vector)
241 } else {
242 embedding.vector.clone()
243 };
244
245 let index = self.index.read().await;
246 let documents = self.documents.read().await;
247
248 let mut scores: Vec<(String, f32)> = Vec::new();
249
250 for (i, doc_embedding) in index.embeddings.iter().enumerate() {
251 let similarity = self.calculate_similarity(&query_vector, doc_embedding);
252
253 if let Some(threshold) = min_score {
254 if similarity < threshold {
255 continue;
256 }
257 }
258
259 scores.push((index.doc_ids[i].clone(), similarity));
260 }
261
262 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
263 scores.truncate(limit);
264
265 let results: Vec<SearchResult> = scores
266 .into_iter()
267 .enumerate()
268 .filter_map(|(rank, (doc_id, score))| {
269 documents.get(&doc_id).map(|doc| SearchResult {
270 id: doc_id,
271 content: doc.content.clone(),
272 score,
273 rank,
274 metadata: doc.metadata.clone(),
275 embedding: Some(doc.embedding.clone()),
276 })
277 })
278 .collect();
279
280 Ok(results)
281 }
282
283 fn calculate_similarity(&self, vec1: &[f32], vec2: &[f32]) -> f32 {
285 match self.config.similarity_metric {
286 SimilarityMetric::Cosine => Self::cosine_similarity(vec1, vec2),
287 SimilarityMetric::Euclidean => {
288 let distance = Self::euclidean_distance(vec1, vec2);
289 1.0 / (1.0 + distance) }
291 SimilarityMetric::DotProduct => Self::dot_product(vec1, vec2),
292 SimilarityMetric::Manhattan => {
293 let distance = Self::manhattan_distance(vec1, vec2);
294 1.0 / (1.0 + distance) }
296 }
297 }
298
299 fn cosine_similarity(vec1: &[f32], vec2: &[f32]) -> f32 {
301 let dot = Self::dot_product(vec1, vec2);
302 let norm1 = vec1.iter().map(|x| x * x).sum::<f32>().sqrt();
303 let norm2 = vec2.iter().map(|x| x * x).sum::<f32>().sqrt();
304
305 if norm1 == 0.0 || norm2 == 0.0 {
306 0.0
307 } else {
308 dot / (norm1 * norm2)
309 }
310 }
311
312 fn dot_product(vec1: &[f32], vec2: &[f32]) -> f32 {
314 vec1.iter().zip(vec2.iter()).map(|(a, b)| a * b).sum()
315 }
316
317 fn euclidean_distance(vec1: &[f32], vec2: &[f32]) -> f32 {
319 vec1.iter()
320 .zip(vec2.iter())
321 .map(|(a, b)| (a - b).powi(2))
322 .sum::<f32>()
323 .sqrt()
324 }
325
326 fn manhattan_distance(vec1: &[f32], vec2: &[f32]) -> f32 {
328 vec1.iter()
329 .zip(vec2.iter())
330 .map(|(a, b)| (a - b).abs())
331 .sum()
332 }
333
334 fn normalize_vector(vec: &[f32]) -> Vec<f32> {
336 let norm = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
337
338 if norm == 0.0 {
339 vec.to_vec()
340 } else {
341 vec.iter().map(|x| x / norm).collect()
342 }
343 }
344
345 pub async fn index_batch(&self, documents: Vec<Document>) -> RragResult<()> {
347 let requests: Vec<crate::EmbeddingRequest> = documents
349 .iter()
350 .map(|doc| crate::EmbeddingRequest::new(&doc.id, doc.content.as_ref()))
351 .collect();
352
353 let embedding_batch = self.embedding_service.embed_batch(requests).await?;
354
355 let mut docs_map = self.documents.write().await;
356 let mut index = self.index.write().await;
357
358 for doc in documents.iter() {
359 if let Some(embedding) = embedding_batch.embeddings.get(&doc.id) {
360 let normalized = if self.config.normalize_embeddings {
361 Some(Self::normalize_vector(&embedding.vector))
362 } else {
363 None
364 };
365
366 let vector_doc = VectorDocument {
367 id: doc.id.clone(),
368 content: doc.content.to_string(),
369 embedding: embedding.clone(),
370 normalized_embedding: normalized.clone(),
371 metadata: doc.metadata.clone(),
372 };
373
374 docs_map.insert(doc.id.clone(), vector_doc);
375 index.doc_ids.push(doc.id.clone());
376 index
377 .embeddings
378 .push(normalized.unwrap_or_else(|| embedding.vector.clone()));
379 }
380 }
381
382 Ok(())
383 }
384
385 pub async fn clear(&self) -> RragResult<()> {
387 let mut documents = self.documents.write().await;
388 let mut index = self.index.write().await;
389
390 documents.clear();
391 index.doc_ids.clear();
392 index.embeddings.clear();
393
394 Ok(())
395 }
396
397 pub async fn stats(&self) -> HashMap<String, serde_json::Value> {
399 let documents = self.documents.read().await;
400 let _index = self.index.read().await;
401
402 let mut stats = HashMap::new();
403 stats.insert("total_documents".to_string(), documents.len().into());
404 stats.insert(
405 "embedding_dimension".to_string(),
406 self.config.embedding_dimension.into(),
407 );
408 stats.insert(
409 "index_type".to_string(),
410 format!("{:?}", self.config.index_type).into(),
411 );
412 stats.insert(
413 "similarity_metric".to_string(),
414 format!("{:?}", self.config.similarity_metric).into(),
415 );
416
417 let memory_size = documents.len() * self.config.embedding_dimension * 4; stats.insert("index_memory_bytes".to_string(), memory_size.into());
419
420 stats
421 }
422}
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427 use crate::embeddings::MockEmbeddingService;
428
429 #[tokio::test]
430 async fn test_semantic_search() {
431 let mock_service = Arc::new(MockEmbeddingService::new());
432 let retriever = SemanticRetriever::new(SemanticConfig::default(), mock_service);
433
434 let docs = vec![
435 Document::with_id(
436 "1",
437 "Machine learning is a subset of artificial intelligence",
438 ),
439 Document::with_id("2", "Deep learning uses neural networks"),
440 Document::with_id(
441 "3",
442 "Natural language processing enables computers to understand text",
443 ),
444 ];
445
446 retriever.index_batch(docs).await.unwrap();
447
448 let results = retriever
449 .search("AI and machine learning", 2, Some(0.5))
450 .await
451 .unwrap();
452 assert!(!results.is_empty());
453 }
454}