use async_trait::async_trait;
use super::bm25::BM25Reranker;
use super::result::RerankResult;
use super::traits::Reranker;
use crate::error::Result;
pub struct RRFReranker {
k: u32,
model: String,
}
impl RRFReranker {
pub fn new() -> Self {
Self {
k: 60,
model: "rrf-reranker".to_string(),
}
}
pub fn with_k(k: u32) -> Self {
Self {
k: k.max(1),
model: "rrf-reranker".to_string(),
}
}
pub fn fuse(&self, ranked_lists: &[Vec<usize>], num_docs: usize) -> Vec<RerankResult> {
let mut scores = vec![0.0f64; num_docs];
for ranked_list in ranked_lists {
for (rank, &doc_idx) in ranked_list.iter().enumerate() {
if doc_idx < num_docs {
scores[doc_idx] += 1.0 / (self.k as f64 + rank as f64 + 1.0);
}
}
}
let mut results: Vec<RerankResult> = scores
.into_iter()
.enumerate()
.filter(|(_, score)| *score > 0.0)
.map(|(idx, score)| RerankResult {
index: idx,
relevance_score: score,
})
.collect();
results.sort_by(|a, b| {
b.relevance_score
.partial_cmp(&a.relevance_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results
}
}
impl Default for RRFReranker {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Reranker for RRFReranker {
fn name(&self) -> &str {
"rrf"
}
fn model(&self) -> &str {
&self.model
}
async fn rerank(
&self,
query: &str,
documents: &[String],
top_n: Option<usize>,
) -> Result<Vec<RerankResult>> {
let bm25 = BM25Reranker::new();
let mut results = bm25.rerank(query, documents, None).await?;
if let Some(n) = top_n {
results.truncate(n);
}
Ok(results)
}
}