use std::cmp::Ordering;
use std::str::FromStr;
use std::sync::Arc;
use fastembed::{RerankInitOptions, RerankerModel as FastEmbedRerankerModel, TextRerank};
use serde::{Deserialize, Serialize};
use tokio::sync::OnceCell;
use crate::types::{AppError, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "kebab-case")]
pub enum RerankerModelType {
#[default]
BgeRerankerBase,
BgeRerankerV2M3,
JinaRerankerV1TurboEn,
JinaRerankerV2BaseMultilingual,
}
impl RerankerModelType {
pub fn to_fastembed_model(&self) -> FastEmbedRerankerModel {
match self {
Self::BgeRerankerBase => FastEmbedRerankerModel::BGERerankerBase,
Self::BgeRerankerV2M3 => FastEmbedRerankerModel::BGERerankerV2M3,
Self::JinaRerankerV1TurboEn => FastEmbedRerankerModel::JINARerankerV1TurboEn,
Self::JinaRerankerV2BaseMultilingual => {
FastEmbedRerankerModel::JINARerankerV2BaseMultiligual
}
}
}
pub fn all() -> Vec<Self> {
vec![
Self::BgeRerankerBase,
Self::BgeRerankerV2M3,
Self::JinaRerankerV1TurboEn,
Self::JinaRerankerV2BaseMultilingual,
]
}
pub fn hf_repo_id(&self) -> &'static str {
match self {
Self::BgeRerankerBase => "BAAI/bge-reranker-base",
Self::BgeRerankerV2M3 => "BAAI/bge-reranker-v2-m3",
Self::JinaRerankerV1TurboEn => "jinaai/jina-reranker-v1-turbo-en",
Self::JinaRerankerV2BaseMultilingual => "jinaai/jina-reranker-v2-base-multilingual",
}
}
pub fn is_multilingual(&self) -> bool {
matches!(
self,
Self::JinaRerankerV2BaseMultilingual | Self::BgeRerankerV2M3
)
}
}
impl FromStr for RerankerModelType {
type Err = AppError;
fn from_str(s: &str) -> Result<Self> {
match s.to_lowercase().as_str() {
"bge-reranker-base" | "bge-base" => Ok(Self::BgeRerankerBase),
"bge-reranker-v2-m3" | "bge-m3" => Ok(Self::BgeRerankerV2M3),
"jina-reranker-v1-turbo-en" | "jina-turbo" => Ok(Self::JinaRerankerV1TurboEn),
"jina-reranker-v2-base-multilingual" | "jina-multilingual" => {
Ok(Self::JinaRerankerV2BaseMultilingual)
}
_ => Err(AppError::Internal(format!(
"Unknown reranker model: {}. Use one of: bge-reranker-base, \
bge-reranker-v2-m3, jina-reranker-v1-turbo-en, jina-reranker-v2-base-multilingual",
s
))),
}
}
}
impl std::fmt::Display for RerankerModelType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let name = match self {
Self::BgeRerankerBase => "bge-reranker-base",
Self::BgeRerankerV2M3 => "bge-reranker-v2-m3",
Self::JinaRerankerV1TurboEn => "jina-reranker-v1-turbo-en",
Self::JinaRerankerV2BaseMultilingual => "jina-reranker-v2-base-multilingual",
};
write!(f, "{}", name)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RerankerConfig {
#[serde(default)]
pub model: RerankerModelType,
#[serde(default = "default_show_progress")]
pub show_download_progress: bool,
#[serde(default = "default_top_k")]
pub top_k: usize,
}
fn default_show_progress() -> bool {
true
}
fn default_top_k() -> usize {
10
}
impl Default for RerankerConfig {
fn default() -> Self {
Self {
model: RerankerModelType::default(),
show_download_progress: default_show_progress(),
top_k: default_top_k(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RerankedResult {
pub id: String,
pub content: String,
pub retrieval_score: f32,
pub rerank_score: f32,
pub final_score: f32,
pub original_rank: usize,
pub new_rank: usize,
}
pub struct Reranker {
config: RerankerConfig,
model: OnceCell<Arc<tokio::sync::Mutex<TextRerank>>>,
}
impl Reranker {
pub fn new(config: RerankerConfig) -> Self {
Self {
config,
model: OnceCell::new(),
}
}
pub fn default_reranker() -> Self {
Self::new(RerankerConfig::default())
}
async fn get_model(&self) -> Result<Arc<tokio::sync::Mutex<TextRerank>>> {
self.model
.get_or_try_init(|| async {
let config = self.config.clone();
tokio::task::spawn_blocking(move || {
let repo_id = config.model.hf_repo_id();
let onnx_files = &["onnx/model.onnx", "tokenizer.json", "config.json"];
let cache_dir = std::env::var("FASTEMBED_CACHE_DIR")
.map(std::path::PathBuf::from)
.unwrap_or_else(|_| {
let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
std::path::PathBuf::from(home).join(".cache").join("fastembed")
});
if let Err(e) = super::embeddings::pre_download_model(repo_id, onnx_files, &cache_dir) {
tracing::warn!("Reranker pre-download failed (may already be cached): {}", e);
}
let init_options = RerankInitOptions::new(config.model.to_fastembed_model())
.with_show_download_progress(config.show_download_progress);
let model = TextRerank::try_new(init_options).map_err(|e| {
AppError::Internal(format!("Failed to load reranker: {}", e))
})?;
Ok(Arc::new(tokio::sync::Mutex::new(model)))
})
.await
.map_err(|e| AppError::Internal(format!("Reranker task failed: {}", e)))?
})
.await
.map(Arc::clone)
}
pub async fn rerank(
&self,
query: &str,
results: &[(String, String, f32)],
top_k: Option<usize>,
) -> Result<Vec<RerankedResult>> {
if results.is_empty() {
return Ok(Vec::new());
}
let model = self.get_model().await?;
let documents: Vec<String> = results
.iter()
.map(|(_, content, _)| content.clone())
.collect();
let query = query.to_string();
let rerank_scores = tokio::task::spawn_blocking(move || {
let mut model = model.blocking_lock();
model.rerank(query, &documents, true, None)
})
.await
.map_err(|e| AppError::Internal(format!("Rerank task failed: {}", e)))?
.map_err(|e| AppError::Internal(format!("Reranking failed: {}", e)))?;
let mut reranked: Vec<RerankedResult> = results
.iter()
.enumerate()
.map(|(idx, (id, content, retrieval_score))| {
let rerank_score = rerank_scores
.iter()
.find(|r| r.index == idx)
.map(|r| r.score)
.unwrap_or(0.0);
RerankedResult {
id: id.clone(),
content: content.clone(),
retrieval_score: *retrieval_score,
rerank_score,
final_score: rerank_score,
original_rank: idx + 1,
new_rank: 0, }
})
.collect();
reranked.sort_by(|a, b| {
b.final_score
.partial_cmp(&a.final_score)
.unwrap_or(Ordering::Equal)
});
for (idx, result) in reranked.iter_mut().enumerate() {
result.new_rank = idx + 1;
}
let top_k = top_k.unwrap_or(self.config.top_k);
reranked.truncate(top_k);
Ok(reranked)
}
pub async fn rerank_hybrid(
&self,
query: &str,
results: &[(String, String, f32)],
rerank_weight: f32,
top_k: Option<usize>,
) -> Result<Vec<RerankedResult>> {
if results.is_empty() {
return Ok(Vec::new());
}
let model = self.get_model().await?;
let documents: Vec<String> = results
.iter()
.map(|(_, content, _)| content.clone())
.collect();
let query = query.to_string();
let rerank_scores = tokio::task::spawn_blocking(move || {
let mut model = model.blocking_lock();
model.rerank(query, &documents, true, None)
})
.await
.map_err(|e| AppError::Internal(format!("Rerank task failed: {}", e)))?
.map_err(|e| AppError::Internal(format!("Reranking failed: {}", e)))?;
let max_retrieval = results
.iter()
.map(|(_, _, s)| *s)
.max_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal))
.unwrap_or(1.0);
let min_retrieval = results
.iter()
.map(|(_, _, s)| *s)
.min_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal))
.unwrap_or(0.0);
let retrieval_range = max_retrieval - min_retrieval;
let retrieval_weight = 1.0 - rerank_weight;
let mut reranked: Vec<RerankedResult> = results
.iter()
.enumerate()
.map(|(idx, (id, content, retrieval_score))| {
let rerank_score = rerank_scores
.iter()
.find(|r| r.index == idx)
.map(|r| r.score)
.unwrap_or(0.0);
let normalized_retrieval = if retrieval_range > 0.0 {
(retrieval_score - min_retrieval) / retrieval_range
} else {
1.0
};
let final_score =
retrieval_weight * normalized_retrieval + rerank_weight * rerank_score;
RerankedResult {
id: id.clone(),
content: content.clone(),
retrieval_score: *retrieval_score,
rerank_score,
final_score,
original_rank: idx + 1,
new_rank: 0,
}
})
.collect();
reranked.sort_by(|a, b| {
b.final_score
.partial_cmp(&a.final_score)
.unwrap_or(Ordering::Equal)
});
for (idx, result) in reranked.iter_mut().enumerate() {
result.new_rank = idx + 1;
}
let top_k = top_k.unwrap_or(self.config.top_k);
reranked.truncate(top_k);
Ok(reranked)
}
pub fn model_type(&self) -> RerankerModelType {
self.config.model
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reranker_model_from_str() {
assert_eq!(
"bge-reranker-base".parse::<RerankerModelType>().unwrap(),
RerankerModelType::BgeRerankerBase
);
assert_eq!(
"bge-m3".parse::<RerankerModelType>().unwrap(),
RerankerModelType::BgeRerankerV2M3
);
assert_eq!(
"jina-multilingual".parse::<RerankerModelType>().unwrap(),
RerankerModelType::JinaRerankerV2BaseMultilingual
);
}
#[test]
fn test_reranker_model_display() {
assert_eq!(
RerankerModelType::BgeRerankerBase.to_string(),
"bge-reranker-base"
);
assert_eq!(
RerankerModelType::JinaRerankerV2BaseMultilingual.to_string(),
"jina-reranker-v2-base-multilingual"
);
}
#[test]
fn test_reranker_model_multilingual() {
assert!(!RerankerModelType::BgeRerankerBase.is_multilingual());
assert!(RerankerModelType::JinaRerankerV2BaseMultilingual.is_multilingual());
assert!(RerankerModelType::BgeRerankerV2M3.is_multilingual());
}
#[test]
fn test_all_models() {
let all = RerankerModelType::all();
assert_eq!(all.len(), 4);
}
#[test]
fn test_default_config() {
let config = RerankerConfig::default();
assert_eq!(config.model, RerankerModelType::BgeRerankerBase);
assert_eq!(config.top_k, 10);
assert!(config.show_download_progress);
}
#[tokio::test]
async fn test_rerank_empty() {
let reranker = Reranker::default_reranker();
let results = reranker.rerank("test query", &[], None).await.unwrap();
assert!(results.is_empty());
}
#[test]
fn test_display_roundtrip_all_reranker_models() {
for model in RerankerModelType::all() {
let display = model.to_string();
let parsed: RerankerModelType = display.parse().unwrap_or_else(|_| {
panic!("Display→FromStr roundtrip failed for {:?} ('{}')", model, display)
});
assert_eq!(parsed, model);
}
}
#[test]
fn test_reranker_from_str_aliases() {
let aliases = vec![
("bge-base", RerankerModelType::BgeRerankerBase),
("bge-m3", RerankerModelType::BgeRerankerV2M3),
("jina-turbo", RerankerModelType::JinaRerankerV1TurboEn),
("jina-multilingual", RerankerModelType::JinaRerankerV2BaseMultilingual),
];
for (alias, expected) in aliases {
let parsed: RerankerModelType = alias.parse().unwrap();
assert_eq!(parsed, expected, "Alias '{}' mismatch", alias);
}
}
#[test]
fn test_reranker_from_str_case_insensitive() {
let parsed: RerankerModelType = "BGE-RERANKER-BASE".parse().unwrap();
assert_eq!(parsed, RerankerModelType::BgeRerankerBase);
}
#[test]
fn test_reranker_from_str_invalid() {
let result = "fake-reranker".parse::<RerankerModelType>();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("Unknown reranker model"));
}
#[test]
fn test_hf_repo_id_all_models() {
for model in RerankerModelType::all() {
let repo = model.hf_repo_id();
assert!(!repo.is_empty(), "{:?} has empty repo ID", model);
assert!(repo.contains('/'), "{:?} repo '{}' should have org/model format", model, repo);
}
}
#[test]
fn test_reranker_config_serialization() {
let config = RerankerConfig {
model: RerankerModelType::JinaRerankerV1TurboEn,
show_download_progress: false,
top_k: 5,
};
let json = serde_json::to_string(&config).unwrap();
let parsed: RerankerConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.model, RerankerModelType::JinaRerankerV1TurboEn);
assert_eq!(parsed.top_k, 5);
assert!(!parsed.show_download_progress);
}
#[test]
fn test_reranked_result_serialization() {
let result = RerankedResult {
id: "doc-1".to_string(),
content: "test content".to_string(),
retrieval_score: 0.8,
rerank_score: 0.95,
final_score: 0.9,
original_rank: 3,
new_rank: 1,
};
let json = serde_json::to_string(&result).unwrap();
assert!(json.contains("\"id\":\"doc-1\""));
assert!(json.contains("\"new_rank\":1"));
let parsed: RerankedResult = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.id, "doc-1");
assert_eq!(parsed.original_rank, 3);
assert_eq!(parsed.new_rank, 1);
}
#[test]
fn test_to_fastembed_reranker_all_variants() {
for model in RerankerModelType::all() {
let _ = model.to_fastembed_model(); }
}
#[test]
fn test_reranker_new_and_model_type() {
let config = RerankerConfig {
model: RerankerModelType::BgeRerankerV2M3,
..Default::default()
};
let reranker = Reranker::new(config);
assert_eq!(reranker.model_type(), RerankerModelType::BgeRerankerV2M3);
}
}