1use anyhow::Result;
2
3#[cfg(feature = "fastembed")]
4use std::path::{Path, PathBuf};
5
6pub mod reranker;
7pub mod tokenizer;
8
9pub use reranker::{RerankResult, Reranker, create_reranker, create_reranker_with_progress};
10pub use tokenizer::TokenEstimator;
11
12pub trait Embedder: Send + Sync {
13 fn id(&self) -> &'static str;
14 fn dim(&self) -> usize;
15 fn model_name(&self) -> &str;
16 fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
17}
18
19pub type ModelDownloadCallback = Box<dyn Fn(&str) + Send + Sync>;
20
21pub fn create_embedder(model_name: Option<&str>) -> Result<Box<dyn Embedder>> {
22 create_embedder_with_progress(model_name, None)
23}
24
25pub fn create_embedder_with_progress(
26 model_name: Option<&str>,
27 progress_callback: Option<ModelDownloadCallback>,
28) -> Result<Box<dyn Embedder>> {
29 let model = model_name.unwrap_or("BAAI/bge-small-en-v1.5");
30
31 #[cfg(feature = "fastembed")]
32 {
33 Ok(Box::new(FastEmbedder::new_with_progress(
34 model,
35 progress_callback,
36 )?))
37 }
38
39 #[cfg(not(feature = "fastembed"))]
40 {
41 if let Some(callback) = progress_callback {
42 callback("Using dummy embedder (no model download required)");
43 }
44 Ok(Box::new(DummyEmbedder::new_with_model(model)))
45 }
46}
47
48pub struct DummyEmbedder {
49 dim: usize,
50 model_name: String,
51}
52
53impl Default for DummyEmbedder {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59impl DummyEmbedder {
60 pub fn new() -> Self {
61 Self {
62 dim: 384, model_name: "dummy".to_string(),
64 }
65 }
66
67 pub fn new_with_model(model_name: &str) -> Self {
68 Self {
69 dim: 384, model_name: model_name.to_string(),
71 }
72 }
73}
74
75impl Embedder for DummyEmbedder {
76 fn id(&self) -> &'static str {
77 "dummy"
78 }
79
80 fn dim(&self) -> usize {
81 self.dim
82 }
83
84 fn model_name(&self) -> &str {
85 &self.model_name
86 }
87
88 fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
89 Ok(texts.iter().map(|_| vec![0.0; self.dim]).collect())
90 }
91}
92
93#[cfg(feature = "fastembed")]
94pub struct FastEmbedder {
95 model: fastembed::TextEmbedding,
96 dim: usize,
97 model_name: String,
98}
99
100#[cfg(feature = "fastembed")]
101impl FastEmbedder {
102 pub fn new(model_name: &str) -> Result<Self> {
103 Self::new_with_progress(model_name, None)
104 }
105
106 pub fn new_with_progress(
107 model_name: &str,
108 progress_callback: Option<ModelDownloadCallback>,
109 ) -> Result<Self> {
110 use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
111
112 let model = match model_name {
113 "BAAI/bge-small-en-v1.5" => EmbeddingModel::BGESmallENV15,
115 "sentence-transformers/all-MiniLM-L6-v2" => EmbeddingModel::AllMiniLML6V2,
116
117 "nomic-embed-text-v1" => EmbeddingModel::NomicEmbedTextV1,
119 "nomic-embed-text-v1.5" => EmbeddingModel::NomicEmbedTextV15,
120 "jina-embeddings-v2-base-code" => EmbeddingModel::JinaEmbeddingsV2BaseCode,
121
122 "BAAI/bge-base-en-v1.5" => EmbeddingModel::BGEBaseENV15,
124 "BAAI/bge-large-en-v1.5" => EmbeddingModel::BGELargeENV15,
125
126 _ => EmbeddingModel::NomicEmbedTextV15,
128 };
129
130 let model_cache_dir = Self::get_model_cache_dir()?;
132 std::fs::create_dir_all(&model_cache_dir)?;
133
134 if let Some(ref callback) = progress_callback {
135 callback(&format!("Initializing model: {}", model_name));
136
137 let model_exists = Self::check_model_exists(&model_cache_dir, model_name);
139 if !model_exists {
140 callback(&format!(
141 "Downloading model {} to {}",
142 model_name,
143 model_cache_dir.display()
144 ));
145 } else {
146 callback(&format!("Using cached model: {}", model_name));
147 }
148 }
149
150 let max_length = match model {
152 EmbeddingModel::BGESmallENV15 | EmbeddingModel::AllMiniLML6V2 => 512,
154 EmbeddingModel::BGEBaseENV15 => 512,
155
156 EmbeddingModel::NomicEmbedTextV1 | EmbeddingModel::NomicEmbedTextV15 => 8192,
158 EmbeddingModel::JinaEmbeddingsV2BaseCode => 8192,
159
160 EmbeddingModel::BGELargeENV15 => 512, _ => 512, };
165
166 let init_options = InitOptions::new(model.clone())
167 .with_show_download_progress(progress_callback.is_some())
168 .with_cache_dir(model_cache_dir)
169 .with_max_length(max_length);
170
171 let embedding = TextEmbedding::try_new(init_options)?;
172
173 if let Some(ref callback) = progress_callback {
174 callback("Model loaded successfully");
175 }
176
177 let dim = match model {
178 EmbeddingModel::BGESmallENV15 => 384,
180 EmbeddingModel::AllMiniLML6V2 => 384,
181
182 EmbeddingModel::NomicEmbedTextV1 => 768,
184 EmbeddingModel::NomicEmbedTextV15 => 768,
185 EmbeddingModel::JinaEmbeddingsV2BaseCode => 768,
186 EmbeddingModel::BGEBaseENV15 => 768,
187
188 EmbeddingModel::BGELargeENV15 => 1024,
190
191 _ => 384, };
193
194 Ok(Self {
195 model: embedding,
196 dim,
197 model_name: model_name.to_string(),
198 })
199 }
200
201 fn get_model_cache_dir() -> Result<PathBuf> {
202 let cache_dir = if let Some(cache_home) = std::env::var_os("XDG_CACHE_HOME") {
204 PathBuf::from(cache_home).join("ck")
205 } else if let Some(home) = std::env::var_os("HOME") {
206 PathBuf::from(home).join(".cache").join("ck")
207 } else if let Some(appdata) = std::env::var_os("LOCALAPPDATA") {
208 PathBuf::from(appdata).join("ck").join("cache")
209 } else {
210 PathBuf::from(".ck_models")
212 };
213
214 Ok(cache_dir.join("models"))
215 }
216
217 fn check_model_exists(cache_dir: &Path, model_name: &str) -> bool {
218 let model_dir = cache_dir.join(model_name.replace("/", "_"));
220 model_dir.exists()
221 }
222}
223
224#[cfg(feature = "fastembed")]
225impl Embedder for FastEmbedder {
226 fn id(&self) -> &'static str {
227 "fastembed"
228 }
229
230 fn dim(&self) -> usize {
231 self.dim
232 }
233
234 fn model_name(&self) -> &str {
235 &self.model_name
236 }
237
238 fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
239 let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
240 let embeddings = self.model.embed(text_refs, None)?;
241 Ok(embeddings)
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248
249 #[test]
250 fn test_dummy_embedder() {
251 let mut embedder = DummyEmbedder::new();
252
253 assert_eq!(embedder.id(), "dummy");
254 assert_eq!(embedder.dim(), 384);
255
256 let texts = vec!["hello".to_string(), "world".to_string()];
257 let embeddings = embedder.embed(&texts).unwrap();
258
259 assert_eq!(embeddings.len(), 2);
260 assert_eq!(embeddings[0].len(), 384);
261 assert_eq!(embeddings[1].len(), 384);
262
263 assert!(embeddings[0].iter().all(|&x| x == 0.0));
265 assert!(embeddings[1].iter().all(|&x| x == 0.0));
266 }
267
268 #[test]
269 fn test_create_embedder_dummy() {
270 #[cfg(not(feature = "fastembed"))]
271 {
272 let embedder = create_embedder(None).unwrap();
273 assert_eq!(embedder.id(), "dummy");
274 assert_eq!(embedder.dim(), 384);
275 }
276 }
277
278 #[test]
279 fn test_embedder_trait_object() {
280 let mut embedder: Box<dyn Embedder> = Box::new(DummyEmbedder::new());
281
282 let texts = vec!["test".to_string()];
283 let result = embedder.embed(&texts);
284 assert!(result.is_ok());
285
286 let embeddings = result.unwrap();
287 assert_eq!(embeddings.len(), 1);
288 assert_eq!(embeddings[0].len(), 384);
289 }
290
291 #[cfg(feature = "fastembed")]
292 #[test]
293 fn test_fastembed_creation() {
294 if std::env::var("CI").is_ok() {
296 return;
297 }
298
299 let embedder = FastEmbedder::new("BAAI/bge-small-en-v1.5");
300
301 match embedder {
304 Ok(mut embedder) => {
305 assert_eq!(embedder.id(), "fastembed");
306 assert_eq!(embedder.dim(), 384);
307
308 let texts = vec!["hello world".to_string()];
309 let result = embedder.embed(&texts);
310 assert!(result.is_ok());
311
312 let embeddings = result.unwrap();
313 assert_eq!(embeddings.len(), 1);
314 assert_eq!(embeddings[0].len(), 384);
315
316 assert!(!embeddings[0].iter().all(|&x| x == 0.0));
318 }
319 Err(_) => {
320 }
323 }
324 }
325
326 #[cfg(feature = "fastembed")]
327 #[test]
328 fn test_create_embedder_fastembed() {
329 if std::env::var("CI").is_ok() {
330 return;
331 }
332
333 let embedder = create_embedder(Some("BAAI/bge-small-en-v1.5"));
334
335 match embedder {
336 Ok(embedder) => {
337 assert_eq!(embedder.id(), "fastembed");
338 assert_eq!(embedder.dim(), 384);
339 }
340 Err(_) => {
341 }
343 }
344 }
345
346 #[test]
347 fn test_embedder_empty_input() {
348 let mut embedder = DummyEmbedder::new();
349 let texts: Vec<String> = vec![];
350 let embeddings = embedder.embed(&texts).unwrap();
351 assert_eq!(embeddings.len(), 0);
352 }
353
354 #[test]
355 fn test_embedder_single_text() {
356 let mut embedder = DummyEmbedder::new();
357 let texts = vec!["single text".to_string()];
358 let embeddings = embedder.embed(&texts).unwrap();
359
360 assert_eq!(embeddings.len(), 1);
361 assert_eq!(embeddings[0].len(), 384);
362 }
363
364 #[test]
365 fn test_embedder_multiple_texts() {
366 let mut embedder = DummyEmbedder::new();
367 let texts = vec![
368 "first text".to_string(),
369 "second text".to_string(),
370 "third text".to_string(),
371 ];
372 let embeddings = embedder.embed(&texts).unwrap();
373
374 assert_eq!(embeddings.len(), 3);
375 for embedding in &embeddings {
376 assert_eq!(embedding.len(), 384);
377 }
378 }
379}