use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::error::OxiRagError;
use crate::layer1_echo::EmbeddingProvider;
use crate::layer1_echo::similarity::cosine_similarity;
use crate::types::{Query, SearchResult};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RerankerConfig {
pub top_k: usize,
pub min_score_threshold: f32,
pub batch_size: usize,
}
impl Default for RerankerConfig {
fn default() -> Self {
Self {
top_k: 10,
min_score_threshold: 0.0,
batch_size: 32,
}
}
}
impl RerankerConfig {
#[must_use]
pub fn with_top_k(mut self, top_k: usize) -> Self {
self.top_k = top_k;
self
}
#[must_use]
pub fn with_min_score_threshold(mut self, threshold: f32) -> Self {
self.min_score_threshold = threshold;
self
}
#[must_use]
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = batch_size;
self
}
}
#[async_trait]
pub trait Reranker: Send + Sync {
async fn rerank(
&self,
query: &Query,
results: Vec<SearchResult>,
) -> Result<Vec<SearchResult>, OxiRagError>;
async fn score_pair(&self, query: &str, document: &str) -> Result<f32, OxiRagError>;
fn config(&self) -> &RerankerConfig;
}
#[async_trait]
pub trait CrossEncoderReranker: Send + Sync {
async fn encode_pair(&self, query: &str, document: &str) -> Result<f32, OxiRagError>;
async fn encode_pairs(&self, pairs: &[(&str, &str)]) -> Result<Vec<f32>, OxiRagError>;
}
#[derive(Debug, Clone)]
pub struct MockReranker {
config: RerankerConfig,
score_multiplier: f32,
}
impl MockReranker {
#[must_use]
pub fn new(config: RerankerConfig) -> Self {
Self {
config,
score_multiplier: 1.0,
}
}
#[must_use]
pub fn with_score_multiplier(mut self, multiplier: f32) -> Self {
self.score_multiplier = multiplier;
self
}
}
impl Default for MockReranker {
fn default() -> Self {
Self::new(RerankerConfig::default())
}
}
#[async_trait]
impl Reranker for MockReranker {
async fn rerank(
&self,
_query: &Query,
mut results: Vec<SearchResult>,
) -> Result<Vec<SearchResult>, OxiRagError> {
for result in &mut results {
result.score *= self.score_multiplier;
}
results.retain(|r| r.score >= self.config.min_score_threshold);
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(self.config.top_k);
for (i, result) in results.iter_mut().enumerate() {
result.rank = i;
}
Ok(results)
}
async fn score_pair(&self, _query: &str, _document: &str) -> Result<f32, OxiRagError> {
Ok(0.5 * self.score_multiplier)
}
fn config(&self) -> &RerankerConfig {
&self.config
}
}
#[derive(Debug, Clone)]
pub struct MockCrossEncoderReranker {
base_score: f32,
}
impl MockCrossEncoderReranker {
#[must_use]
pub fn new(base_score: f32) -> Self {
Self { base_score }
}
}
impl Default for MockCrossEncoderReranker {
fn default() -> Self {
Self::new(0.5)
}
}
#[async_trait]
impl CrossEncoderReranker for MockCrossEncoderReranker {
async fn encode_pair(&self, query: &str, document: &str) -> Result<f32, OxiRagError> {
let query_lower = query.to_lowercase();
let doc_lower = document.to_lowercase();
let query_words: std::collections::HashSet<&str> = query_lower.split_whitespace().collect();
let doc_words: std::collections::HashSet<&str> = doc_lower.split_whitespace().collect();
let overlap = query_words.intersection(&doc_words).count();
let union = query_words.union(&doc_words).count();
if union == 0 {
return Ok(self.base_score);
}
#[allow(clippy::cast_precision_loss)]
let jaccard = overlap as f32 / union as f32;
Ok(self.base_score + jaccard * (1.0 - self.base_score))
}
async fn encode_pairs(&self, pairs: &[(&str, &str)]) -> Result<Vec<f32>, OxiRagError> {
let mut scores = Vec::with_capacity(pairs.len());
for (query, document) in pairs {
scores.push(self.encode_pair(query, document).await?);
}
Ok(scores)
}
}
#[derive(Debug, Clone)]
pub struct KeywordReranker {
config: RerankerConfig,
k1: f32,
b: f32,
}
impl KeywordReranker {
#[must_use]
pub fn new(config: RerankerConfig) -> Self {
Self {
config,
k1: 1.2,
b: 0.75,
}
}
#[must_use]
pub fn with_k1(mut self, k1: f32) -> Self {
self.k1 = k1;
self
}
#[must_use]
pub fn with_b(mut self, b: f32) -> Self {
self.b = b;
self
}
#[allow(clippy::cast_precision_loss)]
fn compute_bm25(&self, query: &str, document: &str, avg_doc_len: f32) -> f32 {
let query_lower = query.to_lowercase();
let doc_lower = document.to_lowercase();
let query_terms: Vec<&str> = query_lower.split_whitespace().collect();
let doc_terms: Vec<&str> = doc_lower.split_whitespace().collect();
let doc_len = doc_terms.len() as f32;
let mut tf_map: HashMap<&str, usize> = HashMap::new();
for term in &doc_terms {
*tf_map.entry(*term).or_insert(0) += 1;
}
let mut score = 0.0;
for term in &query_terms {
if let Some(&tf) = tf_map.get(term) {
let tf_score = tf as f32;
let numerator = tf_score * (self.k1 + 1.0);
let denominator =
tf_score + self.k1 * (1.0 - self.b + self.b * (doc_len / avg_doc_len));
score += numerator / denominator;
}
}
score
}
}
impl Default for KeywordReranker {
fn default() -> Self {
Self::new(RerankerConfig::default())
}
}
#[async_trait]
impl Reranker for KeywordReranker {
async fn rerank(
&self,
query: &Query,
mut results: Vec<SearchResult>,
) -> Result<Vec<SearchResult>, OxiRagError> {
if results.is_empty() {
return Ok(results);
}
let total_len: usize = results.iter().map(|r| r.document.content.len()).sum();
#[allow(clippy::cast_precision_loss)]
let avg_doc_len = total_len as f32 / results.len() as f32;
for result in &mut results {
let bm25_score = self.compute_bm25(&query.text, &result.document.content, avg_doc_len);
result.score = 0.5 * result.score + 0.5 * bm25_score.min(1.0);
}
results.retain(|r| r.score >= self.config.min_score_threshold);
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(self.config.top_k);
for (i, result) in results.iter_mut().enumerate() {
result.rank = i;
}
Ok(results)
}
async fn score_pair(&self, query: &str, document: &str) -> Result<f32, OxiRagError> {
let avg_doc_len = 100.0;
Ok(self.compute_bm25(query, document, avg_doc_len))
}
fn config(&self) -> &RerankerConfig {
&self.config
}
}
pub struct SemanticReranker<E: EmbeddingProvider> {
config: RerankerConfig,
embedding_provider: E,
}
impl<E: EmbeddingProvider> SemanticReranker<E> {
#[must_use]
pub fn new(config: RerankerConfig, embedding_provider: E) -> Self {
Self {
config,
embedding_provider,
}
}
#[must_use]
pub fn embedding_provider(&self) -> &E {
&self.embedding_provider
}
}
#[async_trait]
impl<E: EmbeddingProvider> Reranker for SemanticReranker<E> {
async fn rerank(
&self,
query: &Query,
mut results: Vec<SearchResult>,
) -> Result<Vec<SearchResult>, OxiRagError> {
if results.is_empty() {
return Ok(results);
}
let query_embedding = self
.embedding_provider
.embed(&query.text)
.await
.map_err(OxiRagError::Embedding)?;
let mut doc_embeddings = Vec::with_capacity(results.len());
for chunk in results.chunks(self.config.batch_size) {
let texts: Vec<&str> = chunk.iter().map(|r| r.document.content.as_str()).collect();
let embeddings = self
.embedding_provider
.embed_batch(&texts)
.await
.map_err(OxiRagError::Embedding)?;
doc_embeddings.extend(embeddings);
}
for (result, doc_embedding) in results.iter_mut().zip(doc_embeddings.iter()) {
let semantic_score = cosine_similarity(&query_embedding, doc_embedding);
result.score = 0.5 * result.score + 0.5 * semantic_score;
}
results.retain(|r| r.score >= self.config.min_score_threshold);
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(self.config.top_k);
for (i, result) in results.iter_mut().enumerate() {
result.rank = i;
}
Ok(results)
}
async fn score_pair(&self, query: &str, document: &str) -> Result<f32, OxiRagError> {
let query_embedding = self
.embedding_provider
.embed(query)
.await
.map_err(OxiRagError::Embedding)?;
let doc_embedding = self
.embedding_provider
.embed(document)
.await
.map_err(OxiRagError::Embedding)?;
Ok(cosine_similarity(&query_embedding, &doc_embedding))
}
fn config(&self) -> &RerankerConfig {
&self.config
}
}
pub struct HybridReranker<E: EmbeddingProvider> {
config: RerankerConfig,
keyword_reranker: KeywordReranker,
semantic_reranker: SemanticReranker<E>,
keyword_weight: f32,
}
impl<E: EmbeddingProvider> HybridReranker<E> {
#[must_use]
pub fn new(config: RerankerConfig, embedding_provider: E) -> Self {
let keyword_reranker = KeywordReranker::new(config.clone());
let semantic_reranker = SemanticReranker::new(config.clone(), embedding_provider);
Self {
config,
keyword_reranker,
semantic_reranker,
keyword_weight: 0.3,
}
}
#[must_use]
pub fn with_keyword_weight(mut self, weight: f32) -> Self {
self.keyword_weight = weight.clamp(0.0, 1.0);
self
}
#[must_use]
pub fn keyword_weight(&self) -> f32 {
self.keyword_weight
}
#[must_use]
pub fn semantic_weight(&self) -> f32 {
1.0 - self.keyword_weight
}
}
#[async_trait]
impl<E: EmbeddingProvider> Reranker for HybridReranker<E> {
async fn rerank(
&self,
query: &Query,
results: Vec<SearchResult>,
) -> Result<Vec<SearchResult>, OxiRagError> {
if results.is_empty() {
return Ok(results);
}
let keyword_results = self.keyword_reranker.rerank(query, results.clone()).await?;
let keyword_scores: HashMap<String, f32> = keyword_results
.iter()
.map(|r| (r.document.id.as_str().to_string(), r.score))
.collect();
let semantic_results = self.semantic_reranker.rerank(query, results).await?;
let mut final_results: Vec<SearchResult> = semantic_results
.into_iter()
.map(|mut r| {
let keyword_score = keyword_scores
.get(r.document.id.as_str())
.copied()
.unwrap_or(0.0);
let semantic_score = r.score;
r.score =
self.keyword_weight * keyword_score + self.semantic_weight() * semantic_score;
r
})
.collect();
final_results.retain(|r| r.score >= self.config.min_score_threshold);
final_results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
final_results.truncate(self.config.top_k);
for (i, result) in final_results.iter_mut().enumerate() {
result.rank = i;
}
Ok(final_results)
}
async fn score_pair(&self, query: &str, document: &str) -> Result<f32, OxiRagError> {
let keyword_score = self.keyword_reranker.score_pair(query, document).await?;
let semantic_score = self.semantic_reranker.score_pair(query, document).await?;
Ok(self.keyword_weight * keyword_score + self.semantic_weight() * semantic_score)
}
fn config(&self) -> &RerankerConfig {
&self.config
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum FusionStrategy {
#[default]
Weighted,
Cascade,
ReciprocalRankFusion,
}
pub struct RerankerPipeline {
config: RerankerConfig,
rerankers: Vec<Box<dyn Reranker>>,
weights: Vec<f32>,
fusion_strategy: FusionStrategy,
cascade_top_k: usize,
}
impl RerankerPipeline {
#[must_use]
pub fn new(config: RerankerConfig) -> Self {
Self {
config,
rerankers: Vec::new(),
weights: Vec::new(),
fusion_strategy: FusionStrategy::Weighted,
cascade_top_k: 20,
}
}
pub fn add_reranker(&mut self, reranker: Box<dyn Reranker>, weight: f32) {
self.rerankers.push(reranker);
self.weights.push(weight);
}
#[must_use]
pub fn with_fusion_strategy(mut self, strategy: FusionStrategy) -> Self {
self.fusion_strategy = strategy;
self
}
#[must_use]
pub fn with_cascade_top_k(mut self, top_k: usize) -> Self {
self.cascade_top_k = top_k;
self
}
#[must_use]
pub fn len(&self) -> usize {
self.rerankers.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.rerankers.is_empty()
}
async fn execute_weighted(
&self,
query: &Query,
results: Vec<SearchResult>,
) -> Result<Vec<SearchResult>, OxiRagError> {
if self.rerankers.is_empty() {
return Ok(results);
}
let total_weight: f32 = self.weights.iter().sum();
#[allow(clippy::cast_precision_loss)]
let normalized_weights: Vec<f32> = if total_weight > 0.0 {
self.weights.iter().map(|w| w / total_weight).collect()
} else {
vec![1.0 / self.rerankers.len() as f32; self.rerankers.len()]
};
let mut all_scores: HashMap<String, Vec<f32>> = HashMap::new();
for result in &results {
all_scores.insert(result.document.id.as_str().to_string(), Vec::new());
}
for reranker in &self.rerankers {
let reranked = reranker.rerank(query, results.clone()).await?;
for result in &reranked {
if let Some(scores) = all_scores.get_mut(result.document.id.as_str()) {
scores.push(result.score);
}
}
}
let mut final_results: Vec<SearchResult> = results
.into_iter()
.map(|mut r| {
if let Some(scores) = all_scores.get(r.document.id.as_str()) {
let combined: f32 = scores
.iter()
.zip(normalized_weights.iter())
.map(|(s, w)| s * w)
.sum();
r.score = combined;
}
r
})
.collect();
self.finalize_results(&mut final_results);
Ok(final_results)
}
async fn execute_cascade(
&self,
query: &Query,
mut results: Vec<SearchResult>,
) -> Result<Vec<SearchResult>, OxiRagError> {
for reranker in &self.rerankers {
results.truncate(self.cascade_top_k);
results = reranker.rerank(query, results).await?;
}
self.finalize_results(&mut results);
Ok(results)
}
async fn execute_rrf(
&self,
query: &Query,
results: Vec<SearchResult>,
) -> Result<Vec<SearchResult>, OxiRagError> {
const K: f32 = 60.0;
if self.rerankers.is_empty() {
return Ok(results);
}
let mut rrf_scores: HashMap<String, f32> = HashMap::new();
for reranker in &self.rerankers {
let reranked = reranker.rerank(query, results.clone()).await?;
for result in &reranked {
let doc_id = result.document.id.as_str().to_string();
#[allow(clippy::cast_precision_loss)]
let rank = result.rank as f32;
let rrf_contribution = 1.0 / (K + rank + 1.0);
*rrf_scores.entry(doc_id).or_insert(0.0) += rrf_contribution;
}
}
let mut final_results: Vec<SearchResult> = results
.into_iter()
.map(|mut r| {
if let Some(&rrf_score) = rrf_scores.get(r.document.id.as_str()) {
r.score = rrf_score;
}
r
})
.collect();
self.finalize_results(&mut final_results);
Ok(final_results)
}
fn finalize_results(&self, results: &mut Vec<SearchResult>) {
results.retain(|r| r.score >= self.config.min_score_threshold);
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(self.config.top_k);
for (i, result) in results.iter_mut().enumerate() {
result.rank = i;
}
}
}
#[async_trait]
impl Reranker for RerankerPipeline {
async fn rerank(
&self,
query: &Query,
results: Vec<SearchResult>,
) -> Result<Vec<SearchResult>, OxiRagError> {
match self.fusion_strategy {
FusionStrategy::Weighted => self.execute_weighted(query, results).await,
FusionStrategy::Cascade => self.execute_cascade(query, results).await,
FusionStrategy::ReciprocalRankFusion => self.execute_rrf(query, results).await,
}
}
async fn score_pair(&self, query: &str, document: &str) -> Result<f32, OxiRagError> {
if self.rerankers.is_empty() {
return Ok(0.0);
}
self.rerankers[0].score_pair(query, document).await
}
fn config(&self) -> &RerankerConfig {
&self.config
}
}
pub struct RerankerPipelineBuilder {
config: RerankerConfig,
rerankers: Vec<(Box<dyn Reranker>, f32)>,
fusion_strategy: FusionStrategy,
cascade_top_k: usize,
}
impl RerankerPipelineBuilder {
#[must_use]
pub fn new() -> Self {
Self {
config: RerankerConfig::default(),
rerankers: Vec::new(),
fusion_strategy: FusionStrategy::Weighted,
cascade_top_k: 20,
}
}
#[must_use]
pub fn with_config(mut self, config: RerankerConfig) -> Self {
self.config = config;
self
}
#[must_use]
pub fn add_reranker(mut self, reranker: Box<dyn Reranker>, weight: f32) -> Self {
self.rerankers.push((reranker, weight));
self
}
#[must_use]
pub fn with_fusion_strategy(mut self, strategy: FusionStrategy) -> Self {
self.fusion_strategy = strategy;
self
}
#[must_use]
pub fn with_cascade_top_k(mut self, top_k: usize) -> Self {
self.cascade_top_k = top_k;
self
}
#[must_use]
pub fn build(self) -> RerankerPipeline {
let mut pipeline = RerankerPipeline::new(self.config)
.with_fusion_strategy(self.fusion_strategy)
.with_cascade_top_k(self.cascade_top_k);
for (reranker, weight) in self.rerankers {
pipeline.add_reranker(reranker, weight);
}
pipeline
}
}
impl Default for RerankerPipelineBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
#[allow(clippy::similar_names)]
mod tests {
use super::*;
use crate::layer1_echo::MockEmbeddingProvider;
use crate::types::Document;
fn create_test_results() -> Vec<SearchResult> {
vec![
SearchResult::new(
Document::new("Rust is a systems programming language"),
0.9,
0,
),
SearchResult::new(Document::new("Python is great for data science"), 0.7, 1),
SearchResult::new(Document::new("JavaScript runs in browsers"), 0.5, 2),
SearchResult::new(Document::new("Rust prevents memory safety issues"), 0.8, 3),
]
}
#[tokio::test]
async fn test_mock_reranker() {
let config = RerankerConfig::default().with_top_k(3);
let reranker = MockReranker::new(config);
let query = Query::new("programming languages");
let results = create_test_results();
let reranked = reranker.rerank(&query, results).await.unwrap();
assert_eq!(reranked.len(), 3);
assert!(reranked[0].score >= reranked[1].score);
assert!(reranked[1].score >= reranked[2].score);
for (i, r) in reranked.iter().enumerate() {
assert_eq!(r.rank, i);
}
}
#[tokio::test]
async fn test_mock_reranker_score_multiplier() {
let config = RerankerConfig::default();
let reranker = MockReranker::new(config).with_score_multiplier(0.5);
let query = Query::new("test");
let results = vec![SearchResult::new(Document::new("test document"), 1.0, 0)];
let reranked = reranker.rerank(&query, results).await.unwrap();
assert!((reranked[0].score - 0.5).abs() < 0.001);
}
#[tokio::test]
async fn test_mock_reranker_min_score_threshold() {
let config = RerankerConfig::default().with_min_score_threshold(0.6);
let reranker = MockReranker::new(config);
let query = Query::new("test");
let results = create_test_results();
let reranked = reranker.rerank(&query, results).await.unwrap();
for r in &reranked {
assert!(r.score >= 0.6);
}
}
#[tokio::test]
async fn test_mock_cross_encoder() {
let encoder = MockCrossEncoderReranker::default();
let score = encoder
.encode_pair("Rust programming", "Rust is great")
.await
.unwrap();
assert!(score > 0.5);
let score_no_overlap = encoder
.encode_pair("Rust programming", "cats and dogs")
.await
.unwrap();
assert!(score_no_overlap < score); }
#[tokio::test]
async fn test_mock_cross_encoder_batch() {
let encoder = MockCrossEncoderReranker::default();
let pairs = vec![("query1", "document1 query1"), ("query2", "unrelated text")];
let scores = encoder.encode_pairs(&pairs).await.unwrap();
assert_eq!(scores.len(), 2);
assert!(scores[0] > scores[1]); }
#[tokio::test]
async fn test_keyword_reranker() {
let config = RerankerConfig::default().with_top_k(3);
let reranker = KeywordReranker::new(config);
let query = Query::new("Rust programming language");
let results = create_test_results();
let reranked = reranker.rerank(&query, results).await.unwrap();
assert_eq!(reranked.len(), 3);
assert!(reranked[0].document.content.to_lowercase().contains("rust"));
}
#[tokio::test]
async fn test_keyword_reranker_score_pair() {
let reranker = KeywordReranker::default();
let score = reranker
.score_pair("Rust programming", "Rust is a programming language")
.await
.unwrap();
assert!(score > 0.0);
let score_no_match = reranker
.score_pair("Rust programming", "cats and dogs")
.await
.unwrap();
assert!(score_no_match < score);
}
#[tokio::test]
async fn test_semantic_reranker() {
let config = RerankerConfig::default().with_top_k(3);
let embedding_provider = MockEmbeddingProvider::new(64);
let reranker = SemanticReranker::new(config, embedding_provider);
let query = Query::new("programming");
let results = create_test_results();
let reranked = reranker.rerank(&query, results).await.unwrap();
assert_eq!(reranked.len(), 3);
assert!(reranked[0].score >= reranked[1].score);
assert!(reranked[1].score >= reranked[2].score);
}
#[tokio::test]
async fn test_hybrid_reranker() {
let config = RerankerConfig::default().with_top_k(3);
let embedding_provider = MockEmbeddingProvider::new(64);
let reranker = HybridReranker::new(config, embedding_provider).with_keyword_weight(0.4);
assert!((reranker.keyword_weight() - 0.4).abs() < 0.001);
assert!((reranker.semantic_weight() - 0.6).abs() < 0.001);
let query = Query::new("Rust programming");
let results = create_test_results();
let reranked = reranker.rerank(&query, results).await.unwrap();
assert_eq!(reranked.len(), 3);
}
#[tokio::test]
async fn test_reranker_pipeline_weighted() {
let config = RerankerConfig::default().with_top_k(3);
let mut pipeline =
RerankerPipeline::new(config).with_fusion_strategy(FusionStrategy::Weighted);
pipeline.add_reranker(Box::new(MockReranker::default()), 1.0);
pipeline.add_reranker(Box::new(KeywordReranker::default()), 1.0);
assert_eq!(pipeline.len(), 2);
assert!(!pipeline.is_empty());
let query = Query::new("Rust");
let results = create_test_results();
let reranked = pipeline.rerank(&query, results).await.unwrap();
assert_eq!(reranked.len(), 3);
}
#[tokio::test]
async fn test_reranker_pipeline_cascade() {
let config = RerankerConfig::default().with_top_k(2);
let mut pipeline = RerankerPipeline::new(config)
.with_fusion_strategy(FusionStrategy::Cascade)
.with_cascade_top_k(3);
pipeline.add_reranker(Box::new(MockReranker::default()), 1.0);
pipeline.add_reranker(Box::new(KeywordReranker::default()), 1.0);
let query = Query::new("Rust");
let results = create_test_results();
let reranked = pipeline.rerank(&query, results).await.unwrap();
assert_eq!(reranked.len(), 2);
}
#[tokio::test]
async fn test_reranker_pipeline_rrf() {
let config = RerankerConfig::default().with_top_k(3);
let mut pipeline = RerankerPipeline::new(config)
.with_fusion_strategy(FusionStrategy::ReciprocalRankFusion);
pipeline.add_reranker(Box::new(MockReranker::default()), 1.0);
pipeline.add_reranker(Box::new(KeywordReranker::default()), 1.0);
let query = Query::new("Rust");
let results = create_test_results();
let reranked = pipeline.rerank(&query, results).await.unwrap();
assert_eq!(reranked.len(), 3);
}
#[tokio::test]
async fn test_reranker_pipeline_builder() {
let config = RerankerConfig::default().with_top_k(2);
let pipeline = RerankerPipelineBuilder::new()
.with_config(config)
.with_fusion_strategy(FusionStrategy::Cascade)
.with_cascade_top_k(3)
.add_reranker(Box::new(MockReranker::default()), 1.0)
.add_reranker(Box::new(KeywordReranker::default()), 1.0)
.build();
assert_eq!(pipeline.len(), 2);
let query = Query::new("test");
let results = create_test_results();
let reranked = pipeline.rerank(&query, results).await.unwrap();
assert!(reranked.len() <= 2);
}
#[tokio::test]
async fn test_reranker_pipeline_empty() {
let config = RerankerConfig::default();
let pipeline = RerankerPipeline::new(config);
assert!(pipeline.is_empty());
let query = Query::new("test");
let results = create_test_results();
let reranked = pipeline.rerank(&query, results.clone()).await.unwrap();
assert_eq!(reranked.len(), results.len());
}
#[tokio::test]
async fn test_reranker_empty_results() {
let reranker = MockReranker::default();
let query = Query::new("test");
let results: Vec<SearchResult> = vec![];
let reranked = reranker.rerank(&query, results).await.unwrap();
assert!(reranked.is_empty());
}
#[tokio::test]
async fn test_reranker_config_builder() {
let config = RerankerConfig::default()
.with_top_k(5)
.with_min_score_threshold(0.3)
.with_batch_size(16);
assert_eq!(config.top_k, 5);
assert!((config.min_score_threshold - 0.3).abs() < 0.001);
assert_eq!(config.batch_size, 16);
}
#[tokio::test]
async fn test_result_reordering() {
let config = RerankerConfig::default().with_top_k(10);
let reranker = MockReranker::new(config);
let query = Query::new("test");
let results = vec![
SearchResult::new(Document::new("doc1"), 0.3, 0),
SearchResult::new(Document::new("doc2"), 0.9, 1),
SearchResult::new(Document::new("doc3"), 0.6, 2),
];
let reranked = reranker.rerank(&query, results).await.unwrap();
assert!((reranked[0].score - 0.9).abs() < 0.001);
assert!((reranked[1].score - 0.6).abs() < 0.001);
assert!((reranked[2].score - 0.3).abs() < 0.001);
assert_eq!(reranked[0].rank, 0);
assert_eq!(reranked[1].rank, 1);
assert_eq!(reranked[2].rank, 2);
}
#[tokio::test]
async fn test_score_adjustments() {
let config = RerankerConfig::default();
let reranker = MockReranker::new(config).with_score_multiplier(2.0);
let query = Query::new("test");
let results = vec![SearchResult::new(Document::new("doc"), 0.4, 0)];
let reranked = reranker.rerank(&query, results).await.unwrap();
assert!((reranked[0].score - 0.8).abs() < 0.001);
}
#[tokio::test]
async fn test_top_k_filtering() {
let config = RerankerConfig::default().with_top_k(2);
let reranker = MockReranker::new(config);
let query = Query::new("test");
let results = vec![
SearchResult::new(Document::new("doc1"), 0.9, 0),
SearchResult::new(Document::new("doc2"), 0.8, 1),
SearchResult::new(Document::new("doc3"), 0.7, 2),
SearchResult::new(Document::new("doc4"), 0.6, 3),
];
let reranked = reranker.rerank(&query, results).await.unwrap();
assert_eq!(reranked.len(), 2);
assert!((reranked[0].score - 0.9).abs() < 0.001);
assert!((reranked[1].score - 0.8).abs() < 0.001);
}
#[test]
fn test_fusion_strategy_default() {
let strategy = FusionStrategy::default();
assert_eq!(strategy, FusionStrategy::Weighted);
}
#[tokio::test]
async fn test_score_pair_mock() {
let reranker = MockReranker::default().with_score_multiplier(1.5);
let score = reranker.score_pair("query", "document").await.unwrap();
assert!((score - 0.75).abs() < 0.001);
}
#[tokio::test]
async fn test_pipeline_score_pair() {
let config = RerankerConfig::default();
let mut pipeline = RerankerPipeline::new(config);
let score = pipeline.score_pair("query", "document").await.unwrap();
assert!((score - 0.0).abs() < 0.001);
pipeline.add_reranker(Box::new(KeywordReranker::default()), 1.0);
let score = pipeline
.score_pair("Rust programming", "Rust is great")
.await
.unwrap();
assert!(score > 0.0);
}
}