1use anyhow::Result;
2
3#[cfg(feature = "fastembed")]
4use std::path::PathBuf;
5
6#[derive(Debug, Clone)]
7pub struct RerankResult {
8 pub query: String,
9 pub document: String,
10 pub score: f32,
11}
12
13pub trait Reranker: Send + Sync {
14 fn id(&self) -> &'static str;
15 fn rerank(&mut self, query: &str, documents: &[String]) -> Result<Vec<RerankResult>>;
16}
17
18pub type RerankModelDownloadCallback = Box<dyn Fn(&str) + Send + Sync>;
19
20pub fn create_reranker(model_name: Option<&str>) -> Result<Box<dyn Reranker>> {
21 create_reranker_with_progress(model_name, None)
22}
23
24pub fn create_reranker_with_progress(
25 model_name: Option<&str>,
26 progress_callback: Option<RerankModelDownloadCallback>,
27) -> Result<Box<dyn Reranker>> {
28 let model = model_name.unwrap_or("jina-reranker-v1-turbo-en");
29
30 #[cfg(feature = "fastembed")]
31 {
32 Ok(Box::new(FastReranker::new_with_progress(
33 model,
34 progress_callback,
35 )?))
36 }
37
38 #[cfg(not(feature = "fastembed"))]
39 {
40 let _ = model; if let Some(callback) = progress_callback {
42 callback("Using dummy reranker (no model download required)");
43 }
44 Ok(Box::new(DummyReranker::new()))
45 }
46}
47
48pub struct DummyReranker;
49
50impl DummyReranker {
51 pub fn new() -> Self {
52 Self
53 }
54}
55
56impl Default for DummyReranker {
57 fn default() -> Self {
58 Self::new()
59 }
60}
61
62impl Reranker for DummyReranker {
63 fn id(&self) -> &'static str {
64 "dummy_reranker"
65 }
66
67 fn rerank(&mut self, query: &str, documents: &[String]) -> Result<Vec<RerankResult>> {
68 Ok(documents
70 .iter()
71 .enumerate()
72 .map(|(i, doc)| {
73 RerankResult {
74 query: query.to_string(),
75 document: doc.clone(),
76 score: 0.5 + (i as f32 * 0.1) % 0.5, }
78 })
79 .collect())
80 }
81}
82
83#[cfg(feature = "fastembed")]
84pub struct FastReranker {
85 model: fastembed::TextRerank,
86 #[allow(dead_code)] model_name: String,
88}
89
90#[cfg(feature = "fastembed")]
91impl FastReranker {
92 pub fn new(model_name: &str) -> Result<Self> {
93 Self::new_with_progress(model_name, None)
94 }
95
96 pub fn new_with_progress(
97 model_name: &str,
98 progress_callback: Option<RerankModelDownloadCallback>,
99 ) -> Result<Self> {
100 use fastembed::{RerankInitOptions, RerankerModel, TextRerank};
101
102 let model = match model_name {
103 "jina-reranker-v1-turbo-en" => RerankerModel::JINARerankerV1TurboEn,
104 "bge-reranker-base" => RerankerModel::BGERerankerBase,
105 "jina-reranker-v2-base-multilingual" => RerankerModel::JINARerankerV2BaseMultiligual,
106 "bge-reranker-v2-m3" => RerankerModel::BGERerankerV2M3,
107 _ => RerankerModel::JINARerankerV1TurboEn, };
109
110 let model_cache_dir = Self::get_model_cache_dir()?;
112 std::fs::create_dir_all(&model_cache_dir)?;
113
114 if let Some(ref callback) = progress_callback {
115 callback(&format!("Initializing reranker model: {}", model_name));
116
117 let model_exists = Self::check_model_exists(&model_cache_dir, model_name);
119 if !model_exists {
120 callback(&format!(
121 "Downloading reranker model {} to {}",
122 model_name,
123 model_cache_dir.display()
124 ));
125 } else {
126 callback(&format!("Using cached reranker model: {}", model_name));
127 }
128 }
129
130 let init_options = RerankInitOptions::new(model.clone())
131 .with_show_download_progress(progress_callback.is_some())
132 .with_cache_dir(model_cache_dir);
133
134 let reranker = TextRerank::try_new(init_options)?;
135
136 if let Some(ref callback) = progress_callback {
137 callback("Reranker model loaded successfully");
138 }
139
140 Ok(Self {
141 model: reranker,
142 model_name: model_name.to_string(),
143 })
144 }
145
146 fn get_model_cache_dir() -> Result<PathBuf> {
147 let cache_dir = if let Some(cache_home) = std::env::var_os("XDG_CACHE_HOME") {
149 PathBuf::from(cache_home).join("ck")
150 } else if let Some(home) = std::env::var_os("HOME") {
151 PathBuf::from(home).join(".cache").join("ck")
152 } else if let Some(appdata) = std::env::var_os("LOCALAPPDATA") {
153 PathBuf::from(appdata).join("ck").join("cache")
154 } else {
155 PathBuf::from(".ck_models")
157 };
158
159 Ok(cache_dir.join("rerankers"))
160 }
161
162 fn check_model_exists(cache_dir: &std::path::Path, model_name: &str) -> bool {
163 let model_dir = cache_dir.join(model_name.replace("/", "_"));
165 model_dir.exists()
166 }
167}
168
169#[cfg(feature = "fastembed")]
170impl Reranker for FastReranker {
171 fn id(&self) -> &'static str {
172 "fastembed_reranker"
173 }
174
175 fn rerank(&mut self, query: &str, documents: &[String]) -> Result<Vec<RerankResult>> {
176 let docs: Vec<&str> = documents.iter().map(|s| s.as_str()).collect();
178
179 let results = self.model.rerank(query, docs, true, None)?;
181
182 let rerank_results = results
184 .into_iter()
185 .enumerate()
186 .map(|(i, result)| RerankResult {
187 query: query.to_string(),
188 document: documents[i].clone(),
189 score: result.score,
190 })
191 .collect();
192
193 Ok(rerank_results)
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 #[test]
202 fn test_dummy_reranker() {
203 let mut reranker = DummyReranker::new();
204 assert_eq!(reranker.id(), "dummy_reranker");
205
206 let query = "find error handling";
207 let documents = vec![
208 "try catch block".to_string(),
209 "function definition".to_string(),
210 "error handling code".to_string(),
211 ];
212
213 let results = reranker.rerank(query, &documents).unwrap();
214 assert_eq!(results.len(), 3);
215
216 for result in &results {
217 assert_eq!(result.query, query);
218 assert!(result.score >= 0.5 && result.score <= 1.0);
219 }
220 }
221
222 #[test]
223 fn test_create_reranker_dummy() {
224 #[cfg(not(feature = "fastembed"))]
225 {
226 let reranker = create_reranker(None).unwrap();
227 assert_eq!(reranker.id(), "dummy_reranker");
228 }
229 }
230
231 #[cfg(feature = "fastembed")]
232 #[test]
233 fn test_fastembed_reranker_creation() {
234 if std::env::var("CI").is_ok() {
236 return;
237 }
238
239 let reranker = FastReranker::new("jina-reranker-v1-turbo-en");
240
241 match reranker {
242 Ok(mut reranker) => {
243 assert_eq!(reranker.id(), "fastembed_reranker");
244
245 let query = "error handling";
246 let documents = vec![
247 "try catch exception handling".to_string(),
248 "user interface design".to_string(),
249 ];
250
251 let result = reranker.rerank(query, &documents);
252 assert!(result.is_ok());
253
254 let results = result.unwrap();
255 assert_eq!(results.len(), 2);
256
257 assert!(results[0].score > results[1].score);
259 }
260 Err(_) => {
261 }
264 }
265 }
266
267 #[test]
268 fn test_reranker_empty_input() {
269 let mut reranker = DummyReranker::new();
270 let query = "test query";
271 let documents: Vec<String> = vec![];
272 let results = reranker.rerank(query, &documents).unwrap();
273 assert_eq!(results.len(), 0);
274 }
275
276 #[test]
277 fn test_reranker_single_document() {
278 let mut reranker = DummyReranker::new();
279 let query = "test query";
280 let documents = vec!["single document".to_string()];
281 let results = reranker.rerank(query, &documents).unwrap();
282
283 assert_eq!(results.len(), 1);
284 assert_eq!(results[0].query, query);
285 assert_eq!(results[0].document, "single document");
286 }
287}