Skip to main content

ares/rag/
reranker.rs

1//! Reranking for improving search result relevance.
2//!
3//! This module provides reranking capabilities using cross-encoder models
4//! to improve the quality of retrieved documents after initial retrieval.
5
6use std::cmp::Ordering;
7use std::str::FromStr;
8use std::sync::Arc;
9
10use fastembed::{RerankInitOptions, RerankerModel as FastEmbedRerankerModel, TextRerank};
11use serde::{Deserialize, Serialize};
12use tokio::sync::OnceCell;
13
14use crate::types::{AppError, Result};
15
16// ============================================================================
17// Reranker Model Types
18// ============================================================================
19
20/// Supported reranking models
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
22#[serde(rename_all = "kebab-case")]
23pub enum RerankerModelType {
24    /// BGE Reranker Base - English/Chinese, good balance of speed and quality
25    #[default]
26    BgeRerankerBase,
27    /// BGE Reranker v2 M3 - Multilingual reranker
28    BgeRerankerV2M3,
29    /// Jina Reranker v1 Turbo - Fast English reranker
30    JinaRerankerV1TurboEn,
31    /// Jina Reranker v2 Base - Multilingual reranker
32    JinaRerankerV2BaseMultilingual,
33}
34
35impl RerankerModelType {
36    /// Convert to fastembed's RerankerModel enum
37    pub fn to_fastembed_model(&self) -> FastEmbedRerankerModel {
38        match self {
39            Self::BgeRerankerBase => FastEmbedRerankerModel::BGERerankerBase,
40            Self::BgeRerankerV2M3 => FastEmbedRerankerModel::BGERerankerV2M3,
41            Self::JinaRerankerV1TurboEn => FastEmbedRerankerModel::JINARerankerV1TurboEn,
42            // Note: typo in fastembed - "Multiligual" instead of "Multilingual"
43            Self::JinaRerankerV2BaseMultilingual => {
44                FastEmbedRerankerModel::JINARerankerV2BaseMultiligual
45            }
46        }
47    }
48
49    /// Get all available models
50    pub fn all() -> Vec<Self> {
51        vec![
52            Self::BgeRerankerBase,
53            Self::BgeRerankerV2M3,
54            Self::JinaRerankerV1TurboEn,
55            Self::JinaRerankerV2BaseMultilingual,
56        ]
57    }
58
59    /// Check if this model is multilingual
60    pub fn is_multilingual(&self) -> bool {
61        matches!(
62            self,
63            Self::JinaRerankerV2BaseMultilingual | Self::BgeRerankerV2M3
64        )
65    }
66}
67
68impl FromStr for RerankerModelType {
69    type Err = AppError;
70
71    fn from_str(s: &str) -> Result<Self> {
72        match s.to_lowercase().as_str() {
73            "bge-reranker-base" | "bge-base" => Ok(Self::BgeRerankerBase),
74            "bge-reranker-v2-m3" | "bge-m3" => Ok(Self::BgeRerankerV2M3),
75            "jina-reranker-v1-turbo-en" | "jina-turbo" => Ok(Self::JinaRerankerV1TurboEn),
76            "jina-reranker-v2-base-multilingual" | "jina-multilingual" => {
77                Ok(Self::JinaRerankerV2BaseMultilingual)
78            }
79            _ => Err(AppError::Internal(format!(
80                "Unknown reranker model: {}. Use one of: bge-reranker-base, \
81                 bge-reranker-v2-m3, jina-reranker-v1-turbo-en, jina-reranker-v2-base-multilingual",
82                s
83            ))),
84        }
85    }
86}
87
88impl std::fmt::Display for RerankerModelType {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        let name = match self {
91            Self::BgeRerankerBase => "bge-reranker-base",
92            Self::BgeRerankerV2M3 => "bge-reranker-v2-m3",
93            Self::JinaRerankerV1TurboEn => "jina-reranker-v1-turbo-en",
94            Self::JinaRerankerV2BaseMultilingual => "jina-reranker-v2-base-multilingual",
95        };
96        write!(f, "{}", name)
97    }
98}
99
100// ============================================================================
101// Reranker Configuration
102// ============================================================================
103
104/// Configuration for the reranking service
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct RerankerConfig {
107    /// Model to use for reranking
108    #[serde(default)]
109    pub model: RerankerModelType,
110    /// Show download progress when fetching model weights
111    #[serde(default = "default_show_progress")]
112    pub show_download_progress: bool,
113    /// Number of top results to return after reranking
114    #[serde(default = "default_top_k")]
115    pub top_k: usize,
116}
117
118fn default_show_progress() -> bool {
119    true
120}
121
122fn default_top_k() -> usize {
123    10
124}
125
126impl Default for RerankerConfig {
127    fn default() -> Self {
128        Self {
129            model: RerankerModelType::default(),
130            show_download_progress: default_show_progress(),
131            top_k: default_top_k(),
132        }
133    }
134}
135
136// ============================================================================
137// Reranked Result
138// ============================================================================
139
140/// A reranked search result
141#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct RerankedResult {
143    /// Document ID
144    pub id: String,
145    /// Document content
146    pub content: String,
147    /// Original retrieval score
148    pub retrieval_score: f32,
149    /// Reranking score from cross-encoder
150    pub rerank_score: f32,
151    /// Final combined score (used for ranking)
152    pub final_score: f32,
153    /// Original rank before reranking
154    pub original_rank: usize,
155    /// New rank after reranking
156    pub new_rank: usize,
157}
158
159// ============================================================================
160// Reranker Service
161// ============================================================================
162
163/// Reranking service using cross-encoder models
164pub struct Reranker {
165    config: RerankerConfig,
166    model: OnceCell<Arc<tokio::sync::Mutex<TextRerank>>>,
167}
168
169impl Reranker {
170    /// Create a new reranker with the given configuration
171    pub fn new(config: RerankerConfig) -> Self {
172        Self {
173            config,
174            model: OnceCell::new(),
175        }
176    }
177
178    /// Create with default configuration
179    pub fn default_reranker() -> Self {
180        Self::new(RerankerConfig::default())
181    }
182
183    /// Get or initialize the reranking model
184    async fn get_model(&self) -> Result<Arc<tokio::sync::Mutex<TextRerank>>> {
185        self.model
186            .get_or_try_init(|| async {
187                let config = self.config.clone();
188                tokio::task::spawn_blocking(move || {
189                    let init_options = RerankInitOptions::new(config.model.to_fastembed_model())
190                        .with_show_download_progress(config.show_download_progress);
191                    let model = TextRerank::try_new(init_options).map_err(|e| {
192                        AppError::Internal(format!("Failed to load reranker: {}", e))
193                    })?;
194                    Ok(Arc::new(tokio::sync::Mutex::new(model)))
195                })
196                .await
197                .map_err(|e| AppError::Internal(format!("Reranker task failed: {}", e)))?
198            })
199            .await
200            .map(Arc::clone)
201    }
202
203    /// Rerank search results
204    ///
205    /// Takes a query and a list of (id, content, score) tuples and returns
206    /// reranked results sorted by relevance.
207    pub async fn rerank(
208        &self,
209        query: &str,
210        results: &[(String, String, f32)],
211        top_k: Option<usize>,
212    ) -> Result<Vec<RerankedResult>> {
213        if results.is_empty() {
214            return Ok(Vec::new());
215        }
216
217        let model = self.get_model().await?;
218        let documents: Vec<String> = results
219            .iter()
220            .map(|(_, content, _)| content.clone())
221            .collect();
222
223        let query = query.to_string();
224        let rerank_scores = tokio::task::spawn_blocking(move || {
225            let mut model = model.blocking_lock();
226            model.rerank(query, &documents, true, None)
227        })
228        .await
229        .map_err(|e| AppError::Internal(format!("Rerank task failed: {}", e)))?
230        .map_err(|e| AppError::Internal(format!("Reranking failed: {}", e)))?;
231
232        // Combine with original results
233        let mut reranked: Vec<RerankedResult> = results
234            .iter()
235            .enumerate()
236            .map(|(idx, (id, content, retrieval_score))| {
237                let rerank_score = rerank_scores
238                    .iter()
239                    .find(|r| r.index == idx)
240                    .map(|r| r.score)
241                    .unwrap_or(0.0);
242
243                RerankedResult {
244                    id: id.clone(),
245                    content: content.clone(),
246                    retrieval_score: *retrieval_score,
247                    rerank_score,
248                    // Use rerank score as final score (could be combined differently)
249                    final_score: rerank_score,
250                    original_rank: idx + 1,
251                    new_rank: 0, // Will be set after sorting
252                }
253            })
254            .collect();
255
256        // Sort by rerank score (higher is better)
257        reranked.sort_by(|a, b| {
258            b.final_score
259                .partial_cmp(&a.final_score)
260                .unwrap_or(Ordering::Equal)
261        });
262
263        // Assign new ranks
264        for (idx, result) in reranked.iter_mut().enumerate() {
265            result.new_rank = idx + 1;
266        }
267
268        // Truncate to top_k
269        let top_k = top_k.unwrap_or(self.config.top_k);
270        reranked.truncate(top_k);
271
272        Ok(reranked)
273    }
274
275    /// Rerank with hybrid scoring
276    ///
277    /// Combines retrieval score with rerank score using a configurable weight
278    pub async fn rerank_hybrid(
279        &self,
280        query: &str,
281        results: &[(String, String, f32)],
282        rerank_weight: f32,
283        top_k: Option<usize>,
284    ) -> Result<Vec<RerankedResult>> {
285        if results.is_empty() {
286            return Ok(Vec::new());
287        }
288
289        let model = self.get_model().await?;
290        let documents: Vec<String> = results
291            .iter()
292            .map(|(_, content, _)| content.clone())
293            .collect();
294
295        let query = query.to_string();
296        let rerank_scores = tokio::task::spawn_blocking(move || {
297            let mut model = model.blocking_lock();
298            model.rerank(query, &documents, true, None)
299        })
300        .await
301        .map_err(|e| AppError::Internal(format!("Rerank task failed: {}", e)))?
302        .map_err(|e| AppError::Internal(format!("Reranking failed: {}", e)))?;
303
304        // Normalize retrieval scores to 0-1 range
305        let max_retrieval = results
306            .iter()
307            .map(|(_, _, s)| *s)
308            .max_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal))
309            .unwrap_or(1.0);
310        let min_retrieval = results
311            .iter()
312            .map(|(_, _, s)| *s)
313            .min_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal))
314            .unwrap_or(0.0);
315        let retrieval_range = max_retrieval - min_retrieval;
316
317        // Combine with original results
318        let retrieval_weight = 1.0 - rerank_weight;
319        let mut reranked: Vec<RerankedResult> = results
320            .iter()
321            .enumerate()
322            .map(|(idx, (id, content, retrieval_score))| {
323                let rerank_score = rerank_scores
324                    .iter()
325                    .find(|r| r.index == idx)
326                    .map(|r| r.score)
327                    .unwrap_or(0.0);
328
329                // Normalize retrieval score
330                let normalized_retrieval = if retrieval_range > 0.0 {
331                    (retrieval_score - min_retrieval) / retrieval_range
332                } else {
333                    1.0
334                };
335
336                // Compute hybrid score
337                let final_score =
338                    retrieval_weight * normalized_retrieval + rerank_weight * rerank_score;
339
340                RerankedResult {
341                    id: id.clone(),
342                    content: content.clone(),
343                    retrieval_score: *retrieval_score,
344                    rerank_score,
345                    final_score,
346                    original_rank: idx + 1,
347                    new_rank: 0,
348                }
349            })
350            .collect();
351
352        // Sort by final score (higher is better)
353        reranked.sort_by(|a, b| {
354            b.final_score
355                .partial_cmp(&a.final_score)
356                .unwrap_or(Ordering::Equal)
357        });
358
359        // Assign new ranks
360        for (idx, result) in reranked.iter_mut().enumerate() {
361            result.new_rank = idx + 1;
362        }
363
364        // Truncate to top_k
365        let top_k = top_k.unwrap_or(self.config.top_k);
366        reranked.truncate(top_k);
367
368        Ok(reranked)
369    }
370
371    /// Get the model type
372    pub fn model_type(&self) -> RerankerModelType {
373        self.config.model
374    }
375}
376
377// ============================================================================
378// Tests
379// ============================================================================
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384
385    #[test]
386    fn test_reranker_model_from_str() {
387        assert_eq!(
388            "bge-reranker-base".parse::<RerankerModelType>().unwrap(),
389            RerankerModelType::BgeRerankerBase
390        );
391        assert_eq!(
392            "bge-m3".parse::<RerankerModelType>().unwrap(),
393            RerankerModelType::BgeRerankerV2M3
394        );
395        assert_eq!(
396            "jina-multilingual".parse::<RerankerModelType>().unwrap(),
397            RerankerModelType::JinaRerankerV2BaseMultilingual
398        );
399    }
400
401    #[test]
402    fn test_reranker_model_display() {
403        assert_eq!(
404            RerankerModelType::BgeRerankerBase.to_string(),
405            "bge-reranker-base"
406        );
407        assert_eq!(
408            RerankerModelType::JinaRerankerV2BaseMultilingual.to_string(),
409            "jina-reranker-v2-base-multilingual"
410        );
411    }
412
413    #[test]
414    fn test_reranker_model_multilingual() {
415        assert!(!RerankerModelType::BgeRerankerBase.is_multilingual());
416        assert!(RerankerModelType::JinaRerankerV2BaseMultilingual.is_multilingual());
417        assert!(RerankerModelType::BgeRerankerV2M3.is_multilingual());
418    }
419
420    #[test]
421    fn test_all_models() {
422        let all = RerankerModelType::all();
423        assert_eq!(all.len(), 4);
424    }
425
426    #[test]
427    fn test_default_config() {
428        let config = RerankerConfig::default();
429        assert_eq!(config.model, RerankerModelType::BgeRerankerBase);
430        assert_eq!(config.top_k, 10);
431        assert!(config.show_download_progress);
432    }
433
434    #[tokio::test]
435    async fn test_rerank_empty() {
436        let reranker = Reranker::default_reranker();
437        let results = reranker.rerank("test query", &[], None).await.unwrap();
438        assert!(results.is_empty());
439    }
440}