use crate::{Document, Embedding, EmbeddingProvider, RragResult, SearchResult};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SemanticConfig {
pub similarity_metric: SimilarityMetric,
pub embedding_dimension: usize,
pub normalize_embeddings: bool,
pub index_type: IndexType,
pub num_clusters: Option<usize>,
pub num_probes: Option<usize>,
pub use_gpu: bool,
}
impl Default for SemanticConfig {
fn default() -> Self {
Self {
similarity_metric: SimilarityMetric::Cosine,
embedding_dimension: 768,
normalize_embeddings: true,
index_type: IndexType::Flat,
num_clusters: None,
num_probes: None,
use_gpu: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SimilarityMetric {
Cosine,
Euclidean,
DotProduct,
Manhattan,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum IndexType {
Flat,
IVF,
HNSW,
LSH,
}
#[derive(Debug, Clone)]
struct VectorDocument {
id: String,
content: String,
embedding: Embedding,
normalized_embedding: Option<Vec<f32>>,
metadata: HashMap<String, serde_json::Value>,
}
pub struct SemanticRetriever {
config: SemanticConfig,
documents: Arc<RwLock<HashMap<String, VectorDocument>>>,
embedding_service: Arc<dyn EmbeddingProvider>,
index: Arc<RwLock<VectorIndex>>,
}
struct VectorIndex {
doc_ids: Vec<String>,
embeddings: Vec<Vec<f32>>,
index_type: IndexType,
}
impl SemanticRetriever {
pub fn new(config: SemanticConfig, embedding_service: Arc<dyn EmbeddingProvider>) -> Self {
Self {
config,
documents: Arc::new(RwLock::new(HashMap::new())),
embedding_service,
index: Arc::new(RwLock::new(VectorIndex {
doc_ids: Vec::new(),
embeddings: Vec::new(),
index_type: IndexType::Flat,
})),
}
}
pub async fn index_document(&self, doc: &Document) -> RragResult<()> {
let embedding = self.embedding_service.embed_text(&doc.content).await?;
let normalized = if self.config.normalize_embeddings {
Some(Self::normalize_vector(&embedding.vector))
} else {
None
};
let vector_doc = VectorDocument {
id: doc.id.clone(),
content: doc.content.to_string(),
embedding: embedding.clone(),
normalized_embedding: normalized,
metadata: doc.metadata.clone(),
};
let mut documents = self.documents.write().await;
documents.insert(doc.id.clone(), vector_doc);
let mut index = self.index.write().await;
index.doc_ids.push(doc.id.clone());
index.embeddings.push(if self.config.normalize_embeddings {
Self::normalize_vector(&embedding.vector)
} else {
embedding.vector
});
Ok(())
}
pub async fn search(
&self,
query: &str,
limit: usize,
min_score: Option<f32>,
) -> RragResult<Vec<SearchResult>> {
let query_embedding = self.embedding_service.embed_text(query).await?;
let query_vector = if self.config.normalize_embeddings {
Self::normalize_vector(&query_embedding.vector)
} else {
query_embedding.vector
};
let index = self.index.read().await;
let documents = self.documents.read().await;
let mut scores: Vec<(String, f32)> = Vec::new();
for (i, doc_embedding) in index.embeddings.iter().enumerate() {
let similarity = self.calculate_similarity(&query_vector, doc_embedding);
if let Some(threshold) = min_score {
if similarity < threshold {
continue;
}
}
scores.push((index.doc_ids[i].clone(), similarity));
}
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
scores.truncate(limit);
let results: Vec<SearchResult> = scores
.into_iter()
.enumerate()
.filter_map(|(rank, (doc_id, score))| {
documents.get(&doc_id).map(|doc| SearchResult {
id: doc_id,
content: doc.content.clone(),
score,
rank,
metadata: doc.metadata.clone(),
embedding: Some(doc.embedding.clone()),
})
})
.collect();
Ok(results)
}
pub async fn search_by_embedding(
&self,
embedding: &Embedding,
limit: usize,
min_score: Option<f32>,
) -> RragResult<Vec<SearchResult>> {
let query_vector = if self.config.normalize_embeddings {
Self::normalize_vector(&embedding.vector)
} else {
embedding.vector.clone()
};
let index = self.index.read().await;
let documents = self.documents.read().await;
let mut scores: Vec<(String, f32)> = Vec::new();
for (i, doc_embedding) in index.embeddings.iter().enumerate() {
let similarity = self.calculate_similarity(&query_vector, doc_embedding);
if let Some(threshold) = min_score {
if similarity < threshold {
continue;
}
}
scores.push((index.doc_ids[i].clone(), similarity));
}
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
scores.truncate(limit);
let results: Vec<SearchResult> = scores
.into_iter()
.enumerate()
.filter_map(|(rank, (doc_id, score))| {
documents.get(&doc_id).map(|doc| SearchResult {
id: doc_id,
content: doc.content.clone(),
score,
rank,
metadata: doc.metadata.clone(),
embedding: Some(doc.embedding.clone()),
})
})
.collect();
Ok(results)
}
fn calculate_similarity(&self, vec1: &[f32], vec2: &[f32]) -> f32 {
match self.config.similarity_metric {
SimilarityMetric::Cosine => Self::cosine_similarity(vec1, vec2),
SimilarityMetric::Euclidean => {
let distance = Self::euclidean_distance(vec1, vec2);
1.0 / (1.0 + distance) }
SimilarityMetric::DotProduct => Self::dot_product(vec1, vec2),
SimilarityMetric::Manhattan => {
let distance = Self::manhattan_distance(vec1, vec2);
1.0 / (1.0 + distance) }
}
}
fn cosine_similarity(vec1: &[f32], vec2: &[f32]) -> f32 {
let dot = Self::dot_product(vec1, vec2);
let norm1 = vec1.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm2 = vec2.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm1 == 0.0 || norm2 == 0.0 {
0.0
} else {
dot / (norm1 * norm2)
}
}
fn dot_product(vec1: &[f32], vec2: &[f32]) -> f32 {
vec1.iter().zip(vec2.iter()).map(|(a, b)| a * b).sum()
}
fn euclidean_distance(vec1: &[f32], vec2: &[f32]) -> f32 {
vec1.iter()
.zip(vec2.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f32>()
.sqrt()
}
fn manhattan_distance(vec1: &[f32], vec2: &[f32]) -> f32 {
vec1.iter()
.zip(vec2.iter())
.map(|(a, b)| (a - b).abs())
.sum()
}
fn normalize_vector(vec: &[f32]) -> Vec<f32> {
let norm = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm == 0.0 {
vec.to_vec()
} else {
vec.iter().map(|x| x / norm).collect()
}
}
pub async fn index_batch(&self, documents: Vec<Document>) -> RragResult<()> {
let requests: Vec<crate::EmbeddingRequest> = documents
.iter()
.map(|doc| crate::EmbeddingRequest::new(&doc.id, doc.content.as_ref()))
.collect();
let embedding_batch = self.embedding_service.embed_batch(requests).await?;
let mut docs_map = self.documents.write().await;
let mut index = self.index.write().await;
for doc in documents.iter() {
if let Some(embedding) = embedding_batch.embeddings.get(&doc.id) {
let normalized = if self.config.normalize_embeddings {
Some(Self::normalize_vector(&embedding.vector))
} else {
None
};
let vector_doc = VectorDocument {
id: doc.id.clone(),
content: doc.content.to_string(),
embedding: embedding.clone(),
normalized_embedding: normalized.clone(),
metadata: doc.metadata.clone(),
};
docs_map.insert(doc.id.clone(), vector_doc);
index.doc_ids.push(doc.id.clone());
index
.embeddings
.push(normalized.unwrap_or_else(|| embedding.vector.clone()));
}
}
Ok(())
}
pub async fn clear(&self) -> RragResult<()> {
let mut documents = self.documents.write().await;
let mut index = self.index.write().await;
documents.clear();
index.doc_ids.clear();
index.embeddings.clear();
Ok(())
}
pub async fn stats(&self) -> HashMap<String, serde_json::Value> {
let documents = self.documents.read().await;
let _index = self.index.read().await;
let mut stats = HashMap::new();
stats.insert("total_documents".to_string(), documents.len().into());
stats.insert(
"embedding_dimension".to_string(),
self.config.embedding_dimension.into(),
);
stats.insert(
"index_type".to_string(),
format!("{:?}", self.config.index_type).into(),
);
stats.insert(
"similarity_metric".to_string(),
format!("{:?}", self.config.similarity_metric).into(),
);
let memory_size = documents.len() * self.config.embedding_dimension * 4; stats.insert("index_memory_bytes".to_string(), memory_size.into());
stats
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::embeddings::MockEmbeddingService;
#[tokio::test]
async fn test_semantic_search() {
let mock_service = Arc::new(MockEmbeddingService::new());
let retriever = SemanticRetriever::new(SemanticConfig::default(), mock_service);
let docs = vec![
Document::with_id(
"1",
"Machine learning is a subset of artificial intelligence",
),
Document::with_id("2", "Deep learning uses neural networks"),
Document::with_id(
"3",
"Natural language processing enables computers to understand text",
),
];
retriever.index_batch(docs).await.unwrap();
let results = retriever
.search("AI and machine learning", 2, Some(0.5))
.await
.unwrap();
assert!(!results.is_empty());
}
}