Skip to main content

ck_embed/
lib.rs

1use anyhow::{Result, bail};
2use ck_models::{ModelConfig, ModelRegistry};
3#[cfg(feature = "fastembed")]
4use std::path::Path;
5#[cfg(any(feature = "fastembed", feature = "mixedbread"))]
6use std::path::PathBuf;
7
8pub mod reranker;
9pub mod tokenizer;
10
11pub use reranker::{
12    RerankResult, Reranker, create_reranker, create_reranker_for_config,
13    create_reranker_with_progress,
14};
15pub use tokenizer::TokenEstimator;
16
17#[cfg(feature = "mixedbread")]
18mod mixedbread;
19#[cfg(feature = "mixedbread")]
20use mixedbread::MixedbreadEmbedder;
21
22pub trait Embedder: Send + Sync {
23    fn id(&self) -> &'static str;
24    fn dim(&self) -> usize;
25    fn model_name(&self) -> &str;
26    fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
27}
28
29pub type ModelDownloadCallback = Box<dyn Fn(&str) + Send + Sync>;
30
31#[cfg(any(feature = "fastembed", feature = "mixedbread"))]
32pub(crate) fn model_cache_root() -> Result<PathBuf> {
33    let base = if let Some(cache_home) = std::env::var_os("XDG_CACHE_HOME") {
34        PathBuf::from(cache_home).join("ck")
35    } else if let Some(home) = std::env::var_os("HOME") {
36        PathBuf::from(home).join(".cache").join("ck")
37    } else if let Some(appdata) = std::env::var_os("LOCALAPPDATA") {
38        PathBuf::from(appdata).join("ck").join("cache")
39    } else {
40        PathBuf::from(".ck_models")
41    };
42
43    Ok(base.join("models"))
44}
45
46pub fn create_embedder(model_name: Option<&str>) -> Result<Box<dyn Embedder>> {
47    create_embedder_with_progress(model_name, None)
48}
49
50pub fn create_embedder_with_progress(
51    model_name: Option<&str>,
52    progress_callback: Option<ModelDownloadCallback>,
53) -> Result<Box<dyn Embedder>> {
54    let registry = ModelRegistry::default();
55    let (_, config) = registry.resolve(model_name)?;
56    create_embedder_for_config(&config, progress_callback)
57}
58
59#[allow(clippy::needless_return)]
60pub fn create_embedder_for_config(
61    config: &ModelConfig,
62    progress_callback: Option<ModelDownloadCallback>,
63) -> Result<Box<dyn Embedder>> {
64    match config.provider.as_str() {
65        "fastembed" => {
66            #[cfg(feature = "fastembed")]
67            {
68                return Ok(Box::new(FastEmbedder::new_with_progress(
69                    config.name.as_str(),
70                    progress_callback,
71                )?));
72            }
73
74            #[cfg(not(feature = "fastembed"))]
75            {
76                if let Some(callback) = progress_callback.as_ref() {
77                    callback("fastembed provider unavailable; using dummy embedder");
78                }
79                return Ok(Box::new(DummyEmbedder::new_with_model(
80                    config.name.as_str(),
81                )));
82            }
83        }
84        "mixedbread" => {
85            #[cfg(feature = "mixedbread")]
86            {
87                return Ok(Box::new(MixedbreadEmbedder::new(
88                    config,
89                    progress_callback,
90                )?));
91            }
92            #[cfg(not(feature = "mixedbread"))]
93            {
94                bail!(
95                    "Model '{}' requires the `mixedbread` feature. Rebuild ck with Mixedbread support.",
96                    config.name
97                );
98            }
99        }
100        provider => bail!("Unsupported embedding provider '{}'", provider),
101    }
102}
103
104pub struct DummyEmbedder {
105    dim: usize,
106    model_name: String,
107}
108
109impl Default for DummyEmbedder {
110    fn default() -> Self {
111        Self::new()
112    }
113}
114
115impl DummyEmbedder {
116    pub fn new() -> Self {
117        Self {
118            dim: 384, // Match default BGE model
119            model_name: "dummy".to_string(),
120        }
121    }
122
123    pub fn new_with_model(model_name: &str) -> Self {
124        Self {
125            dim: 384, // Match default BGE model
126            model_name: model_name.to_string(),
127        }
128    }
129}
130
131impl Embedder for DummyEmbedder {
132    fn id(&self) -> &'static str {
133        "dummy"
134    }
135
136    fn dim(&self) -> usize {
137        self.dim
138    }
139
140    fn model_name(&self) -> &str {
141        &self.model_name
142    }
143
144    fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
145        Ok(texts.iter().map(|_| vec![0.0; self.dim]).collect())
146    }
147}
148
149#[cfg(feature = "fastembed")]
150pub struct FastEmbedder {
151    model: fastembed::TextEmbedding,
152    dim: usize,
153    model_name: String,
154}
155
156#[cfg(feature = "fastembed")]
157impl FastEmbedder {
158    pub fn new(model_name: &str) -> Result<Self> {
159        Self::new_with_progress(model_name, None)
160    }
161
162    pub fn new_with_progress(
163        model_name: &str,
164        progress_callback: Option<ModelDownloadCallback>,
165    ) -> Result<Self> {
166        use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
167
168        let model = match model_name {
169            // Current models
170            "BAAI/bge-small-en-v1.5" => EmbeddingModel::BGESmallENV15,
171            "sentence-transformers/all-MiniLM-L6-v2" => EmbeddingModel::AllMiniLML6V2,
172
173            // Enhanced models with longer context
174            "nomic-embed-text-v1" => EmbeddingModel::NomicEmbedTextV1,
175            "nomic-embed-text-v1.5" => EmbeddingModel::NomicEmbedTextV15,
176            "jina-embeddings-v2-base-code" => EmbeddingModel::JinaEmbeddingsV2BaseCode,
177
178            // BGE variants
179            "BAAI/bge-base-en-v1.5" => EmbeddingModel::BGEBaseENV15,
180            "BAAI/bge-large-en-v1.5" => EmbeddingModel::BGELargeENV15,
181
182            // Default to Nomic v1.5 for better performance
183            _ => EmbeddingModel::NomicEmbedTextV15,
184        };
185
186        // Configure permanent model cache directory
187        let model_cache_dir = model_cache_root()?;
188        std::fs::create_dir_all(&model_cache_dir)?;
189
190        if let Some(ref callback) = progress_callback {
191            callback(&format!("Initializing model: {}", model_name));
192
193            // Check if model already exists
194            let model_exists = Self::check_model_exists(&model_cache_dir, model_name);
195            if !model_exists {
196                callback(&format!(
197                    "Downloading model {} to {}",
198                    model_name,
199                    model_cache_dir.display()
200                ));
201            } else {
202                callback(&format!("Using cached model: {}", model_name));
203            }
204        }
205
206        // Configure max_length based on model capacity
207        let max_length = match model {
208            // Small models - keep at 512
209            EmbeddingModel::BGESmallENV15 | EmbeddingModel::AllMiniLML6V2 => 512,
210            EmbeddingModel::BGEBaseENV15 => 512,
211
212            // Large context models - use their full capacity!
213            EmbeddingModel::NomicEmbedTextV1 | EmbeddingModel::NomicEmbedTextV15 => 8192,
214            EmbeddingModel::JinaEmbeddingsV2BaseCode => 8192,
215
216            // BGE large can handle more
217            EmbeddingModel::BGELargeENV15 => 512, // Conservative for BGE
218
219            _ => 512, // Safe default
220        };
221
222        let init_options = InitOptions::new(model.clone())
223            .with_show_download_progress(progress_callback.is_some())
224            .with_cache_dir(model_cache_dir)
225            .with_max_length(max_length);
226
227        let embedding = TextEmbedding::try_new(init_options)?;
228
229        if let Some(ref callback) = progress_callback {
230            callback("Model loaded successfully");
231        }
232
233        let dim = match model {
234            // Small models (384 dimensions)
235            EmbeddingModel::BGESmallENV15 => 384,
236            EmbeddingModel::AllMiniLML6V2 => 384,
237
238            // Large context models (768 dimensions)
239            EmbeddingModel::NomicEmbedTextV1 => 768,
240            EmbeddingModel::NomicEmbedTextV15 => 768,
241            EmbeddingModel::JinaEmbeddingsV2BaseCode => 768,
242            EmbeddingModel::BGEBaseENV15 => 768,
243
244            // Large models (1024 dimensions)
245            EmbeddingModel::BGELargeENV15 => 1024,
246
247            _ => 384, // Default to 384 for BGE default
248        };
249
250        Ok(Self {
251            model: embedding,
252            dim,
253            model_name: model_name.to_string(),
254        })
255    }
256
257    fn check_model_exists(cache_dir: &Path, model_name: &str) -> bool {
258        // Simple heuristic - check if model directory exists
259        let model_dir = cache_dir.join(model_name.replace("/", "_"));
260        model_dir.exists()
261    }
262}
263
264#[cfg(feature = "fastembed")]
265impl Embedder for FastEmbedder {
266    fn id(&self) -> &'static str {
267        "fastembed"
268    }
269
270    fn dim(&self) -> usize {
271        self.dim
272    }
273
274    fn model_name(&self) -> &str {
275        &self.model_name
276    }
277
278    fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
279        let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
280        let embeddings = self.model.embed(text_refs, None)?;
281        Ok(embeddings)
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    #[test]
290    fn test_dummy_embedder() {
291        let mut embedder = DummyEmbedder::new();
292
293        assert_eq!(embedder.id(), "dummy");
294        assert_eq!(embedder.dim(), 384);
295
296        let texts = vec!["hello".to_string(), "world".to_string()];
297        let embeddings = embedder.embed(&texts).unwrap();
298
299        assert_eq!(embeddings.len(), 2);
300        assert_eq!(embeddings[0].len(), 384);
301        assert_eq!(embeddings[1].len(), 384);
302
303        // Dummy embedder should return all zeros
304        assert!(embeddings[0].iter().all(|&x| x == 0.0));
305        assert!(embeddings[1].iter().all(|&x| x == 0.0));
306    }
307
308    #[test]
309    fn test_create_embedder_dummy() {
310        #[cfg(not(feature = "fastembed"))]
311        {
312            let embedder = create_embedder(None).unwrap();
313            assert_eq!(embedder.id(), "dummy");
314            assert_eq!(embedder.dim(), 384);
315        }
316    }
317
318    #[test]
319    fn test_embedder_trait_object() {
320        let mut embedder: Box<dyn Embedder> = Box::new(DummyEmbedder::new());
321
322        let texts = vec!["test".to_string()];
323        let result = embedder.embed(&texts);
324        assert!(result.is_ok());
325
326        let embeddings = result.unwrap();
327        assert_eq!(embeddings.len(), 1);
328        assert_eq!(embeddings[0].len(), 384);
329    }
330
331    #[cfg(feature = "fastembed")]
332    #[test]
333    fn test_fastembed_creation() {
334        // This test requires downloading models, so we'll skip it in CI
335        if std::env::var("CI").is_ok() {
336            return;
337        }
338
339        let embedder = FastEmbedder::new("BAAI/bge-small-en-v1.5");
340
341        // FastEmbed creation might fail due to network issues or missing models
342        // In a real test environment, you'd want to ensure models are available
343        match embedder {
344            Ok(mut embedder) => {
345                assert_eq!(embedder.id(), "fastembed");
346                assert_eq!(embedder.dim(), 384);
347
348                let texts = vec!["hello world".to_string()];
349                let result = embedder.embed(&texts);
350                assert!(result.is_ok());
351
352                let embeddings = result.unwrap();
353                assert_eq!(embeddings.len(), 1);
354                assert_eq!(embeddings[0].len(), 384);
355
356                // Real embeddings should not be all zeros
357                assert!(!embeddings[0].iter().all(|&x| x == 0.0));
358            }
359            Err(_) => {
360                // In test environments, FastEmbed might not be available
361                // This is acceptable for unit tests
362            }
363        }
364    }
365
366    #[cfg(feature = "fastembed")]
367    #[test]
368    fn test_create_embedder_fastembed() {
369        if std::env::var("CI").is_ok() {
370            return;
371        }
372
373        let embedder = create_embedder(Some("BAAI/bge-small-en-v1.5"));
374
375        match embedder {
376            Ok(embedder) => {
377                assert_eq!(embedder.id(), "fastembed");
378                assert_eq!(embedder.dim(), 384);
379            }
380            Err(_) => {
381                // Model might not be available in test environment
382            }
383        }
384    }
385
386    #[test]
387    fn test_embedder_empty_input() {
388        let mut embedder = DummyEmbedder::new();
389        let texts: Vec<String> = vec![];
390        let embeddings = embedder.embed(&texts).unwrap();
391        assert_eq!(embeddings.len(), 0);
392    }
393
394    #[test]
395    fn test_embedder_single_text() {
396        let mut embedder = DummyEmbedder::new();
397        let texts = vec!["single text".to_string()];
398        let embeddings = embedder.embed(&texts).unwrap();
399
400        assert_eq!(embeddings.len(), 1);
401        assert_eq!(embeddings[0].len(), 384);
402    }
403
404    #[test]
405    fn test_embedder_multiple_texts() {
406        let mut embedder = DummyEmbedder::new();
407        let texts = vec![
408            "first text".to_string(),
409            "second text".to_string(),
410            "third text".to_string(),
411        ];
412        let embeddings = embedder.embed(&texts).unwrap();
413
414        assert_eq!(embeddings.len(), 3);
415        for embedding in &embeddings {
416            assert_eq!(embedding.len(), 384);
417        }
418    }
419}