ck_embed/
reranker.rs

1use anyhow::Result;
2
3#[cfg(feature = "fastembed")]
4use std::path::PathBuf;
5
6#[derive(Debug, Clone)]
7pub struct RerankResult {
8    pub query: String,
9    pub document: String,
10    pub score: f32,
11}
12
13pub trait Reranker: Send + Sync {
14    fn id(&self) -> &'static str;
15    fn rerank(&mut self, query: &str, documents: &[String]) -> Result<Vec<RerankResult>>;
16}
17
18pub type RerankModelDownloadCallback = Box<dyn Fn(&str) + Send + Sync>;
19
20pub fn create_reranker(model_name: Option<&str>) -> Result<Box<dyn Reranker>> {
21    create_reranker_with_progress(model_name, None)
22}
23
24pub fn create_reranker_with_progress(
25    model_name: Option<&str>,
26    progress_callback: Option<RerankModelDownloadCallback>,
27) -> Result<Box<dyn Reranker>> {
28    let model = model_name.unwrap_or("jina-reranker-v1-turbo-en");
29
30    #[cfg(feature = "fastembed")]
31    {
32        Ok(Box::new(FastReranker::new_with_progress(
33            model,
34            progress_callback,
35        )?))
36    }
37
38    #[cfg(not(feature = "fastembed"))]
39    {
40        let _ = model; // Suppress unused variable warning
41        if let Some(callback) = progress_callback {
42            callback("Using dummy reranker (no model download required)");
43        }
44        Ok(Box::new(DummyReranker::new()))
45    }
46}
47
48pub struct DummyReranker;
49
50impl DummyReranker {
51    pub fn new() -> Self {
52        Self
53    }
54}
55
56impl Default for DummyReranker {
57    fn default() -> Self {
58        Self::new()
59    }
60}
61
62impl Reranker for DummyReranker {
63    fn id(&self) -> &'static str {
64        "dummy_reranker"
65    }
66
67    fn rerank(&mut self, query: &str, documents: &[String]) -> Result<Vec<RerankResult>> {
68        // Dummy reranker just returns documents in original order with random scores
69        Ok(documents
70            .iter()
71            .enumerate()
72            .map(|(i, doc)| {
73                RerankResult {
74                    query: query.to_string(),
75                    document: doc.clone(),
76                    score: 0.5 + (i as f32 * 0.1) % 0.5, // Fake scores between 0.5-1.0
77                }
78            })
79            .collect())
80    }
81}
82
83#[cfg(feature = "fastembed")]
84pub struct FastReranker {
85    model: fastembed::TextRerank,
86    #[allow(dead_code)] // Keep for future use (debugging, logging)
87    model_name: String,
88}
89
90#[cfg(feature = "fastembed")]
91impl FastReranker {
92    pub fn new(model_name: &str) -> Result<Self> {
93        Self::new_with_progress(model_name, None)
94    }
95
96    pub fn new_with_progress(
97        model_name: &str,
98        progress_callback: Option<RerankModelDownloadCallback>,
99    ) -> Result<Self> {
100        use fastembed::{RerankInitOptions, RerankerModel, TextRerank};
101
102        let model = match model_name {
103            "jina-reranker-v1-turbo-en" => RerankerModel::JINARerankerV1TurboEn,
104            "bge-reranker-base" => RerankerModel::BGERerankerBase,
105            "jina-reranker-v2-base-multilingual" => RerankerModel::JINARerankerV2BaseMultiligual,
106            "bge-reranker-v2-m3" => RerankerModel::BGERerankerV2M3,
107            _ => RerankerModel::JINARerankerV1TurboEn, // Default
108        };
109
110        // Configure permanent model cache directory
111        let model_cache_dir = Self::get_model_cache_dir()?;
112        std::fs::create_dir_all(&model_cache_dir)?;
113
114        if let Some(ref callback) = progress_callback {
115            callback(&format!("Initializing reranker model: {}", model_name));
116
117            // Check if model already exists
118            let model_exists = Self::check_model_exists(&model_cache_dir, model_name);
119            if !model_exists {
120                callback(&format!(
121                    "Downloading reranker model {} to {}",
122                    model_name,
123                    model_cache_dir.display()
124                ));
125            } else {
126                callback(&format!("Using cached reranker model: {}", model_name));
127            }
128        }
129
130        let init_options = RerankInitOptions::new(model.clone())
131            .with_show_download_progress(progress_callback.is_some())
132            .with_cache_dir(model_cache_dir);
133
134        let reranker = TextRerank::try_new(init_options)?;
135
136        if let Some(ref callback) = progress_callback {
137            callback("Reranker model loaded successfully");
138        }
139
140        Ok(Self {
141            model: reranker,
142            model_name: model_name.to_string(),
143        })
144    }
145
146    fn get_model_cache_dir() -> Result<PathBuf> {
147        // Use platform-appropriate cache directory (same as embedder)
148        let cache_dir = if let Some(cache_home) = std::env::var_os("XDG_CACHE_HOME") {
149            PathBuf::from(cache_home).join("ck")
150        } else if let Some(home) = std::env::var_os("HOME") {
151            PathBuf::from(home).join(".cache").join("ck")
152        } else if let Some(appdata) = std::env::var_os("LOCALAPPDATA") {
153            PathBuf::from(appdata).join("ck").join("cache")
154        } else {
155            // Fallback to current directory if no home found
156            PathBuf::from(".ck_models")
157        };
158
159        Ok(cache_dir.join("rerankers"))
160    }
161
162    fn check_model_exists(cache_dir: &std::path::Path, model_name: &str) -> bool {
163        // Simple heuristic - check if model directory exists
164        let model_dir = cache_dir.join(model_name.replace("/", "_"));
165        model_dir.exists()
166    }
167}
168
169#[cfg(feature = "fastembed")]
170impl Reranker for FastReranker {
171    fn id(&self) -> &'static str {
172        "fastembed_reranker"
173    }
174
175    fn rerank(&mut self, query: &str, documents: &[String]) -> Result<Vec<RerankResult>> {
176        // Convert documents to string references
177        let docs: Vec<&str> = documents.iter().map(|s| s.as_str()).collect();
178
179        // Get reranking scores - fastembed rerank takes (query, documents)
180        let results = self.model.rerank(query, docs, true, None)?;
181
182        // Convert to our format
183        let rerank_results = results
184            .into_iter()
185            .enumerate()
186            .map(|(i, result)| RerankResult {
187                query: query.to_string(),
188                document: documents[i].clone(),
189                score: result.score,
190            })
191            .collect();
192
193        Ok(rerank_results)
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    #[test]
202    fn test_dummy_reranker() {
203        let mut reranker = DummyReranker::new();
204        assert_eq!(reranker.id(), "dummy_reranker");
205
206        let query = "find error handling";
207        let documents = vec![
208            "try catch block".to_string(),
209            "function definition".to_string(),
210            "error handling code".to_string(),
211        ];
212
213        let results = reranker.rerank(query, &documents).unwrap();
214        assert_eq!(results.len(), 3);
215
216        for result in &results {
217            assert_eq!(result.query, query);
218            assert!(result.score >= 0.5 && result.score <= 1.0);
219        }
220    }
221
222    #[test]
223    fn test_create_reranker_dummy() {
224        #[cfg(not(feature = "fastembed"))]
225        {
226            let reranker = create_reranker(None).unwrap();
227            assert_eq!(reranker.id(), "dummy_reranker");
228        }
229    }
230
231    #[cfg(feature = "fastembed")]
232    #[test]
233    fn test_fastembed_reranker_creation() {
234        // This test requires downloading models, so we'll skip it in CI
235        if std::env::var("CI").is_ok() {
236            return;
237        }
238
239        let reranker = FastReranker::new("jina-reranker-v1-turbo-en");
240
241        match reranker {
242            Ok(mut reranker) => {
243                assert_eq!(reranker.id(), "fastembed_reranker");
244
245                let query = "error handling";
246                let documents = vec![
247                    "try catch exception handling".to_string(),
248                    "user interface design".to_string(),
249                ];
250
251                let result = reranker.rerank(query, &documents);
252                assert!(result.is_ok());
253
254                let results = result.unwrap();
255                assert_eq!(results.len(), 2);
256
257                // First result should be more relevant to query
258                assert!(results[0].score > results[1].score);
259            }
260            Err(_) => {
261                // In test environments, FastEmbed might not be available
262                // This is acceptable for unit tests
263            }
264        }
265    }
266
267    #[test]
268    fn test_reranker_empty_input() {
269        let mut reranker = DummyReranker::new();
270        let query = "test query";
271        let documents: Vec<String> = vec![];
272        let results = reranker.rerank(query, &documents).unwrap();
273        assert_eq!(results.len(), 0);
274    }
275
276    #[test]
277    fn test_reranker_single_document() {
278        let mut reranker = DummyReranker::new();
279        let query = "test query";
280        let documents = vec!["single document".to_string()];
281        let results = reranker.rerank(query, &documents).unwrap();
282
283        assert_eq!(results.len(), 1);
284        assert_eq!(results[0].query, query);
285        assert_eq!(results[0].document, "single document");
286    }
287}