use crate::{GraphRAGResult, ScoredEntity};
use async_trait::async_trait;
#[async_trait]
pub trait RerankerTrait: Send + Sync {
async fn rerank(
&self,
query: &str,
candidates: Vec<ScoredEntity>,
) -> GraphRAGResult<Vec<ScoredEntity>>;
}
pub struct Reranker {
multi_source_boost: f64,
min_score: f64,
}
impl Default for Reranker {
fn default() -> Self {
Self::new()
}
}
impl Reranker {
pub fn new() -> Self {
Self {
multi_source_boost: 1.2,
min_score: 0.1,
}
}
pub fn with_multi_source_boost(mut self, boost: f64) -> Self {
self.multi_source_boost = boost;
self
}
pub fn with_min_score(mut self, min_score: f64) -> Self {
self.min_score = min_score;
self
}
pub fn rerank(&self, candidates: Vec<ScoredEntity>) -> Vec<ScoredEntity> {
let mut reranked: Vec<ScoredEntity> = candidates
.into_iter()
.filter(|e| e.score >= self.min_score)
.map(|mut e| {
if e.source == crate::ScoreSource::Fused {
e.score *= self.multi_source_boost;
}
e
})
.collect();
reranked.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
reranked
}
}
pub struct CrossEncoderReranker<E>
where
E: CrossEncoderModel,
{
model: E,
batch_size: usize,
}
#[async_trait]
pub trait CrossEncoderModel: Send + Sync {
async fn score(&self, query: &str, document: &str) -> GraphRAGResult<f64>;
async fn score_batch(&self, query: &str, documents: &[&str]) -> GraphRAGResult<Vec<f64>>;
}
impl<E: CrossEncoderModel> CrossEncoderReranker<E> {
pub fn new(model: E) -> Self {
Self {
model,
batch_size: 32,
}
}
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = batch_size;
self
}
pub async fn rerank(
&self,
query: &str,
candidates: Vec<ScoredEntity>,
) -> GraphRAGResult<Vec<ScoredEntity>> {
if candidates.is_empty() {
return Ok(vec![]);
}
let docs: Vec<&str> = candidates.iter().map(|e| e.uri.as_str()).collect();
let mut all_scores = Vec::with_capacity(candidates.len());
for chunk in docs.chunks(self.batch_size) {
let scores = self.model.score_batch(query, chunk).await?;
all_scores.extend(scores);
}
let mut reranked: Vec<ScoredEntity> = candidates
.into_iter()
.zip(all_scores)
.map(|(mut e, cross_score)| {
e.score = e.score * 0.3 + cross_score * 0.7;
e
})
.collect();
reranked.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(reranked)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ScoreSource;
use std::collections::HashMap;
#[test]
fn test_simple_reranker() {
let reranker = Reranker::new();
let candidates = vec![
ScoredEntity {
uri: "http://a".to_string(),
score: 0.5,
source: ScoreSource::Vector,
metadata: HashMap::new(),
},
ScoredEntity {
uri: "http://b".to_string(),
score: 0.6,
source: ScoreSource::Fused,
metadata: HashMap::new(),
},
ScoredEntity {
uri: "http://c".to_string(),
score: 0.05,
source: ScoreSource::Keyword,
metadata: HashMap::new(),
},
];
let reranked = reranker.rerank(candidates);
assert_eq!(reranked.len(), 2);
assert_eq!(reranked[0].uri, "http://b");
}
}