1use std::cmp::Ordering;
7use std::str::FromStr;
8use std::sync::Arc;
9
10use fastembed::{RerankInitOptions, RerankerModel as FastEmbedRerankerModel, TextRerank};
11use serde::{Deserialize, Serialize};
12use tokio::sync::OnceCell;
13
14use crate::types::{AppError, Result};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
22#[serde(rename_all = "kebab-case")]
23pub enum RerankerModelType {
24 #[default]
26 BgeRerankerBase,
27 BgeRerankerV2M3,
29 JinaRerankerV1TurboEn,
31 JinaRerankerV2BaseMultilingual,
33}
34
35impl RerankerModelType {
36 pub fn to_fastembed_model(&self) -> FastEmbedRerankerModel {
38 match self {
39 Self::BgeRerankerBase => FastEmbedRerankerModel::BGERerankerBase,
40 Self::BgeRerankerV2M3 => FastEmbedRerankerModel::BGERerankerV2M3,
41 Self::JinaRerankerV1TurboEn => FastEmbedRerankerModel::JINARerankerV1TurboEn,
42 Self::JinaRerankerV2BaseMultilingual => {
44 FastEmbedRerankerModel::JINARerankerV2BaseMultiligual
45 }
46 }
47 }
48
49 pub fn all() -> Vec<Self> {
51 vec![
52 Self::BgeRerankerBase,
53 Self::BgeRerankerV2M3,
54 Self::JinaRerankerV1TurboEn,
55 Self::JinaRerankerV2BaseMultilingual,
56 ]
57 }
58
59 pub fn is_multilingual(&self) -> bool {
61 matches!(
62 self,
63 Self::JinaRerankerV2BaseMultilingual | Self::BgeRerankerV2M3
64 )
65 }
66}
67
68impl FromStr for RerankerModelType {
69 type Err = AppError;
70
71 fn from_str(s: &str) -> Result<Self> {
72 match s.to_lowercase().as_str() {
73 "bge-reranker-base" | "bge-base" => Ok(Self::BgeRerankerBase),
74 "bge-reranker-v2-m3" | "bge-m3" => Ok(Self::BgeRerankerV2M3),
75 "jina-reranker-v1-turbo-en" | "jina-turbo" => Ok(Self::JinaRerankerV1TurboEn),
76 "jina-reranker-v2-base-multilingual" | "jina-multilingual" => {
77 Ok(Self::JinaRerankerV2BaseMultilingual)
78 }
79 _ => Err(AppError::Internal(format!(
80 "Unknown reranker model: {}. Use one of: bge-reranker-base, \
81 bge-reranker-v2-m3, jina-reranker-v1-turbo-en, jina-reranker-v2-base-multilingual",
82 s
83 ))),
84 }
85 }
86}
87
88impl std::fmt::Display for RerankerModelType {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 let name = match self {
91 Self::BgeRerankerBase => "bge-reranker-base",
92 Self::BgeRerankerV2M3 => "bge-reranker-v2-m3",
93 Self::JinaRerankerV1TurboEn => "jina-reranker-v1-turbo-en",
94 Self::JinaRerankerV2BaseMultilingual => "jina-reranker-v2-base-multilingual",
95 };
96 write!(f, "{}", name)
97 }
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct RerankerConfig {
107 #[serde(default)]
109 pub model: RerankerModelType,
110 #[serde(default = "default_show_progress")]
112 pub show_download_progress: bool,
113 #[serde(default = "default_top_k")]
115 pub top_k: usize,
116}
117
118fn default_show_progress() -> bool {
119 true
120}
121
122fn default_top_k() -> usize {
123 10
124}
125
126impl Default for RerankerConfig {
127 fn default() -> Self {
128 Self {
129 model: RerankerModelType::default(),
130 show_download_progress: default_show_progress(),
131 top_k: default_top_k(),
132 }
133 }
134}
135
136#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct RerankedResult {
143 pub id: String,
145 pub content: String,
147 pub retrieval_score: f32,
149 pub rerank_score: f32,
151 pub final_score: f32,
153 pub original_rank: usize,
155 pub new_rank: usize,
157}
158
159pub struct Reranker {
165 config: RerankerConfig,
166 model: OnceCell<Arc<tokio::sync::Mutex<TextRerank>>>,
167}
168
169impl Reranker {
170 pub fn new(config: RerankerConfig) -> Self {
172 Self {
173 config,
174 model: OnceCell::new(),
175 }
176 }
177
178 pub fn default_reranker() -> Self {
180 Self::new(RerankerConfig::default())
181 }
182
183 async fn get_model(&self) -> Result<Arc<tokio::sync::Mutex<TextRerank>>> {
185 self.model
186 .get_or_try_init(|| async {
187 let config = self.config.clone();
188 tokio::task::spawn_blocking(move || {
189 let init_options = RerankInitOptions::new(config.model.to_fastembed_model())
190 .with_show_download_progress(config.show_download_progress);
191 let model = TextRerank::try_new(init_options).map_err(|e| {
192 AppError::Internal(format!("Failed to load reranker: {}", e))
193 })?;
194 Ok(Arc::new(tokio::sync::Mutex::new(model)))
195 })
196 .await
197 .map_err(|e| AppError::Internal(format!("Reranker task failed: {}", e)))?
198 })
199 .await
200 .map(Arc::clone)
201 }
202
203 pub async fn rerank(
208 &self,
209 query: &str,
210 results: &[(String, String, f32)],
211 top_k: Option<usize>,
212 ) -> Result<Vec<RerankedResult>> {
213 if results.is_empty() {
214 return Ok(Vec::new());
215 }
216
217 let model = self.get_model().await?;
218 let documents: Vec<String> = results
219 .iter()
220 .map(|(_, content, _)| content.clone())
221 .collect();
222
223 let query = query.to_string();
224 let rerank_scores = tokio::task::spawn_blocking(move || {
225 let mut model = model.blocking_lock();
226 model.rerank(query, &documents, true, None)
227 })
228 .await
229 .map_err(|e| AppError::Internal(format!("Rerank task failed: {}", e)))?
230 .map_err(|e| AppError::Internal(format!("Reranking failed: {}", e)))?;
231
232 let mut reranked: Vec<RerankedResult> = results
234 .iter()
235 .enumerate()
236 .map(|(idx, (id, content, retrieval_score))| {
237 let rerank_score = rerank_scores
238 .iter()
239 .find(|r| r.index == idx)
240 .map(|r| r.score)
241 .unwrap_or(0.0);
242
243 RerankedResult {
244 id: id.clone(),
245 content: content.clone(),
246 retrieval_score: *retrieval_score,
247 rerank_score,
248 final_score: rerank_score,
250 original_rank: idx + 1,
251 new_rank: 0, }
253 })
254 .collect();
255
256 reranked.sort_by(|a, b| {
258 b.final_score
259 .partial_cmp(&a.final_score)
260 .unwrap_or(Ordering::Equal)
261 });
262
263 for (idx, result) in reranked.iter_mut().enumerate() {
265 result.new_rank = idx + 1;
266 }
267
268 let top_k = top_k.unwrap_or(self.config.top_k);
270 reranked.truncate(top_k);
271
272 Ok(reranked)
273 }
274
275 pub async fn rerank_hybrid(
279 &self,
280 query: &str,
281 results: &[(String, String, f32)],
282 rerank_weight: f32,
283 top_k: Option<usize>,
284 ) -> Result<Vec<RerankedResult>> {
285 if results.is_empty() {
286 return Ok(Vec::new());
287 }
288
289 let model = self.get_model().await?;
290 let documents: Vec<String> = results
291 .iter()
292 .map(|(_, content, _)| content.clone())
293 .collect();
294
295 let query = query.to_string();
296 let rerank_scores = tokio::task::spawn_blocking(move || {
297 let mut model = model.blocking_lock();
298 model.rerank(query, &documents, true, None)
299 })
300 .await
301 .map_err(|e| AppError::Internal(format!("Rerank task failed: {}", e)))?
302 .map_err(|e| AppError::Internal(format!("Reranking failed: {}", e)))?;
303
304 let max_retrieval = results
306 .iter()
307 .map(|(_, _, s)| *s)
308 .max_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal))
309 .unwrap_or(1.0);
310 let min_retrieval = results
311 .iter()
312 .map(|(_, _, s)| *s)
313 .min_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal))
314 .unwrap_or(0.0);
315 let retrieval_range = max_retrieval - min_retrieval;
316
317 let retrieval_weight = 1.0 - rerank_weight;
319 let mut reranked: Vec<RerankedResult> = results
320 .iter()
321 .enumerate()
322 .map(|(idx, (id, content, retrieval_score))| {
323 let rerank_score = rerank_scores
324 .iter()
325 .find(|r| r.index == idx)
326 .map(|r| r.score)
327 .unwrap_or(0.0);
328
329 let normalized_retrieval = if retrieval_range > 0.0 {
331 (retrieval_score - min_retrieval) / retrieval_range
332 } else {
333 1.0
334 };
335
336 let final_score =
338 retrieval_weight * normalized_retrieval + rerank_weight * rerank_score;
339
340 RerankedResult {
341 id: id.clone(),
342 content: content.clone(),
343 retrieval_score: *retrieval_score,
344 rerank_score,
345 final_score,
346 original_rank: idx + 1,
347 new_rank: 0,
348 }
349 })
350 .collect();
351
352 reranked.sort_by(|a, b| {
354 b.final_score
355 .partial_cmp(&a.final_score)
356 .unwrap_or(Ordering::Equal)
357 });
358
359 for (idx, result) in reranked.iter_mut().enumerate() {
361 result.new_rank = idx + 1;
362 }
363
364 let top_k = top_k.unwrap_or(self.config.top_k);
366 reranked.truncate(top_k);
367
368 Ok(reranked)
369 }
370
371 pub fn model_type(&self) -> RerankerModelType {
373 self.config.model
374 }
375}
376
377#[cfg(test)]
382mod tests {
383 use super::*;
384
385 #[test]
386 fn test_reranker_model_from_str() {
387 assert_eq!(
388 "bge-reranker-base".parse::<RerankerModelType>().unwrap(),
389 RerankerModelType::BgeRerankerBase
390 );
391 assert_eq!(
392 "bge-m3".parse::<RerankerModelType>().unwrap(),
393 RerankerModelType::BgeRerankerV2M3
394 );
395 assert_eq!(
396 "jina-multilingual".parse::<RerankerModelType>().unwrap(),
397 RerankerModelType::JinaRerankerV2BaseMultilingual
398 );
399 }
400
401 #[test]
402 fn test_reranker_model_display() {
403 assert_eq!(
404 RerankerModelType::BgeRerankerBase.to_string(),
405 "bge-reranker-base"
406 );
407 assert_eq!(
408 RerankerModelType::JinaRerankerV2BaseMultilingual.to_string(),
409 "jina-reranker-v2-base-multilingual"
410 );
411 }
412
413 #[test]
414 fn test_reranker_model_multilingual() {
415 assert!(!RerankerModelType::BgeRerankerBase.is_multilingual());
416 assert!(RerankerModelType::JinaRerankerV2BaseMultilingual.is_multilingual());
417 assert!(RerankerModelType::BgeRerankerV2M3.is_multilingual());
418 }
419
420 #[test]
421 fn test_all_models() {
422 let all = RerankerModelType::all();
423 assert_eq!(all.len(), 4);
424 }
425
426 #[test]
427 fn test_default_config() {
428 let config = RerankerConfig::default();
429 assert_eq!(config.model, RerankerModelType::BgeRerankerBase);
430 assert_eq!(config.top_k, 10);
431 assert!(config.show_download_progress);
432 }
433
434 #[tokio::test]
435 async fn test_rerank_empty() {
436 let reranker = Reranker::default_reranker();
437 let results = reranker.rerank("test query", &[], None).await.unwrap();
438 assert!(results.is_empty());
439 }
440}