use crate::{Document, DocumentChunk, Embedding, RragError, RragResult};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
pub id: String,
pub content: String,
pub score: f32,
pub rank: usize,
pub metadata: HashMap<String, serde_json::Value>,
pub embedding: Option<Embedding>,
}
impl SearchResult {
pub fn new(id: impl Into<String>, content: impl Into<String>, score: f32, rank: usize) -> Self {
Self {
id: id.into(),
content: content.into(),
score,
rank,
metadata: HashMap::new(),
embedding: None,
}
}
pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
pub fn with_embedding(mut self, embedding: Embedding) -> Self {
self.embedding = Some(embedding);
self
}
}
#[derive(Debug, Clone)]
pub struct SearchQuery {
pub query: QueryType,
pub limit: usize,
pub min_score: f32,
pub filters: HashMap<String, serde_json::Value>,
pub config: SearchConfig,
}
#[derive(Debug, Clone)]
pub enum QueryType {
Text(String),
Embedding(Embedding),
}
#[derive(Debug, Clone)]
pub struct SearchConfig {
pub include_embeddings: bool,
pub enable_reranking: bool,
pub algorithm: SearchAlgorithm,
pub scoring_weights: ScoringWeights,
}
#[derive(Debug, Clone)]
pub enum SearchAlgorithm {
Cosine,
Euclidean,
DotProduct,
Hybrid {
methods: Vec<SearchAlgorithm>,
weights: Vec<f32>,
},
}
#[derive(Debug, Clone)]
pub struct ScoringWeights {
pub semantic: f32,
pub metadata: f32,
pub recency: f32,
pub quality: f32,
}
impl Default for SearchConfig {
fn default() -> Self {
Self {
include_embeddings: false,
enable_reranking: true,
algorithm: SearchAlgorithm::Cosine,
scoring_weights: ScoringWeights::default(),
}
}
}
impl Default for ScoringWeights {
fn default() -> Self {
Self {
semantic: 1.0,
metadata: 0.1,
recency: 0.05,
quality: 0.1,
}
}
}
impl SearchQuery {
pub fn text(query: impl Into<String>) -> Self {
Self {
query: QueryType::Text(query.into()),
limit: 10,
min_score: 0.0,
filters: HashMap::new(),
config: SearchConfig::default(),
}
}
pub fn embedding(embedding: Embedding) -> Self {
Self {
query: QueryType::Embedding(embedding),
limit: 10,
min_score: 0.0,
filters: HashMap::new(),
config: SearchConfig::default(),
}
}
pub fn with_limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
}
pub fn with_min_score(mut self, min_score: f32) -> Self {
self.min_score = min_score;
self
}
pub fn with_filter(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.filters.insert(key.into(), value);
self
}
pub fn with_config(mut self, config: SearchConfig) -> Self {
self.config = config;
self
}
}
#[async_trait]
pub trait Retriever: Send + Sync {
fn name(&self) -> &str;
async fn search(&self, query: &SearchQuery) -> RragResult<Vec<SearchResult>>;
async fn add_documents(&self, documents: &[(Document, Embedding)]) -> RragResult<()>;
async fn add_chunks(&self, chunks: &[(DocumentChunk, Embedding)]) -> RragResult<()>;
async fn remove_documents(&self, document_ids: &[String]) -> RragResult<()>;
async fn clear(&self) -> RragResult<()>;
async fn stats(&self) -> RragResult<IndexStats>;
async fn health_check(&self) -> RragResult<bool>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexStats {
pub total_items: usize,
pub size_bytes: usize,
pub dimensions: usize,
pub index_type: String,
pub last_updated: chrono::DateTime<chrono::Utc>,
}
pub struct InMemoryRetriever {
documents: Arc<tokio::sync::RwLock<HashMap<String, (Document, Embedding)>>>,
chunks: Arc<tokio::sync::RwLock<HashMap<String, (DocumentChunk, Embedding)>>>,
config: RetrieverConfig,
}
#[derive(Debug, Clone)]
pub struct RetrieverConfig {
pub storage_mode: StorageMode,
pub default_threshold: f32,
pub max_results: usize,
}
#[derive(Debug, Clone)]
pub enum StorageMode {
DocumentsOnly,
ChunksOnly,
Both,
}
impl Default for RetrieverConfig {
fn default() -> Self {
Self {
storage_mode: StorageMode::Both,
default_threshold: 0.0,
max_results: 1000,
}
}
}
impl InMemoryRetriever {
pub fn new() -> Self {
Self {
documents: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
chunks: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
config: RetrieverConfig::default(),
}
}
pub fn with_config(config: RetrieverConfig) -> Self {
Self {
documents: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
chunks: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
config,
}
}
fn calculate_similarity(
&self,
embedding1: &Embedding,
embedding2: &Embedding,
algorithm: &SearchAlgorithm,
) -> RragResult<f32> {
match algorithm {
SearchAlgorithm::Cosine => embedding1.cosine_similarity(embedding2),
SearchAlgorithm::Euclidean => {
let distance = embedding1.euclidean_distance(embedding2)?;
Ok(1.0 / (1.0 + distance))
}
SearchAlgorithm::DotProduct => {
if embedding1.dimensions != embedding2.dimensions {
return Err(RragError::retrieval(format!(
"Dimension mismatch: {} vs {}",
embedding1.dimensions, embedding2.dimensions
)));
}
let dot_product: f32 = embedding1
.vector
.iter()
.zip(embedding2.vector.iter())
.map(|(a, b)| a * b)
.sum();
Ok(dot_product.max(0.0).min(1.0)) }
SearchAlgorithm::Hybrid { methods, weights } => {
let mut total_score = 0.0;
let mut total_weight = 0.0;
for (method, weight) in methods.iter().zip(weights.iter()) {
let score = self.calculate_similarity(embedding1, embedding2, method)?;
total_score += score * weight;
total_weight += weight;
}
if total_weight > 0.0 {
Ok(total_score / total_weight)
} else {
Ok(0.0)
}
}
}
}
fn apply_filters(
&self,
metadata: &HashMap<String, serde_json::Value>,
filters: &HashMap<String, serde_json::Value>,
) -> bool {
for (key, expected_value) in filters {
match metadata.get(key) {
Some(actual_value) if actual_value == expected_value => continue,
_ => return false,
}
}
true
}
fn rerank_results(
&self,
mut results: Vec<SearchResult>,
weights: &ScoringWeights,
) -> Vec<SearchResult> {
for result in &mut results {
let mut enhanced_score = result.score * weights.semantic;
if !result.metadata.is_empty() {
enhanced_score += 0.1 * weights.metadata;
}
if let Some(timestamp_value) = result.metadata.get("created_at") {
if let Some(timestamp_str) = timestamp_value.as_str() {
if let Ok(timestamp) = chrono::DateTime::parse_from_rfc3339(timestamp_str) {
let age_days =
(chrono::Utc::now() - timestamp.with_timezone(&chrono::Utc)).num_days();
let recency_bonus = (-age_days as f32 / 30.0).exp() * weights.recency;
enhanced_score += recency_bonus;
}
}
}
let content_length = result.content.len();
if content_length > 100 && content_length < 2000 {
enhanced_score += 0.05 * weights.quality;
}
result.score = enhanced_score.min(1.0);
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
for (i, result) in results.iter_mut().enumerate() {
result.rank = i;
}
results
}
}
impl Default for InMemoryRetriever {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Retriever for InMemoryRetriever {
fn name(&self) -> &str {
"in_memory"
}
async fn search(&self, query: &SearchQuery) -> RragResult<Vec<SearchResult>> {
let query_embedding = match &query.query {
QueryType::Text(_) => {
return Err(RragError::retrieval(
"Text queries require pre-computed embeddings for in-memory retriever"
.to_string(),
));
}
QueryType::Embedding(emb) => emb,
};
let mut results = Vec::new();
if matches!(
self.config.storage_mode,
StorageMode::DocumentsOnly | StorageMode::Both
) {
let documents = self.documents.read().await;
for (doc_id, (document, embedding)) in documents.iter() {
if !self.apply_filters(&document.metadata, &query.filters) {
continue;
}
let similarity =
self.calculate_similarity(query_embedding, embedding, &query.config.algorithm)?;
if similarity >= query.min_score {
let mut result = SearchResult::new(
doc_id,
document.content_str(),
similarity,
0, )
.with_metadata("type", serde_json::Value::String("document".to_string()));
for (key, value) in &document.metadata {
result = result.with_metadata(key, value.clone());
}
if query.config.include_embeddings {
result = result.with_embedding(embedding.clone());
}
results.push(result);
}
}
}
if matches!(
self.config.storage_mode,
StorageMode::ChunksOnly | StorageMode::Both
) {
let chunks = self.chunks.read().await;
for (chunk_id, (chunk, embedding)) in chunks.iter() {
if !self.apply_filters(&chunk.metadata, &query.filters) {
continue;
}
let similarity =
self.calculate_similarity(query_embedding, embedding, &query.config.algorithm)?;
if similarity >= query.min_score {
let mut result = SearchResult::new(
chunk_id,
&chunk.content,
similarity,
0, )
.with_metadata("type", serde_json::Value::String("chunk".to_string()))
.with_metadata(
"document_id",
serde_json::Value::String(chunk.document_id.clone()),
)
.with_metadata(
"chunk_index",
serde_json::Value::Number(chunk.chunk_index.into()),
);
for (key, value) in &chunk.metadata {
result = result.with_metadata(key, value.clone());
}
if query.config.include_embeddings {
result = result.with_embedding(embedding.clone());
}
results.push(result);
}
}
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
if query.config.enable_reranking {
results = self.rerank_results(results, &query.config.scoring_weights);
}
for (i, result) in results.iter_mut().enumerate() {
result.rank = i;
}
results.truncate(query.limit.min(self.config.max_results));
Ok(results)
}
async fn add_documents(&self, documents: &[(Document, Embedding)]) -> RragResult<()> {
let mut docs = self.documents.write().await;
for (document, embedding) in documents {
docs.insert(document.id.clone(), (document.clone(), embedding.clone()));
}
Ok(())
}
async fn add_chunks(&self, chunks: &[(DocumentChunk, Embedding)]) -> RragResult<()> {
let mut chunk_store = self.chunks.write().await;
for (chunk, embedding) in chunks {
let chunk_id = format!("{}_{}", chunk.document_id, chunk.chunk_index);
chunk_store.insert(chunk_id, (chunk.clone(), embedding.clone()));
}
Ok(())
}
async fn remove_documents(&self, document_ids: &[String]) -> RragResult<()> {
let mut docs = self.documents.write().await;
for doc_id in document_ids {
docs.remove(doc_id);
}
let mut chunk_store = self.chunks.write().await;
let chunk_ids_to_remove: Vec<String> = chunk_store
.iter()
.filter(|(_, (chunk, _))| document_ids.contains(&chunk.document_id))
.map(|(id, _)| id.clone())
.collect();
for chunk_id in chunk_ids_to_remove {
chunk_store.remove(&chunk_id);
}
Ok(())
}
async fn clear(&self) -> RragResult<()> {
self.documents.write().await.clear();
self.chunks.write().await.clear();
Ok(())
}
async fn stats(&self) -> RragResult<IndexStats> {
let doc_count = self.documents.read().await.len();
let chunk_count = self.chunks.read().await.len();
let dimensions = if doc_count > 0 {
self.documents
.read()
.await
.values()
.next()
.map(|(_, emb)| emb.dimensions)
.unwrap_or(0)
} else if chunk_count > 0 {
self.chunks
.read()
.await
.values()
.next()
.map(|(_, emb)| emb.dimensions)
.unwrap_or(0)
} else {
0
};
Ok(IndexStats {
total_items: doc_count + chunk_count,
size_bytes: (doc_count + chunk_count) * dimensions * 4, dimensions,
index_type: "in_memory".to_string(),
last_updated: chrono::Utc::now(),
})
}
async fn health_check(&self) -> RragResult<bool> {
Ok(true)
}
}
pub struct RetrievalService {
retriever: Arc<dyn Retriever>,
config: RetrievalServiceConfig,
}
#[derive(Debug, Clone)]
pub struct RetrievalServiceConfig {
pub default_search_config: SearchConfig,
pub enable_caching: bool,
pub cache_ttl_seconds: u64,
}
impl Default for RetrievalServiceConfig {
fn default() -> Self {
Self {
default_search_config: SearchConfig::default(),
enable_caching: false,
cache_ttl_seconds: 300, }
}
}
impl RetrievalService {
pub fn new(retriever: Arc<dyn Retriever>) -> Self {
Self {
retriever,
config: RetrievalServiceConfig::default(),
}
}
pub fn with_config(retriever: Arc<dyn Retriever>, config: RetrievalServiceConfig) -> Self {
Self { retriever, config }
}
pub async fn search_text(
&self,
_query: &str,
_limit: Option<usize>,
) -> RragResult<Vec<SearchResult>> {
Err(RragError::retrieval(
"Text search requires embedding service integration".to_string(),
))
}
pub async fn search_embedding(
&self,
embedding: Embedding,
limit: Option<usize>,
) -> RragResult<Vec<SearchResult>> {
let query = SearchQuery::embedding(embedding)
.with_limit(limit.unwrap_or(10))
.with_config(self.config.default_search_config.clone());
self.retriever.search(&query).await
}
pub async fn search(&self, query: SearchQuery) -> RragResult<Vec<SearchResult>> {
self.retriever.search(&query).await
}
pub async fn index_documents(
&self,
documents_with_embeddings: &[(Document, Embedding)],
) -> RragResult<()> {
self.retriever
.add_documents(documents_with_embeddings)
.await
}
pub async fn index_chunks(
&self,
chunks_with_embeddings: &[(DocumentChunk, Embedding)],
) -> RragResult<()> {
self.retriever.add_chunks(chunks_with_embeddings).await
}
pub async fn get_stats(&self) -> RragResult<IndexStats> {
self.retriever.stats().await
}
pub async fn health_check(&self) -> RragResult<bool> {
self.retriever.health_check().await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Document;
#[tokio::test]
async fn test_in_memory_retriever() {
let retriever = InMemoryRetriever::new();
let doc1 = Document::new("First test document");
let emb1 = Embedding::new(vec![1.0, 0.0, 0.0], "test-model", &doc1.id);
let doc2 = Document::new("Second test document");
let emb2 = Embedding::new(vec![0.0, 1.0, 0.0], "test-model", &doc2.id);
retriever
.add_documents(&[(doc1.clone(), emb1.clone()), (doc2, emb2)])
.await
.unwrap();
let query_embedding = Embedding::new(vec![0.8, 0.2, 0.0], "test-model", "query");
let query = SearchQuery::embedding(query_embedding).with_limit(5);
let results = retriever.search(&query).await.unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].id, doc1.id); }
#[tokio::test]
async fn test_search_filters() {
let retriever = InMemoryRetriever::new();
let doc1 = Document::new("Test document")
.with_metadata("category", serde_json::Value::String("tech".to_string()));
let emb1 = Embedding::new(vec![1.0, 0.0], "test-model", &doc1.id);
let doc2 = Document::new("Another document")
.with_metadata("category", serde_json::Value::String("science".to_string()));
let emb2 = Embedding::new(vec![0.9, 0.1], "test-model", &doc2.id);
retriever
.add_documents(&[(doc1.clone(), emb1), (doc2, emb2)])
.await
.unwrap();
let query_embedding = Embedding::new(vec![1.0, 0.0], "test-model", "query");
let query = SearchQuery::embedding(query_embedding)
.with_filter("category", serde_json::Value::String("tech".to_string()));
let results = retriever.search(&query).await.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, doc1.id);
}
#[test]
fn test_search_query_builder() {
let query = SearchQuery::text("test query")
.with_limit(20)
.with_min_score(0.5)
.with_filter("type", serde_json::Value::String("article".to_string()));
assert_eq!(query.limit, 20);
assert_eq!(query.min_score, 0.5);
assert_eq!(query.filters.len(), 1);
}
#[tokio::test]
async fn test_retrieval_service() {
let retriever = Arc::new(InMemoryRetriever::new());
let service = RetrievalService::new(retriever);
let stats = service.get_stats().await.unwrap();
assert_eq!(stats.total_items, 0);
assert!(service.health_check().await.unwrap());
}
}