1use anyhow::{Result, bail};
2use ck_models::{RerankModelConfig, RerankModelRegistry};
3
4#[cfg(feature = "mixedbread")]
5use crate::mixedbread::MixedbreadReranker;
6
7#[cfg(feature = "fastembed")]
8use std::path::PathBuf;
9
10#[derive(Debug, Clone)]
11pub struct RerankResult {
12 pub query: String,
13 pub document: String,
14 pub score: f32,
15}
16
17pub trait Reranker: Send + Sync {
18 fn id(&self) -> &'static str;
19 fn rerank(&mut self, query: &str, documents: &[String]) -> Result<Vec<RerankResult>>;
20}
21
22pub type RerankModelDownloadCallback = Box<dyn Fn(&str) + Send + Sync>;
23
24pub fn create_reranker(model_name: Option<&str>) -> Result<Box<dyn Reranker>> {
25 create_reranker_with_progress(model_name, None)
26}
27
28pub fn create_reranker_with_progress(
29 model_name: Option<&str>,
30 progress_callback: Option<RerankModelDownloadCallback>,
31) -> Result<Box<dyn Reranker>> {
32 let registry = RerankModelRegistry::default();
33 let (_, config) = registry.resolve(model_name)?;
34 create_reranker_for_config(&config, progress_callback)
35}
36
37#[allow(clippy::needless_return)]
38pub fn create_reranker_for_config(
39 config: &RerankModelConfig,
40 progress_callback: Option<RerankModelDownloadCallback>,
41) -> Result<Box<dyn Reranker>> {
42 match config.provider.as_str() {
43 "fastembed" => {
44 #[cfg(feature = "fastembed")]
45 {
46 return Ok(Box::new(FastReranker::new_with_progress(
47 config.name.as_str(),
48 progress_callback,
49 )?));
50 }
51
52 #[cfg(not(feature = "fastembed"))]
53 {
54 if let Some(callback) = progress_callback.as_ref() {
55 callback("fastembed reranker unavailable; using dummy reranker");
56 }
57 return Ok(Box::new(DummyReranker::new()));
58 }
59 }
60 "mixedbread" => {
61 #[cfg(feature = "mixedbread")]
62 {
63 return Ok(Box::new(MixedbreadReranker::new(
64 config,
65 progress_callback,
66 )?));
67 }
68 #[cfg(not(feature = "mixedbread"))]
69 {
70 bail!(
71 "Reranking model '{}' requires the `mixedbread` feature. Rebuild ck with Mixedbread support.",
72 config.name
73 );
74 }
75 }
76 provider => bail!("Unsupported reranker provider '{}'", provider),
77 }
78}
79
80pub struct DummyReranker;
81
82impl DummyReranker {
83 pub fn new() -> Self {
84 Self
85 }
86}
87
88impl Default for DummyReranker {
89 fn default() -> Self {
90 Self::new()
91 }
92}
93
94impl Reranker for DummyReranker {
95 fn id(&self) -> &'static str {
96 "dummy_reranker"
97 }
98
99 fn rerank(&mut self, query: &str, documents: &[String]) -> Result<Vec<RerankResult>> {
100 Ok(documents
102 .iter()
103 .enumerate()
104 .map(|(i, doc)| {
105 RerankResult {
106 query: query.to_string(),
107 document: doc.clone(),
108 score: 0.5 + (i as f32 * 0.1) % 0.5, }
110 })
111 .collect())
112 }
113}
114
115#[cfg(feature = "fastembed")]
116pub struct FastReranker {
117 model: fastembed::TextRerank,
118 #[allow(dead_code)] model_name: String,
120}
121
122#[cfg(feature = "fastembed")]
123impl FastReranker {
124 pub fn new(model_name: &str) -> Result<Self> {
125 Self::new_with_progress(model_name, None)
126 }
127
128 pub fn new_with_progress(
129 model_name: &str,
130 progress_callback: Option<RerankModelDownloadCallback>,
131 ) -> Result<Self> {
132 use fastembed::{RerankInitOptions, RerankerModel, TextRerank};
133
134 let model = match model_name {
135 "jina-reranker-v1-turbo-en" => RerankerModel::JINARerankerV1TurboEn,
136 "bge-reranker-base" => RerankerModel::BGERerankerBase,
137 "jina-reranker-v2-base-multilingual" => RerankerModel::JINARerankerV2BaseMultiligual,
138 "bge-reranker-v2-m3" => RerankerModel::BGERerankerV2M3,
139 _ => RerankerModel::JINARerankerV1TurboEn, };
141
142 let model_cache_dir = Self::get_model_cache_dir()?;
144 std::fs::create_dir_all(&model_cache_dir)?;
145
146 if let Some(ref callback) = progress_callback {
147 callback(&format!("Initializing reranker model: {}", model_name));
148
149 let model_exists = Self::check_model_exists(&model_cache_dir, model_name);
151 if !model_exists {
152 callback(&format!(
153 "Downloading reranker model {} to {}",
154 model_name,
155 model_cache_dir.display()
156 ));
157 } else {
158 callback(&format!("Using cached reranker model: {}", model_name));
159 }
160 }
161
162 let init_options = RerankInitOptions::new(model.clone())
163 .with_show_download_progress(progress_callback.is_some())
164 .with_cache_dir(model_cache_dir);
165
166 let reranker = TextRerank::try_new(init_options)?;
167
168 if let Some(ref callback) = progress_callback {
169 callback("Reranker model loaded successfully");
170 }
171
172 Ok(Self {
173 model: reranker,
174 model_name: model_name.to_string(),
175 })
176 }
177
178 fn get_model_cache_dir() -> Result<PathBuf> {
179 let cache_dir = if let Some(cache_home) = std::env::var_os("XDG_CACHE_HOME") {
181 PathBuf::from(cache_home).join("ck")
182 } else if let Some(home) = std::env::var_os("HOME") {
183 PathBuf::from(home).join(".cache").join("ck")
184 } else if let Some(appdata) = std::env::var_os("LOCALAPPDATA") {
185 PathBuf::from(appdata).join("ck").join("cache")
186 } else {
187 PathBuf::from(".ck_models")
189 };
190
191 Ok(cache_dir.join("rerankers"))
192 }
193
194 fn check_model_exists(cache_dir: &std::path::Path, model_name: &str) -> bool {
195 let model_dir = cache_dir.join(model_name.replace("/", "_"));
197 model_dir.exists()
198 }
199}
200
201#[cfg(feature = "fastembed")]
202impl Reranker for FastReranker {
203 fn id(&self) -> &'static str {
204 "fastembed_reranker"
205 }
206
207 fn rerank(&mut self, query: &str, documents: &[String]) -> Result<Vec<RerankResult>> {
208 let docs: Vec<&str> = documents.iter().map(|s| s.as_str()).collect();
210
211 let results = self.model.rerank(query, docs, true, None)?;
213
214 let rerank_results = results
216 .into_iter()
217 .enumerate()
218 .map(|(i, result)| RerankResult {
219 query: query.to_string(),
220 document: documents[i].clone(),
221 score: result.score,
222 })
223 .collect();
224
225 Ok(rerank_results)
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232
233 #[test]
234 fn test_dummy_reranker() {
235 let mut reranker = DummyReranker::new();
236 assert_eq!(reranker.id(), "dummy_reranker");
237
238 let query = "find error handling";
239 let documents = vec![
240 "try catch block".to_string(),
241 "function definition".to_string(),
242 "error handling code".to_string(),
243 ];
244
245 let results = reranker.rerank(query, &documents).unwrap();
246 assert_eq!(results.len(), 3);
247
248 for result in &results {
249 assert_eq!(result.query, query);
250 assert!(result.score >= 0.5 && result.score <= 1.0);
251 }
252 }
253
254 #[test]
255 fn test_create_reranker_dummy() {
256 #[cfg(not(feature = "fastembed"))]
257 {
258 let reranker = create_reranker(None).unwrap();
259 assert_eq!(reranker.id(), "dummy_reranker");
260 }
261 }
262
263 #[cfg(feature = "fastembed")]
264 #[test]
265 fn test_fastembed_reranker_creation() {
266 if std::env::var("CI").is_ok() {
268 return;
269 }
270
271 let reranker = FastReranker::new("jina-reranker-v1-turbo-en");
272
273 match reranker {
274 Ok(mut reranker) => {
275 assert_eq!(reranker.id(), "fastembed_reranker");
276
277 let query = "error handling";
278 let documents = vec![
279 "try catch exception handling".to_string(),
280 "user interface design".to_string(),
281 ];
282
283 let result = reranker.rerank(query, &documents);
284 assert!(result.is_ok());
285
286 let results = result.unwrap();
287 assert_eq!(results.len(), 2);
288
289 assert!(results[0].score > results[1].score);
291 }
292 Err(_) => {
293 }
296 }
297 }
298
299 #[test]
300 fn test_reranker_empty_input() {
301 let mut reranker = DummyReranker::new();
302 let query = "test query";
303 let documents: Vec<String> = vec![];
304 let results = reranker.rerank(query, &documents).unwrap();
305 assert_eq!(results.len(), 0);
306 }
307
308 #[test]
309 fn test_reranker_single_document() {
310 let mut reranker = DummyReranker::new();
311 let query = "test query";
312 let documents = vec!["single document".to_string()];
313 let results = reranker.rerank(query, &documents).unwrap();
314
315 assert_eq!(results.len(), 1);
316 assert_eq!(results[0].query, query);
317 assert_eq!(results[0].document, "single document");
318 }
319}