use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tracing::{debug, info};
use super::types::{QueryContext, RagDocument, SearchResult};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdvancedRetrievalConfig {
pub enable_dpr: bool,
pub enable_colbert: bool,
pub enable_bm25_plus: bool,
pub enable_ltr: bool,
pub enable_query_expansion: bool,
pub enable_multi_hop: bool,
pub max_expansion_terms: usize,
pub multi_hop_depth: usize,
pub ltr_model_path: Option<String>,
}
impl Default for AdvancedRetrievalConfig {
fn default() -> Self {
Self {
enable_dpr: true,
enable_colbert: false, enable_bm25_plus: true,
enable_ltr: true,
enable_query_expansion: true,
enable_multi_hop: false,
max_expansion_terms: 10,
multi_hop_depth: 2,
ltr_model_path: None,
}
}
}
pub struct AdvancedRetriever {
config: AdvancedRetrievalConfig,
dpr_retriever: Option<DensePassageRetriever>,
colbert_retriever: Option<ColBERTRetriever>,
bm25_plus: BM25Plus,
ltr_ranker: Option<LearningToRank>,
query_expander: QueryExpander,
}
impl AdvancedRetriever {
pub fn new(config: AdvancedRetrievalConfig) -> Self {
info!("Initializing advanced retrieval strategies");
let dpr_retriever = if config.enable_dpr {
Some(DensePassageRetriever::new())
} else {
None
};
let colbert_retriever = if config.enable_colbert {
Some(ColBERTRetriever::new())
} else {
None
};
let ltr_ranker = if config.enable_ltr {
Some(LearningToRank::new(config.ltr_model_path.clone()))
} else {
None
};
Self {
config,
dpr_retriever,
colbert_retriever,
bm25_plus: BM25Plus::new(),
ltr_ranker,
query_expander: QueryExpander::new(),
}
}
pub async fn retrieve(
&self,
query: &str,
context: &QueryContext,
candidate_docs: &[RagDocument],
) -> Result<Vec<SearchResult>> {
let mut all_results = Vec::new();
let expanded_query = if self.config.enable_query_expansion {
self.query_expander.expand(query, context).await?
} else {
query.to_string()
};
debug!("Expanded query: {} -> {}", query, expanded_query);
if let Some(ref dpr) = self.dpr_retriever {
let dpr_results = dpr.retrieve(&expanded_query, candidate_docs).await?;
all_results.extend(dpr_results);
}
if let Some(ref colbert) = self.colbert_retriever {
let colbert_results = colbert.retrieve(&expanded_query, candidate_docs).await?;
all_results.extend(colbert_results);
}
if self.config.enable_bm25_plus {
let bm25_results = self
.bm25_plus
.retrieve(&expanded_query, candidate_docs)
.await?;
all_results.extend(bm25_results);
}
if self.config.enable_multi_hop {
let multi_hop_results = self
.multi_hop_retrieve(&expanded_query, candidate_docs)
.await?;
all_results.extend(multi_hop_results);
}
let final_results = if let Some(ref ltr) = self.ltr_ranker {
ltr.rerank(all_results, &expanded_query, context).await?
} else {
let mut sorted = all_results;
sorted.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
sorted
};
Ok(final_results)
}
async fn multi_hop_retrieve(
&self,
query: &str,
candidate_docs: &[RagDocument],
) -> Result<Vec<SearchResult>> {
let mut hop_results = Vec::new();
let mut current_query = query.to_string();
for hop in 0..self.config.multi_hop_depth {
debug!("Multi-hop retrieval: hop {}", hop);
let hop_docs = self
.bm25_plus
.retrieve(¤t_query, candidate_docs)
.await?;
if hop_docs.is_empty() {
break;
}
let top_doc = &hop_docs[0].document;
current_query = format!(
"{} {}",
current_query,
top_doc
.content
.split_whitespace()
.take(20)
.collect::<Vec<_>>()
.join(" ")
);
hop_results.extend(hop_docs);
}
Ok(hop_results)
}
}
pub struct DensePassageRetriever {
}
impl DensePassageRetriever {
pub fn new() -> Self {
Self {}
}
pub async fn retrieve(
&self,
query: &str,
documents: &[RagDocument],
) -> Result<Vec<SearchResult>> {
let query_embedding = self.encode_query(query)?;
let mut results = Vec::new();
for doc in documents {
let doc_embedding = self.encode_passage(&doc.content)?;
let score = self.compute_similarity(&query_embedding, &doc_embedding);
let mut result = SearchResult::new(doc.clone(), score);
result.relevance_factors.push("DPR".to_string());
results.push(result);
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(results)
}
fn encode_query(&self, query: &str) -> Result<Vec<f32>> {
Ok(query
.split_whitespace()
.enumerate()
.map(|(i, _)| (i as f32 + 1.0) / query.split_whitespace().count() as f32)
.collect())
}
fn encode_passage(&self, passage: &str) -> Result<Vec<f32>> {
Ok(passage
.split_whitespace()
.enumerate()
.map(|(i, _)| (i as f32 + 1.0) / passage.split_whitespace().count() as f32)
.collect())
}
fn compute_similarity(&self, query_emb: &[f32], doc_emb: &[f32]) -> f64 {
let dot_product: f32 = query_emb
.iter()
.zip(doc_emb.iter())
.map(|(a, b)| a * b)
.sum();
let query_norm: f32 = query_emb.iter().map(|x| x * x).sum::<f32>().sqrt();
let doc_norm: f32 = doc_emb.iter().map(|x| x * x).sum::<f32>().sqrt();
if query_norm > 0.0 && doc_norm > 0.0 {
(dot_product / (query_norm * doc_norm)) as f64
} else {
0.0
}
}
}
impl Default for DensePassageRetriever {
fn default() -> Self {
Self::new()
}
}
pub struct ColBERTRetriever {
}
impl ColBERTRetriever {
pub fn new() -> Self {
Self {}
}
pub async fn retrieve(
&self,
query: &str,
documents: &[RagDocument],
) -> Result<Vec<SearchResult>> {
let query_tokens = self.tokenize_and_encode(query)?;
let mut results = Vec::new();
for doc in documents {
let doc_tokens = self.tokenize_and_encode(&doc.content)?;
let score = self.max_sim_interaction(&query_tokens, &doc_tokens);
let mut result = SearchResult::new(doc.clone(), score);
result.relevance_factors.push("ColBERT".to_string());
results.push(result);
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(results)
}
fn tokenize_and_encode(&self, text: &str) -> Result<Vec<Vec<f32>>> {
Ok(text
.split_whitespace()
.map(|token| {
token
.chars()
.enumerate()
.map(|(i, _)| (i as f32 + 1.0) / token.len() as f32)
.collect()
})
.collect())
}
fn max_sim_interaction(&self, query_tokens: &[Vec<f32>], doc_tokens: &[Vec<f32>]) -> f64 {
let mut total_score = 0.0;
for q_token in query_tokens {
let mut max_sim = 0.0;
for d_token in doc_tokens {
let sim = self.cosine_similarity(q_token, d_token);
if sim > max_sim {
max_sim = sim;
}
}
total_score += max_sim;
}
total_score / query_tokens.len() as f64
}
fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f64 {
if a.is_empty() || b.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a > 0.0 && norm_b > 0.0 {
(dot / (norm_a * norm_b)) as f64
} else {
0.0
}
}
}
impl Default for ColBERTRetriever {
fn default() -> Self {
Self::new()
}
}
pub struct BM25Plus {
k1: f64, b: f64, delta: f64, }
impl BM25Plus {
pub fn new() -> Self {
Self {
k1: 1.5, b: 0.75, delta: 0.5, }
}
pub async fn retrieve(
&self,
query: &str,
documents: &[RagDocument],
) -> Result<Vec<SearchResult>> {
let query_terms: Vec<&str> = query.split_whitespace().collect();
let avg_doc_length = self.compute_avg_doc_length(documents);
let mut results = Vec::new();
for doc in documents {
let score = self.compute_bm25_plus_score(
&query_terms,
&doc.content,
avg_doc_length,
documents.len(),
);
let mut result = SearchResult::new(doc.clone(), score);
result.relevance_factors.push("BM25+".to_string());
results.push(result);
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(results)
}
fn compute_avg_doc_length(&self, documents: &[RagDocument]) -> f64 {
if documents.is_empty() {
return 0.0;
}
let total_length: usize = documents
.iter()
.map(|d| d.content.split_whitespace().count())
.sum();
total_length as f64 / documents.len() as f64
}
fn compute_bm25_plus_score(
&self,
query_terms: &[&str],
document: &str,
avg_doc_length: f64,
corpus_size: usize,
) -> f64 {
let doc_terms: Vec<&str> = document.split_whitespace().collect();
let doc_length = doc_terms.len() as f64;
let mut score = 0.0;
for term in query_terms {
let term_freq = doc_terms.iter().filter(|&&t| t == *term).count() as f64;
let idf = ((corpus_size as f64 + 1.0) / 2.0).ln();
let tf_component = (term_freq * (self.k1 + 1.0))
/ (term_freq + self.k1 * (1.0 - self.b + self.b * doc_length / avg_doc_length));
score += idf * (tf_component + self.delta);
}
score
}
}
impl Default for BM25Plus {
fn default() -> Self {
Self::new()
}
}
pub struct LearningToRank {
model_path: Option<String>,
feature_weights: HashMap<String, f64>,
}
impl LearningToRank {
pub fn new(model_path: Option<String>) -> Self {
let mut feature_weights = HashMap::new();
feature_weights.insert("semantic_score".to_string(), 0.4);
feature_weights.insert("bm25_score".to_string(), 0.3);
feature_weights.insert("query_term_coverage".to_string(), 0.2);
feature_weights.insert("document_freshness".to_string(), 0.1);
Self {
model_path,
feature_weights,
}
}
pub async fn rerank(
&self,
results: Vec<SearchResult>,
query: &str,
_context: &QueryContext,
) -> Result<Vec<SearchResult>> {
let mut reranked = results;
for result in &mut reranked {
let features = self.extract_features(&result.document, query);
let ltr_score = self.compute_ltr_score(&features);
result.score = ltr_score;
}
reranked.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(reranked)
}
fn extract_features(&self, document: &RagDocument, query: &str) -> HashMap<String, f64> {
let mut features = HashMap::new();
let query_terms: Vec<&str> = query.split_whitespace().collect();
let doc_terms: Vec<&str> = document.content.split_whitespace().collect();
let coverage = query_terms
.iter()
.filter(|&term| doc_terms.contains(term))
.count() as f64
/ query_terms.len() as f64;
features.insert("query_term_coverage".to_string(), coverage);
let age_days = (chrono::Utc::now() - document.timestamp).num_days() as f64;
let freshness = 1.0 / (1.0 + age_days / 365.0); features.insert("document_freshness".to_string(), freshness);
features.insert("semantic_score".to_string(), 0.5);
features.insert("bm25_score".to_string(), 0.5);
features
}
fn compute_ltr_score(&self, features: &HashMap<String, f64>) -> f64 {
let mut score = 0.0;
for (feature, value) in features {
if let Some(&weight) = self.feature_weights.get(feature) {
score += weight * value;
}
}
score
}
}
pub struct QueryExpander {
expansion_terms: HashMap<String, Vec<String>>,
}
impl QueryExpander {
pub fn new() -> Self {
let mut expansion_terms = HashMap::new();
expansion_terms.insert(
"search".to_string(),
vec!["find".to_string(), "look".to_string(), "query".to_string()],
);
expansion_terms.insert(
"person".to_string(),
vec![
"people".to_string(),
"individual".to_string(),
"human".to_string(),
],
);
Self { expansion_terms }
}
pub async fn expand(&self, query: &str, _context: &QueryContext) -> Result<String> {
let mut expanded_terms = Vec::new();
for term in query.split_whitespace() {
expanded_terms.push(term.to_string());
if let Some(expansions) = self.expansion_terms.get(&term.to_lowercase()) {
expanded_terms.extend(expansions.iter().take(2).cloned());
}
}
Ok(expanded_terms.join(" "))
}
}
impl Default for QueryExpander {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::Utc;
#[test]
fn test_bm25_plus_scoring() {
let bm25 = BM25Plus::new();
let docs = [
RagDocument {
id: "1".to_string(),
content: "machine learning artificial intelligence".to_string(),
embedding: None,
metadata: HashMap::new(),
timestamp: Utc::now(),
source: "test".to_string(),
},
RagDocument {
id: "2".to_string(),
content: "deep learning neural networks".to_string(),
embedding: None,
metadata: HashMap::new(),
timestamp: Utc::now(),
source: "test".to_string(),
},
];
let score =
bm25.compute_bm25_plus_score(&["machine", "learning"], &docs[0].content, 10.0, 2);
assert!(score > 0.0);
}
#[tokio::test]
async fn test_query_expansion() {
let expander = QueryExpander::new();
let expanded = expander
.expand("search person", &QueryContext::new("test".to_string()))
.await
.expect("should succeed");
assert!(expanded.contains("search"));
assert!(expanded.contains("person"));
assert!(expanded.split_whitespace().count() > 2);
}
}