use std::sync::Mutex;
use fastembed::{RerankInitOptions, RerankerModel, TextRerank};
#[derive(Debug, thiserror::Error)]
pub enum RerankerError {
#[error("Model load error: {0}")]
ModelLoad(String),
#[error("Inference error: {0}")]
Inference(String),
#[error("Unknown model: '{0}'. Supported: BGERerankerBase, JINARerankerV1TurboEn")]
UnknownModel(String),
}
#[derive(Debug, Clone)]
pub struct RerankResult {
pub index: usize,
pub score: f32,
}
enum RerankerInner {
Mock,
Real(Box<Mutex<TextRerank>>),
}
pub struct Reranker {
inner: RerankerInner,
}
impl Reranker {
pub fn new(model_name: &str) -> Result<Self, RerankerError> {
let model_enum = match model_name {
"BGERerankerBase" => RerankerModel::BGERerankerBase,
"JINARerankerV1TurboEn" => RerankerModel::JINARerankerV1TurboEn,
other => return Err(RerankerError::UnknownModel(other.to_string())),
};
let text_rerank = TextRerank::try_new(
RerankInitOptions::new(model_enum).with_show_download_progress(true),
)
.map_err(|e| RerankerError::ModelLoad(e.to_string()))?;
Ok(Self {
inner: RerankerInner::Real(Box::new(Mutex::new(text_rerank))),
})
}
pub fn new_mock() -> Self {
Self {
inner: RerankerInner::Mock,
}
}
#[tracing::instrument(skip_all)]
pub fn rerank(
&self,
query: &str,
documents: &[&str],
top_k: usize,
) -> Result<Vec<RerankResult>, RerankerError> {
if documents.is_empty() {
return Ok(vec![]);
}
let effective_k = if top_k == 0 || top_k > documents.len() {
documents.len()
} else {
top_k
};
match &self.inner {
RerankerInner::Mock => {
let n = documents.len() as f32;
let mut results: Vec<RerankResult> = documents
.iter()
.enumerate()
.map(|(i, _)| RerankResult {
index: i,
score: (n - i as f32) / n,
})
.collect();
results.truncate(effective_k);
Ok(results)
}
RerankerInner::Real(mutex) => {
let mut model = mutex
.lock()
.map_err(|e| RerankerError::Inference(format!("Mutex poisoned: {e}")))?;
let fastembed_results = model
.rerank(query, documents, false, None)
.map_err(|e| RerankerError::Inference(e.to_string()))?;
let results: Vec<RerankResult> = fastembed_results
.into_iter()
.take(effective_k)
.map(|r| RerankResult {
index: r.index,
score: r.score,
})
.collect();
Ok(results)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reranker_mock_passthrough() {
let reranker = Reranker::new_mock();
let results = reranker
.rerank("query", &["doc1", "doc2", "doc3"], 3)
.unwrap();
assert_eq!(results.len(), 3);
let mut indices: Vec<usize> = results.iter().map(|r| r.index).collect();
indices.sort_unstable();
assert_eq!(indices, vec![0, 1, 2]);
}
#[test]
fn test_reranker_mock_top_k_truncates() {
let reranker = Reranker::new_mock();
let results = reranker
.rerank("query", &["doc1", "doc2", "doc3", "doc4"], 2)
.unwrap();
assert_eq!(results.len(), 2);
}
#[test]
fn test_reranker_mock_empty_documents() {
let reranker = Reranker::new_mock();
let results = reranker.rerank("query", &[], 3).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_reranker_mock_top_k_zero_returns_all() {
let reranker = Reranker::new_mock();
let results = reranker
.rerank("query", &["doc1", "doc2", "doc3"], 0)
.unwrap();
assert_eq!(results.len(), 3);
}
#[test]
fn test_reranker_mock_scores_decrease() {
let reranker = Reranker::new_mock();
let results = reranker
.rerank("query", &["doc1", "doc2", "doc3"], 3)
.unwrap();
for i in 1..results.len() {
assert!(
results[i - 1].score >= results[i].score,
"Scores should be non-increasing: {} < {}",
results[i - 1].score,
results[i].score
);
}
}
#[test]
fn test_unknown_model_returns_error() {
let result = Reranker::new("nonexistent-model");
assert!(result.is_err());
let err = result.err().unwrap();
assert!(err.to_string().contains("Unknown model"));
}
#[test]
#[ignore] fn test_reranker_real_bge() {
let reranker = Reranker::new("BGERerankerBase").unwrap();
let results = reranker
.rerank(
"What is Python?",
&[
"Python is a programming language",
"The weather is sunny today",
"Python was created by Guido van Rossum",
],
3,
)
.unwrap();
assert_eq!(results.len(), 3);
let top_index = results[0].index;
assert!(
top_index == 0 || top_index == 2,
"Top result should be a programming doc, got index {top_index}"
);
}
}