use std::sync::Arc;
use async_trait::async_trait;
use crate::schemas::{Document, Retriever};
#[derive(Debug, Clone)]
pub struct FlashRankRerankerConfig {
pub model: String,
pub top_k: Option<usize>,
}
impl Default for FlashRankRerankerConfig {
fn default() -> Self {
Self {
model: "ms-marco-MiniLM-L-12-v2".to_string(),
top_k: None,
}
}
}
pub struct FlashRankReranker {
base_retriever: Arc<dyn Retriever>,
config: FlashRankRerankerConfig,
}
impl FlashRankReranker {
pub fn new(base_retriever: Arc<dyn Retriever>) -> Self {
Self::with_config(base_retriever, FlashRankRerankerConfig::default())
}
pub fn with_config(
base_retriever: Arc<dyn Retriever>,
config: FlashRankRerankerConfig,
) -> Self {
Self {
base_retriever,
config,
}
}
fn rerank_simple(&self, query: &str, documents: Vec<Document>) -> Vec<Document> {
let query_lower = query.to_lowercase();
let query_words: Vec<&str> = query_lower.split_whitespace().collect();
let mut scored: Vec<(Document, f64)> = documents
.into_iter()
.map(|doc| {
let doc_lower = doc.page_content.to_lowercase();
let score = query_words
.iter()
.map(|word| if doc_lower.contains(word) { 1.0 } else { 0.0 })
.sum::<f64>()
/ query_words.len() as f64;
(doc, score)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let results: Vec<Document> = scored.into_iter().map(|(doc, _)| doc).collect();
if let Some(k) = self.config.top_k {
results.into_iter().take(k).collect()
} else {
results
}
}
}
#[async_trait]
impl Retriever for FlashRankReranker {
async fn get_relevant_documents(
&self,
query: &str,
) -> Result<Vec<Document>, crate::error::RetrieverError> {
let documents = self.base_retriever.get_relevant_documents(query).await?;
Ok(self.rerank_simple(query, documents))
}
}