Skip to main content

ck_embed/
reranker.rs

1use anyhow::{Result, bail};
2use ck_models::{RerankModelConfig, RerankModelRegistry};
3
4#[cfg(feature = "mixedbread")]
5use crate::mixedbread::MixedbreadReranker;
6
7#[cfg(feature = "fastembed")]
8use std::path::PathBuf;
9
10#[derive(Debug, Clone)]
11pub struct RerankResult {
12    pub query: String,
13    pub document: String,
14    pub score: f32,
15}
16
17pub trait Reranker: Send + Sync {
18    fn id(&self) -> &'static str;
19    fn rerank(&mut self, query: &str, documents: &[String]) -> Result<Vec<RerankResult>>;
20}
21
22pub type RerankModelDownloadCallback = Box<dyn Fn(&str) + Send + Sync>;
23
24pub fn create_reranker(model_name: Option<&str>) -> Result<Box<dyn Reranker>> {
25    create_reranker_with_progress(model_name, None)
26}
27
28pub fn create_reranker_with_progress(
29    model_name: Option<&str>,
30    progress_callback: Option<RerankModelDownloadCallback>,
31) -> Result<Box<dyn Reranker>> {
32    let registry = RerankModelRegistry::default();
33    let (_, config) = registry.resolve(model_name)?;
34    create_reranker_for_config(&config, progress_callback)
35}
36
37#[allow(clippy::needless_return)]
38pub fn create_reranker_for_config(
39    config: &RerankModelConfig,
40    progress_callback: Option<RerankModelDownloadCallback>,
41) -> Result<Box<dyn Reranker>> {
42    match config.provider.as_str() {
43        "fastembed" => {
44            #[cfg(feature = "fastembed")]
45            {
46                return Ok(Box::new(FastReranker::new_with_progress(
47                    config.name.as_str(),
48                    progress_callback,
49                )?));
50            }
51
52            #[cfg(not(feature = "fastembed"))]
53            {
54                if let Some(callback) = progress_callback.as_ref() {
55                    callback("fastembed reranker unavailable; using dummy reranker");
56                }
57                return Ok(Box::new(DummyReranker::new()));
58            }
59        }
60        "mixedbread" => {
61            #[cfg(feature = "mixedbread")]
62            {
63                return Ok(Box::new(MixedbreadReranker::new(
64                    config,
65                    progress_callback,
66                )?));
67            }
68            #[cfg(not(feature = "mixedbread"))]
69            {
70                bail!(
71                    "Reranking model '{}' requires the `mixedbread` feature. Rebuild ck with Mixedbread support.",
72                    config.name
73                );
74            }
75        }
76        provider => bail!("Unsupported reranker provider '{}'", provider),
77    }
78}
79
80pub struct DummyReranker;
81
82impl DummyReranker {
83    pub fn new() -> Self {
84        Self
85    }
86}
87
88impl Default for DummyReranker {
89    fn default() -> Self {
90        Self::new()
91    }
92}
93
94impl Reranker for DummyReranker {
95    fn id(&self) -> &'static str {
96        "dummy_reranker"
97    }
98
99    fn rerank(&mut self, query: &str, documents: &[String]) -> Result<Vec<RerankResult>> {
100        // Dummy reranker just returns documents in original order with random scores
101        Ok(documents
102            .iter()
103            .enumerate()
104            .map(|(i, doc)| {
105                RerankResult {
106                    query: query.to_string(),
107                    document: doc.clone(),
108                    score: 0.5 + (i as f32 * 0.1) % 0.5, // Fake scores between 0.5-1.0
109                }
110            })
111            .collect())
112    }
113}
114
115#[cfg(feature = "fastembed")]
116pub struct FastReranker {
117    model: fastembed::TextRerank,
118    #[allow(dead_code)] // Keep for future use (debugging, logging)
119    model_name: String,
120}
121
122#[cfg(feature = "fastembed")]
123impl FastReranker {
124    pub fn new(model_name: &str) -> Result<Self> {
125        Self::new_with_progress(model_name, None)
126    }
127
128    pub fn new_with_progress(
129        model_name: &str,
130        progress_callback: Option<RerankModelDownloadCallback>,
131    ) -> Result<Self> {
132        use fastembed::{RerankInitOptions, RerankerModel, TextRerank};
133
134        let model = match model_name {
135            "jina-reranker-v1-turbo-en" => RerankerModel::JINARerankerV1TurboEn,
136            "bge-reranker-base" => RerankerModel::BGERerankerBase,
137            "jina-reranker-v2-base-multilingual" => RerankerModel::JINARerankerV2BaseMultiligual,
138            "bge-reranker-v2-m3" => RerankerModel::BGERerankerV2M3,
139            _ => RerankerModel::JINARerankerV1TurboEn, // Default
140        };
141
142        // Configure permanent model cache directory
143        let model_cache_dir = Self::get_model_cache_dir()?;
144        std::fs::create_dir_all(&model_cache_dir)?;
145
146        if let Some(ref callback) = progress_callback {
147            callback(&format!("Initializing reranker model: {}", model_name));
148
149            // Check if model already exists
150            let model_exists = Self::check_model_exists(&model_cache_dir, model_name);
151            if !model_exists {
152                callback(&format!(
153                    "Downloading reranker model {} to {}",
154                    model_name,
155                    model_cache_dir.display()
156                ));
157            } else {
158                callback(&format!("Using cached reranker model: {}", model_name));
159            }
160        }
161
162        let init_options = RerankInitOptions::new(model.clone())
163            .with_show_download_progress(progress_callback.is_some())
164            .with_cache_dir(model_cache_dir);
165
166        let reranker = TextRerank::try_new(init_options)?;
167
168        if let Some(ref callback) = progress_callback {
169            callback("Reranker model loaded successfully");
170        }
171
172        Ok(Self {
173            model: reranker,
174            model_name: model_name.to_string(),
175        })
176    }
177
178    fn get_model_cache_dir() -> Result<PathBuf> {
179        // Use platform-appropriate cache directory (same as embedder)
180        let cache_dir = if let Some(cache_home) = std::env::var_os("XDG_CACHE_HOME") {
181            PathBuf::from(cache_home).join("ck")
182        } else if let Some(home) = std::env::var_os("HOME") {
183            PathBuf::from(home).join(".cache").join("ck")
184        } else if let Some(appdata) = std::env::var_os("LOCALAPPDATA") {
185            PathBuf::from(appdata).join("ck").join("cache")
186        } else {
187            // Fallback to current directory if no home found
188            PathBuf::from(".ck_models")
189        };
190
191        Ok(cache_dir.join("rerankers"))
192    }
193
194    fn check_model_exists(cache_dir: &std::path::Path, model_name: &str) -> bool {
195        // Simple heuristic - check if model directory exists
196        let model_dir = cache_dir.join(model_name.replace("/", "_"));
197        model_dir.exists()
198    }
199}
200
201#[cfg(feature = "fastembed")]
202impl Reranker for FastReranker {
203    fn id(&self) -> &'static str {
204        "fastembed_reranker"
205    }
206
207    fn rerank(&mut self, query: &str, documents: &[String]) -> Result<Vec<RerankResult>> {
208        // Convert documents to string references
209        let docs: Vec<&str> = documents.iter().map(|s| s.as_str()).collect();
210
211        // Get reranking scores - fastembed rerank takes (query, documents)
212        let results = self.model.rerank(query, docs, true, None)?;
213
214        // Convert to our format
215        let rerank_results = results
216            .into_iter()
217            .enumerate()
218            .map(|(i, result)| RerankResult {
219                query: query.to_string(),
220                document: documents[i].clone(),
221                score: result.score,
222            })
223            .collect();
224
225        Ok(rerank_results)
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    #[test]
234    fn test_dummy_reranker() {
235        let mut reranker = DummyReranker::new();
236        assert_eq!(reranker.id(), "dummy_reranker");
237
238        let query = "find error handling";
239        let documents = vec![
240            "try catch block".to_string(),
241            "function definition".to_string(),
242            "error handling code".to_string(),
243        ];
244
245        let results = reranker.rerank(query, &documents).unwrap();
246        assert_eq!(results.len(), 3);
247
248        for result in &results {
249            assert_eq!(result.query, query);
250            assert!(result.score >= 0.5 && result.score <= 1.0);
251        }
252    }
253
254    #[test]
255    fn test_create_reranker_dummy() {
256        #[cfg(not(feature = "fastembed"))]
257        {
258            let reranker = create_reranker(None).unwrap();
259            assert_eq!(reranker.id(), "dummy_reranker");
260        }
261    }
262
263    #[cfg(feature = "fastembed")]
264    #[test]
265    fn test_fastembed_reranker_creation() {
266        // This test requires downloading models, so we'll skip it in CI
267        if std::env::var("CI").is_ok() {
268            return;
269        }
270
271        let reranker = FastReranker::new("jina-reranker-v1-turbo-en");
272
273        match reranker {
274            Ok(mut reranker) => {
275                assert_eq!(reranker.id(), "fastembed_reranker");
276
277                let query = "error handling";
278                let documents = vec![
279                    "try catch exception handling".to_string(),
280                    "user interface design".to_string(),
281                ];
282
283                let result = reranker.rerank(query, &documents);
284                assert!(result.is_ok());
285
286                let results = result.unwrap();
287                assert_eq!(results.len(), 2);
288
289                // First result should be more relevant to query
290                assert!(results[0].score > results[1].score);
291            }
292            Err(_) => {
293                // In test environments, FastEmbed might not be available
294                // This is acceptable for unit tests
295            }
296        }
297    }
298
299    #[test]
300    fn test_reranker_empty_input() {
301        let mut reranker = DummyReranker::new();
302        let query = "test query";
303        let documents: Vec<String> = vec![];
304        let results = reranker.rerank(query, &documents).unwrap();
305        assert_eq!(results.len(), 0);
306    }
307
308    #[test]
309    fn test_reranker_single_document() {
310        let mut reranker = DummyReranker::new();
311        let query = "test query";
312        let documents = vec!["single document".to_string()];
313        let results = reranker.rerank(query, &documents).unwrap();
314
315        assert_eq!(results.len(), 1);
316        assert_eq!(results[0].query, query);
317        assert_eq!(results[0].document, "single document");
318    }
319}