use async_trait::async_trait;
use std::collections::HashSet;
use super::result::RerankResult;
use super::traits::Reranker;
use crate::error::Result;
pub struct TermOverlapReranker {
model: String,
}
impl TermOverlapReranker {
pub fn new() -> Self {
Self {
model: "term-overlap-reranker".to_string(),
}
}
}
impl Default for TermOverlapReranker {
fn default() -> Self {
Self::new()
}
}
pub type MockReranker = TermOverlapReranker;
#[async_trait]
impl Reranker for TermOverlapReranker {
fn name(&self) -> &str {
"term-overlap"
}
fn model(&self) -> &str {
&self.model
}
async fn rerank(
&self,
query: &str,
documents: &[String],
top_n: Option<usize>,
) -> Result<Vec<RerankResult>> {
let query_lower = query.to_lowercase();
let query_terms: HashSet<String> = query_lower
.split_whitespace()
.map(|s| s.to_string())
.collect();
let mut results: Vec<RerankResult> = documents
.iter()
.enumerate()
.map(|(idx, doc)| {
let doc_lower = doc.to_lowercase();
let doc_terms: HashSet<String> = doc_lower
.split_whitespace()
.map(|s| s.to_string())
.collect();
let overlap = query_terms.intersection(&doc_terms).count();
let max_terms = query_terms.len().max(1);
let score = overlap as f64 / max_terms as f64;
RerankResult {
index: idx,
relevance_score: score,
}
})
.collect();
results.sort_by(|a, b| {
b.relevance_score
.partial_cmp(&a.relevance_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
if let Some(n) = top_n {
results.truncate(n);
}
Ok(results)
}
}