use super::quantum::QuantumRanker;
use super::types::*;
use anyhow::Result;
use oxirs_core::Store;
use oxirs_vec::{Vector, VectorIndex};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info};
pub struct MultiStageRetrieval {
config: RetrievalConfig,
semantic_retriever: SemanticRetriever,
graph_retriever: GraphRetriever,
hybrid_retriever: HybridRetriever,
result_cache: Arc<RwLock<HashMap<String, Vec<SearchResult>>>>,
}
impl MultiStageRetrieval {
pub fn new(rag_config: &super::RAGConfig) -> Self {
let config = RetrievalConfig {
max_documents: rag_config.retrieval.max_results,
similarity_threshold: rag_config.retrieval.similarity_threshold as f64,
enable_reranking: true,
reranking_model: None,
enable_temporal_filtering: true,
temporal_window: Some(std::time::Duration::from_secs(365 * 24 * 3600)),
};
Self {
semantic_retriever: SemanticRetriever::new(),
graph_retriever: GraphRetriever::new(),
hybrid_retriever: HybridRetriever::new(),
result_cache: Arc::new(RwLock::new(HashMap::new())),
config,
}
}
pub async fn retrieve(
&self,
query: &str,
context: &QueryContext,
vector_index: &Arc<dyn VectorIndex>,
store: &Arc<dyn Store>,
) -> Result<Vec<SearchResult>> {
let start_time = std::time::Instant::now();
let cache_key = format!("{}:{}", query, context.session_id);
let cache = self.result_cache.read().await;
if let Some(cached_results) = cache.get(&cache_key) {
debug!("Cache hit for query: {}", query);
return Ok(cached_results.clone());
}
let mut all_results = Vec::new();
info!("Starting semantic retrieval for query: {}", query);
let semantic_results = self
.semantic_retriever
.retrieve(query, context, vector_index)
.await?;
all_results.extend(semantic_results);
if context.query_intent != QueryIntent::Information {
info!("Starting graph traversal");
let graph_results = self.graph_retriever.retrieve(query, context, store).await?;
all_results.extend(graph_results);
}
info!("Starting hybrid retrieval");
let hybrid_results = self
.hybrid_retriever
.retrieve(query, context, vector_index, store)
.await?;
all_results.extend(hybrid_results);
let final_results = self.deduplicate_and_rank(all_results, query).await?;
let filtered_results = self.apply_filters(final_results, context).await?;
let mut cache = self.result_cache.write().await;
cache.insert(cache_key, filtered_results.clone());
if cache.len() > 1000 {
cache.clear();
}
let retrieval_time = start_time.elapsed();
info!(
"Multi-stage retrieval completed in {:?}, found {} documents",
retrieval_time,
filtered_results.len()
);
Ok(filtered_results)
}
pub fn update_config(&mut self, rag_config: &super::RAGConfig) {
self.config.max_documents = rag_config.retrieval.max_results;
self.config.similarity_threshold = rag_config.retrieval.similarity_threshold as f64;
}
async fn deduplicate_and_rank(
&self,
mut results: Vec<SearchResult>,
query: &str,
) -> Result<Vec<SearchResult>> {
let mut seen_ids = HashSet::new();
results.retain(|result| seen_ids.insert(result.document.id.clone()));
let query_complexity = query.split_whitespace().count() as f64 / 10.0;
if query_complexity > 0.5 {
let documents: Vec<RagDocument> = results.iter().map(|r| r.document.clone()).collect();
let mut quantum_ranker = QuantumRanker::new(query_complexity);
let quantum_results = quantum_ranker.rank_documents(&documents);
results = quantum_results
.into_iter()
.map(|qr| qr.to_classical())
.collect();
} else {
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
}
results.truncate(self.config.max_documents);
Ok(results)
}
async fn apply_filters(
&self,
mut results: Vec<SearchResult>,
context: &QueryContext,
) -> Result<Vec<SearchResult>> {
results.retain(|result| result.score >= self.config.similarity_threshold);
if self.config.enable_temporal_filtering {
if let Some(window) = self.config.temporal_window {
let cutoff = chrono::Utc::now() - chrono::Duration::from_std(window)?;
results.retain(|result| result.document.timestamp >= cutoff);
}
}
if !context.domain_constraints.is_empty() {
results.retain(|result| {
context.domain_constraints.iter().any(|constraint| {
result.document.content.contains(constraint)
|| result
.document
.metadata
.values()
.any(|v| v.contains(constraint))
})
});
}
Ok(results)
}
}
pub struct SemanticRetriever {
embedding_cache: Arc<RwLock<HashMap<String, Vec<f32>>>>,
}
impl Default for SemanticRetriever {
fn default() -> Self {
Self::new()
}
}
impl SemanticRetriever {
pub fn new() -> Self {
Self {
embedding_cache: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn retrieve(
&self,
query: &str,
_context: &QueryContext,
vector_index: &Arc<dyn VectorIndex>,
) -> Result<Vec<SearchResult>> {
let query_embedding_vec = self.get_query_embedding(query).await?;
let query_embedding = Vector::new(query_embedding_vec);
let vector_results = vector_index.search_knn(&query_embedding, 20)?;
let search_results = vector_results
.into_iter()
.map(|(id, score)| {
SearchResult::new(
RagDocument::new(
id.clone(),
format!("Content for document {id}"),
"vector_index".to_string(),
),
score.into(),
)
.add_relevance_factor("semantic_similarity".to_string())
})
.collect();
Ok(search_results)
}
async fn get_query_embedding(&self, query: &str) -> Result<Vec<f32>> {
let cache = self.embedding_cache.read().await;
{
if let Some(embedding) = cache.get(query) {
return Ok(embedding.clone());
}
}
let embedding = query
.chars()
.map(|c| (c as u8 as f32) / 255.0)
.take(128)
.collect::<Vec<f32>>();
let mut padded = embedding;
padded.resize(128, 0.0);
let mut cache = self.embedding_cache.write().await;
{
cache.insert(query.to_string(), padded.clone());
}
Ok(padded)
}
}
pub struct GraphRetriever {
query_cache: Arc<RwLock<HashMap<String, Vec<SearchResult>>>>,
}
impl Default for GraphRetriever {
fn default() -> Self {
Self::new()
}
}
impl GraphRetriever {
pub fn new() -> Self {
Self {
query_cache: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn retrieve(
&self,
query: &str,
context: &QueryContext,
_store: &Arc<dyn Store>,
) -> Result<Vec<SearchResult>> {
let cache_key = format!("graph:{}:{}", query, context.session_id);
let cache = self.query_cache.read().await;
{
if let Some(cached) = cache.get(&cache_key) {
return Ok(cached.clone());
}
}
let sparql_query = self.generate_sparql_query(query, context)?;
let results = self.execute_sparql_query(&sparql_query).await?;
{
let mut cache = self.query_cache.write().await;
cache.insert(cache_key, results.clone());
}
Ok(results)
}
fn generate_sparql_query(&self, query: &str, _context: &QueryContext) -> Result<String> {
let query_terms: Vec<&str> = query.split_whitespace().collect();
let sparql = format!(
r#"
SELECT ?subject ?predicate ?object WHERE {{
?subject ?predicate ?object .
FILTER(contains(lcase(str(?object)), "{}"))
}}
LIMIT 10
"#,
query_terms.join(" ").to_lowercase()
);
Ok(sparql)
}
async fn execute_sparql_query(&self, _query: &str) -> Result<Vec<SearchResult>> {
Ok(vec![SearchResult::new(
RagDocument::new(
"graph_result_1".to_string(),
"Graph-based result content".to_string(),
"knowledge_graph".to_string(),
),
0.8,
)
.add_relevance_factor("graph_connectivity".to_string())])
}
}
pub struct HybridRetriever {
fusion_weights: HashMap<String, f64>,
}
impl Default for HybridRetriever {
fn default() -> Self {
Self::new()
}
}
impl HybridRetriever {
pub fn new() -> Self {
let mut weights = HashMap::new();
weights.insert("semantic".to_string(), 0.6);
weights.insert("keyword".to_string(), 0.3);
weights.insert("graph".to_string(), 0.1);
Self {
fusion_weights: weights,
}
}
pub async fn retrieve(
&self,
query: &str,
context: &QueryContext,
vector_index: &Arc<dyn VectorIndex>,
_store: &Arc<dyn Store>,
) -> Result<Vec<SearchResult>> {
let mut combined_results = Vec::new();
let keyword_results = self.keyword_search(query, context).await?;
combined_results.extend(keyword_results);
let additional_semantic = self.enhanced_semantic_search(query, vector_index).await?;
combined_results.extend(additional_semantic);
self.apply_fusion_scoring(&mut combined_results);
Ok(combined_results)
}
async fn keyword_search(
&self,
query: &str,
_context: &QueryContext,
) -> Result<Vec<SearchResult>> {
let keywords: Vec<&str> = query.split_whitespace().collect();
let results = keywords
.iter()
.enumerate()
.map(|(i, &keyword)| {
SearchResult::new(
RagDocument::new(
format!("keyword_result_{i}"),
format!("Document containing keyword: {keyword}"),
"keyword_search".to_string(),
),
0.7 - (i as f64 * 0.1),
)
.add_relevance_factor(format!("keyword_match: {keyword}"))
})
.collect();
Ok(results)
}
async fn enhanced_semantic_search(
&self,
query: &str,
vector_index: &Arc<dyn VectorIndex>,
) -> Result<Vec<SearchResult>> {
let expanded_query = format!("{query} related context information");
let semantic_retriever = SemanticRetriever::new();
let dummy_context = QueryContext::new("hybrid_search".to_string());
semantic_retriever
.retrieve(&expanded_query, &dummy_context, vector_index)
.await
}
fn apply_fusion_scoring(&self, results: &mut [SearchResult]) {
for result in results.iter_mut() {
let mut fused_score = 0.0;
let mut total_weight = 0.0;
for factor in &result.relevance_factors {
if factor.contains("semantic") {
fused_score +=
result.score * self.fusion_weights.get("semantic").unwrap_or(&0.0);
total_weight += self.fusion_weights.get("semantic").unwrap_or(&0.0);
} else if factor.contains("keyword") {
fused_score +=
result.score * self.fusion_weights.get("keyword").unwrap_or(&0.0);
total_weight += self.fusion_weights.get("keyword").unwrap_or(&0.0);
} else if factor.contains("graph") {
fused_score += result.score * self.fusion_weights.get("graph").unwrap_or(&0.0);
total_weight += self.fusion_weights.get("graph").unwrap_or(&0.0);
}
}
if total_weight > 0.0 {
result.score = fused_score / total_weight;
}
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_semantic_retriever_creation() {
let retriever = SemanticRetriever::new();
assert!(retriever.embedding_cache.try_read().is_ok());
}
#[test]
fn test_graph_retriever_sparql_generation() {
let retriever = GraphRetriever::new();
let context = QueryContext::new("test_session".to_string());
let sparql = retriever.generate_sparql_query("test query", &context);
assert!(sparql.is_ok());
assert!(sparql.expect("should succeed").contains("SELECT"));
}
#[test]
fn test_hybrid_retriever_weights() {
let retriever = HybridRetriever::new();
assert_eq!(retriever.fusion_weights.len(), 3);
assert!(retriever.fusion_weights.contains_key("semantic"));
}
}