1use anyhow::Result;
2
3#[cfg(feature = "fastembed")]
4use std::path::{Path, PathBuf};
5
6pub mod reranker;
7pub mod tokenizer;
8
9pub use reranker::{RerankResult, Reranker, create_reranker, create_reranker_with_progress};
10pub use tokenizer::TokenEstimator;
11
12pub trait Embedder: Send + Sync {
13 fn id(&self) -> &'static str;
14 fn dim(&self) -> usize;
15 fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
16}
17
18pub type ModelDownloadCallback = Box<dyn Fn(&str) + Send + Sync>;
19
20pub fn create_embedder(model_name: Option<&str>) -> Result<Box<dyn Embedder>> {
21 create_embedder_with_progress(model_name, None)
22}
23
24pub fn create_embedder_with_progress(
25 model_name: Option<&str>,
26 progress_callback: Option<ModelDownloadCallback>,
27) -> Result<Box<dyn Embedder>> {
28 let model = model_name.unwrap_or("BAAI/bge-small-en-v1.5");
29
30 #[cfg(feature = "fastembed")]
31 {
32 Ok(Box::new(FastEmbedder::new_with_progress(
33 model,
34 progress_callback,
35 )?))
36 }
37
38 #[cfg(not(feature = "fastembed"))]
39 {
40 let _ = model; if let Some(callback) = progress_callback {
42 callback("Using dummy embedder (no model download required)");
43 }
44 Ok(Box::new(DummyEmbedder::new()))
45 }
46}
47
48pub struct DummyEmbedder {
49 dim: usize,
50}
51
52impl Default for DummyEmbedder {
53 fn default() -> Self {
54 Self::new()
55 }
56}
57
58impl DummyEmbedder {
59 pub fn new() -> Self {
60 Self { dim: 384 } }
62}
63
64impl Embedder for DummyEmbedder {
65 fn id(&self) -> &'static str {
66 "dummy"
67 }
68
69 fn dim(&self) -> usize {
70 self.dim
71 }
72
73 fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
74 Ok(texts.iter().map(|_| vec![0.0; self.dim]).collect())
75 }
76}
77
78#[cfg(feature = "fastembed")]
79pub struct FastEmbedder {
80 model: fastembed::TextEmbedding,
81 dim: usize,
82}
83
84#[cfg(feature = "fastembed")]
85impl FastEmbedder {
86 pub fn new(model_name: &str) -> Result<Self> {
87 Self::new_with_progress(model_name, None)
88 }
89
90 pub fn new_with_progress(
91 model_name: &str,
92 progress_callback: Option<ModelDownloadCallback>,
93 ) -> Result<Self> {
94 use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
95
96 let model = match model_name {
97 "BAAI/bge-small-en-v1.5" => EmbeddingModel::BGESmallENV15,
99 "sentence-transformers/all-MiniLM-L6-v2" => EmbeddingModel::AllMiniLML6V2,
100
101 "nomic-embed-text-v1" => EmbeddingModel::NomicEmbedTextV1,
103 "nomic-embed-text-v1.5" => EmbeddingModel::NomicEmbedTextV15,
104 "jina-embeddings-v2-base-code" => EmbeddingModel::JinaEmbeddingsV2BaseCode,
105
106 "BAAI/bge-base-en-v1.5" => EmbeddingModel::BGEBaseENV15,
108 "BAAI/bge-large-en-v1.5" => EmbeddingModel::BGELargeENV15,
109
110 _ => EmbeddingModel::NomicEmbedTextV15,
112 };
113
114 let model_cache_dir = Self::get_model_cache_dir()?;
116 std::fs::create_dir_all(&model_cache_dir)?;
117
118 if let Some(ref callback) = progress_callback {
119 callback(&format!("Initializing model: {}", model_name));
120
121 let model_exists = Self::check_model_exists(&model_cache_dir, model_name);
123 if !model_exists {
124 callback(&format!(
125 "Downloading model {} to {}",
126 model_name,
127 model_cache_dir.display()
128 ));
129 } else {
130 callback(&format!("Using cached model: {}", model_name));
131 }
132 }
133
134 let max_length = match model {
136 EmbeddingModel::BGESmallENV15 | EmbeddingModel::AllMiniLML6V2 => 512,
138 EmbeddingModel::BGEBaseENV15 => 512,
139
140 EmbeddingModel::NomicEmbedTextV1 | EmbeddingModel::NomicEmbedTextV15 => 8192,
142 EmbeddingModel::JinaEmbeddingsV2BaseCode => 8192,
143
144 EmbeddingModel::BGELargeENV15 => 512, _ => 512, };
149
150 let init_options = InitOptions::new(model.clone())
151 .with_show_download_progress(progress_callback.is_some())
152 .with_cache_dir(model_cache_dir)
153 .with_max_length(max_length);
154
155 let embedding = TextEmbedding::try_new(init_options)?;
156
157 if let Some(ref callback) = progress_callback {
158 callback("Model loaded successfully");
159 }
160
161 let dim = match model {
162 EmbeddingModel::BGESmallENV15 => 384,
164 EmbeddingModel::AllMiniLML6V2 => 384,
165
166 EmbeddingModel::NomicEmbedTextV1 => 768,
168 EmbeddingModel::NomicEmbedTextV15 => 768,
169 EmbeddingModel::JinaEmbeddingsV2BaseCode => 768,
170 EmbeddingModel::BGEBaseENV15 => 768,
171
172 EmbeddingModel::BGELargeENV15 => 1024,
174
175 _ => 384, };
177
178 Ok(Self {
179 model: embedding,
180 dim,
181 })
182 }
183
184 fn get_model_cache_dir() -> Result<PathBuf> {
185 let cache_dir = if let Some(cache_home) = std::env::var_os("XDG_CACHE_HOME") {
187 PathBuf::from(cache_home).join("ck")
188 } else if let Some(home) = std::env::var_os("HOME") {
189 PathBuf::from(home).join(".cache").join("ck")
190 } else if let Some(appdata) = std::env::var_os("LOCALAPPDATA") {
191 PathBuf::from(appdata).join("ck").join("cache")
192 } else {
193 PathBuf::from(".ck_models")
195 };
196
197 Ok(cache_dir.join("models"))
198 }
199
200 fn check_model_exists(cache_dir: &Path, model_name: &str) -> bool {
201 let model_dir = cache_dir.join(model_name.replace("/", "_"));
203 model_dir.exists()
204 }
205}
206
207#[cfg(feature = "fastembed")]
208impl Embedder for FastEmbedder {
209 fn id(&self) -> &'static str {
210 "fastembed"
211 }
212
213 fn dim(&self) -> usize {
214 self.dim
215 }
216
217 fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
218 let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
219 let embeddings = self.model.embed(text_refs, None)?;
220 Ok(embeddings)
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 #[test]
229 fn test_dummy_embedder() {
230 let mut embedder = DummyEmbedder::new();
231
232 assert_eq!(embedder.id(), "dummy");
233 assert_eq!(embedder.dim(), 384);
234
235 let texts = vec!["hello".to_string(), "world".to_string()];
236 let embeddings = embedder.embed(&texts).unwrap();
237
238 assert_eq!(embeddings.len(), 2);
239 assert_eq!(embeddings[0].len(), 384);
240 assert_eq!(embeddings[1].len(), 384);
241
242 assert!(embeddings[0].iter().all(|&x| x == 0.0));
244 assert!(embeddings[1].iter().all(|&x| x == 0.0));
245 }
246
247 #[test]
248 fn test_create_embedder_dummy() {
249 #[cfg(not(feature = "fastembed"))]
250 {
251 let embedder = create_embedder(None).unwrap();
252 assert_eq!(embedder.id(), "dummy");
253 assert_eq!(embedder.dim(), 384);
254 }
255 }
256
257 #[test]
258 fn test_embedder_trait_object() {
259 let mut embedder: Box<dyn Embedder> = Box::new(DummyEmbedder::new());
260
261 let texts = vec!["test".to_string()];
262 let result = embedder.embed(&texts);
263 assert!(result.is_ok());
264
265 let embeddings = result.unwrap();
266 assert_eq!(embeddings.len(), 1);
267 assert_eq!(embeddings[0].len(), 384);
268 }
269
270 #[cfg(feature = "fastembed")]
271 #[test]
272 fn test_fastembed_creation() {
273 if std::env::var("CI").is_ok() {
275 return;
276 }
277
278 let embedder = FastEmbedder::new("BAAI/bge-small-en-v1.5");
279
280 match embedder {
283 Ok(mut embedder) => {
284 assert_eq!(embedder.id(), "fastembed");
285 assert_eq!(embedder.dim(), 384);
286
287 let texts = vec!["hello world".to_string()];
288 let result = embedder.embed(&texts);
289 assert!(result.is_ok());
290
291 let embeddings = result.unwrap();
292 assert_eq!(embeddings.len(), 1);
293 assert_eq!(embeddings[0].len(), 384);
294
295 assert!(!embeddings[0].iter().all(|&x| x == 0.0));
297 }
298 Err(_) => {
299 }
302 }
303 }
304
305 #[cfg(feature = "fastembed")]
306 #[test]
307 fn test_create_embedder_fastembed() {
308 if std::env::var("CI").is_ok() {
309 return;
310 }
311
312 let embedder = create_embedder(Some("BAAI/bge-small-en-v1.5"));
313
314 match embedder {
315 Ok(embedder) => {
316 assert_eq!(embedder.id(), "fastembed");
317 assert_eq!(embedder.dim(), 384);
318 }
319 Err(_) => {
320 }
322 }
323 }
324
325 #[test]
326 fn test_embedder_empty_input() {
327 let mut embedder = DummyEmbedder::new();
328 let texts: Vec<String> = vec![];
329 let embeddings = embedder.embed(&texts).unwrap();
330 assert_eq!(embeddings.len(), 0);
331 }
332
333 #[test]
334 fn test_embedder_single_text() {
335 let mut embedder = DummyEmbedder::new();
336 let texts = vec!["single text".to_string()];
337 let embeddings = embedder.embed(&texts).unwrap();
338
339 assert_eq!(embeddings.len(), 1);
340 assert_eq!(embeddings[0].len(), 384);
341 }
342
343 #[test]
344 fn test_embedder_multiple_texts() {
345 let mut embedder = DummyEmbedder::new();
346 let texts = vec![
347 "first text".to_string(),
348 "second text".to_string(),
349 "third text".to_string(),
350 ];
351 let embeddings = embedder.embed(&texts).unwrap();
352
353 assert_eq!(embeddings.len(), 3);
354 for embedding in &embeddings {
355 assert_eq!(embedding.len(), 384);
356 }
357 }
358}