1use anyhow::{Result, bail};
2use ck_models::{ModelConfig, ModelRegistry};
3#[cfg(feature = "fastembed")]
4use std::path::Path;
5#[cfg(any(feature = "fastembed", feature = "mixedbread"))]
6use std::path::PathBuf;
7
8pub mod reranker;
9pub mod tokenizer;
10
11pub use reranker::{
12 RerankResult, Reranker, create_reranker, create_reranker_for_config,
13 create_reranker_with_progress,
14};
15pub use tokenizer::TokenEstimator;
16
17#[cfg(feature = "mixedbread")]
18mod mixedbread;
19#[cfg(feature = "mixedbread")]
20use mixedbread::MixedbreadEmbedder;
21
22pub trait Embedder: Send + Sync {
23 fn id(&self) -> &'static str;
24 fn dim(&self) -> usize;
25 fn model_name(&self) -> &str;
26 fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
27}
28
29pub type ModelDownloadCallback = Box<dyn Fn(&str) + Send + Sync>;
30
31#[cfg(any(feature = "fastembed", feature = "mixedbread"))]
32pub(crate) fn model_cache_root() -> Result<PathBuf> {
33 let base = if let Some(cache_home) = std::env::var_os("XDG_CACHE_HOME") {
34 PathBuf::from(cache_home).join("ck")
35 } else if let Some(home) = std::env::var_os("HOME") {
36 PathBuf::from(home).join(".cache").join("ck")
37 } else if let Some(appdata) = std::env::var_os("LOCALAPPDATA") {
38 PathBuf::from(appdata).join("ck").join("cache")
39 } else {
40 PathBuf::from(".ck_models")
41 };
42
43 Ok(base.join("models"))
44}
45
46pub fn create_embedder(model_name: Option<&str>) -> Result<Box<dyn Embedder>> {
47 create_embedder_with_progress(model_name, None)
48}
49
50pub fn create_embedder_with_progress(
51 model_name: Option<&str>,
52 progress_callback: Option<ModelDownloadCallback>,
53) -> Result<Box<dyn Embedder>> {
54 let registry = ModelRegistry::default();
55 let (_, config) = registry.resolve(model_name)?;
56 create_embedder_for_config(&config, progress_callback)
57}
58
59#[allow(clippy::needless_return)]
60pub fn create_embedder_for_config(
61 config: &ModelConfig,
62 progress_callback: Option<ModelDownloadCallback>,
63) -> Result<Box<dyn Embedder>> {
64 match config.provider.as_str() {
65 "fastembed" => {
66 #[cfg(feature = "fastembed")]
67 {
68 return Ok(Box::new(FastEmbedder::new_with_progress(
69 config.name.as_str(),
70 progress_callback,
71 )?));
72 }
73
74 #[cfg(not(feature = "fastembed"))]
75 {
76 if let Some(callback) = progress_callback.as_ref() {
77 callback("fastembed provider unavailable; using dummy embedder");
78 }
79 return Ok(Box::new(DummyEmbedder::new_with_model(
80 config.name.as_str(),
81 )));
82 }
83 }
84 "mixedbread" => {
85 #[cfg(feature = "mixedbread")]
86 {
87 return Ok(Box::new(MixedbreadEmbedder::new(
88 config,
89 progress_callback,
90 )?));
91 }
92 #[cfg(not(feature = "mixedbread"))]
93 {
94 bail!(
95 "Model '{}' requires the `mixedbread` feature. Rebuild ck with Mixedbread support.",
96 config.name
97 );
98 }
99 }
100 provider => bail!("Unsupported embedding provider '{}'", provider),
101 }
102}
103
104pub struct DummyEmbedder {
105 dim: usize,
106 model_name: String,
107}
108
109impl Default for DummyEmbedder {
110 fn default() -> Self {
111 Self::new()
112 }
113}
114
115impl DummyEmbedder {
116 pub fn new() -> Self {
117 Self {
118 dim: 384, model_name: "dummy".to_string(),
120 }
121 }
122
123 pub fn new_with_model(model_name: &str) -> Self {
124 Self {
125 dim: 384, model_name: model_name.to_string(),
127 }
128 }
129}
130
131impl Embedder for DummyEmbedder {
132 fn id(&self) -> &'static str {
133 "dummy"
134 }
135
136 fn dim(&self) -> usize {
137 self.dim
138 }
139
140 fn model_name(&self) -> &str {
141 &self.model_name
142 }
143
144 fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
145 Ok(texts.iter().map(|_| vec![0.0; self.dim]).collect())
146 }
147}
148
149#[cfg(feature = "fastembed")]
150pub struct FastEmbedder {
151 model: fastembed::TextEmbedding,
152 dim: usize,
153 model_name: String,
154}
155
156#[cfg(feature = "fastembed")]
157impl FastEmbedder {
158 pub fn new(model_name: &str) -> Result<Self> {
159 Self::new_with_progress(model_name, None)
160 }
161
162 pub fn new_with_progress(
163 model_name: &str,
164 progress_callback: Option<ModelDownloadCallback>,
165 ) -> Result<Self> {
166 use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
167
168 let model = match model_name {
169 "BAAI/bge-small-en-v1.5" => EmbeddingModel::BGESmallENV15,
171 "sentence-transformers/all-MiniLM-L6-v2" => EmbeddingModel::AllMiniLML6V2,
172
173 "nomic-embed-text-v1" => EmbeddingModel::NomicEmbedTextV1,
175 "nomic-embed-text-v1.5" => EmbeddingModel::NomicEmbedTextV15,
176 "jina-embeddings-v2-base-code" => EmbeddingModel::JinaEmbeddingsV2BaseCode,
177
178 "BAAI/bge-base-en-v1.5" => EmbeddingModel::BGEBaseENV15,
180 "BAAI/bge-large-en-v1.5" => EmbeddingModel::BGELargeENV15,
181
182 _ => EmbeddingModel::NomicEmbedTextV15,
184 };
185
186 let model_cache_dir = model_cache_root()?;
188 std::fs::create_dir_all(&model_cache_dir)?;
189
190 if let Some(ref callback) = progress_callback {
191 callback(&format!("Initializing model: {}", model_name));
192
193 let model_exists = Self::check_model_exists(&model_cache_dir, model_name);
195 if !model_exists {
196 callback(&format!(
197 "Downloading model {} to {}",
198 model_name,
199 model_cache_dir.display()
200 ));
201 } else {
202 callback(&format!("Using cached model: {}", model_name));
203 }
204 }
205
206 let max_length = match model {
208 EmbeddingModel::BGESmallENV15 | EmbeddingModel::AllMiniLML6V2 => 512,
210 EmbeddingModel::BGEBaseENV15 => 512,
211
212 EmbeddingModel::NomicEmbedTextV1 | EmbeddingModel::NomicEmbedTextV15 => 8192,
214 EmbeddingModel::JinaEmbeddingsV2BaseCode => 8192,
215
216 EmbeddingModel::BGELargeENV15 => 512, _ => 512, };
221
222 let init_options = InitOptions::new(model.clone())
223 .with_show_download_progress(progress_callback.is_some())
224 .with_cache_dir(model_cache_dir)
225 .with_max_length(max_length);
226
227 let embedding = TextEmbedding::try_new(init_options)?;
228
229 if let Some(ref callback) = progress_callback {
230 callback("Model loaded successfully");
231 }
232
233 let dim = match model {
234 EmbeddingModel::BGESmallENV15 => 384,
236 EmbeddingModel::AllMiniLML6V2 => 384,
237
238 EmbeddingModel::NomicEmbedTextV1 => 768,
240 EmbeddingModel::NomicEmbedTextV15 => 768,
241 EmbeddingModel::JinaEmbeddingsV2BaseCode => 768,
242 EmbeddingModel::BGEBaseENV15 => 768,
243
244 EmbeddingModel::BGELargeENV15 => 1024,
246
247 _ => 384, };
249
250 Ok(Self {
251 model: embedding,
252 dim,
253 model_name: model_name.to_string(),
254 })
255 }
256
257 fn check_model_exists(cache_dir: &Path, model_name: &str) -> bool {
258 let model_dir = cache_dir.join(model_name.replace("/", "_"));
260 model_dir.exists()
261 }
262}
263
264#[cfg(feature = "fastembed")]
265impl Embedder for FastEmbedder {
266 fn id(&self) -> &'static str {
267 "fastembed"
268 }
269
270 fn dim(&self) -> usize {
271 self.dim
272 }
273
274 fn model_name(&self) -> &str {
275 &self.model_name
276 }
277
278 fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
279 let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
280 let embeddings = self.model.embed(text_refs, None)?;
281 Ok(embeddings)
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 #[test]
290 fn test_dummy_embedder() {
291 let mut embedder = DummyEmbedder::new();
292
293 assert_eq!(embedder.id(), "dummy");
294 assert_eq!(embedder.dim(), 384);
295
296 let texts = vec!["hello".to_string(), "world".to_string()];
297 let embeddings = embedder.embed(&texts).unwrap();
298
299 assert_eq!(embeddings.len(), 2);
300 assert_eq!(embeddings[0].len(), 384);
301 assert_eq!(embeddings[1].len(), 384);
302
303 assert!(embeddings[0].iter().all(|&x| x == 0.0));
305 assert!(embeddings[1].iter().all(|&x| x == 0.0));
306 }
307
308 #[test]
309 fn test_create_embedder_dummy() {
310 #[cfg(not(feature = "fastembed"))]
311 {
312 let embedder = create_embedder(None).unwrap();
313 assert_eq!(embedder.id(), "dummy");
314 assert_eq!(embedder.dim(), 384);
315 }
316 }
317
318 #[test]
319 fn test_embedder_trait_object() {
320 let mut embedder: Box<dyn Embedder> = Box::new(DummyEmbedder::new());
321
322 let texts = vec!["test".to_string()];
323 let result = embedder.embed(&texts);
324 assert!(result.is_ok());
325
326 let embeddings = result.unwrap();
327 assert_eq!(embeddings.len(), 1);
328 assert_eq!(embeddings[0].len(), 384);
329 }
330
331 #[cfg(feature = "fastembed")]
332 #[test]
333 fn test_fastembed_creation() {
334 if std::env::var("CI").is_ok() {
336 return;
337 }
338
339 let embedder = FastEmbedder::new("BAAI/bge-small-en-v1.5");
340
341 match embedder {
344 Ok(mut embedder) => {
345 assert_eq!(embedder.id(), "fastembed");
346 assert_eq!(embedder.dim(), 384);
347
348 let texts = vec!["hello world".to_string()];
349 let result = embedder.embed(&texts);
350 assert!(result.is_ok());
351
352 let embeddings = result.unwrap();
353 assert_eq!(embeddings.len(), 1);
354 assert_eq!(embeddings[0].len(), 384);
355
356 assert!(!embeddings[0].iter().all(|&x| x == 0.0));
358 }
359 Err(_) => {
360 }
363 }
364 }
365
366 #[cfg(feature = "fastembed")]
367 #[test]
368 fn test_create_embedder_fastembed() {
369 if std::env::var("CI").is_ok() {
370 return;
371 }
372
373 let embedder = create_embedder(Some("BAAI/bge-small-en-v1.5"));
374
375 match embedder {
376 Ok(embedder) => {
377 assert_eq!(embedder.id(), "fastembed");
378 assert_eq!(embedder.dim(), 384);
379 }
380 Err(_) => {
381 }
383 }
384 }
385
386 #[test]
387 fn test_embedder_empty_input() {
388 let mut embedder = DummyEmbedder::new();
389 let texts: Vec<String> = vec![];
390 let embeddings = embedder.embed(&texts).unwrap();
391 assert_eq!(embeddings.len(), 0);
392 }
393
394 #[test]
395 fn test_embedder_single_text() {
396 let mut embedder = DummyEmbedder::new();
397 let texts = vec!["single text".to_string()];
398 let embeddings = embedder.embed(&texts).unwrap();
399
400 assert_eq!(embeddings.len(), 1);
401 assert_eq!(embeddings[0].len(), 384);
402 }
403
404 #[test]
405 fn test_embedder_multiple_texts() {
406 let mut embedder = DummyEmbedder::new();
407 let texts = vec![
408 "first text".to_string(),
409 "second text".to_string(),
410 "third text".to_string(),
411 ];
412 let embeddings = embedder.embed(&texts).unwrap();
413
414 assert_eq!(embeddings.len(), 3);
415 for embedding in &embeddings {
416 assert_eq!(embedding.len(), 384);
417 }
418 }
419}