ck_embed/
lib.rs

1use anyhow::Result;
2
3#[cfg(feature = "fastembed")]
4use std::path::{Path, PathBuf};
5
6pub mod reranker;
7pub mod tokenizer;
8
9pub use reranker::{RerankResult, Reranker, create_reranker, create_reranker_with_progress};
10pub use tokenizer::TokenEstimator;
11
12pub trait Embedder: Send + Sync {
13    fn id(&self) -> &'static str;
14    fn dim(&self) -> usize;
15    fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
16}
17
18pub type ModelDownloadCallback = Box<dyn Fn(&str) + Send + Sync>;
19
20pub fn create_embedder(model_name: Option<&str>) -> Result<Box<dyn Embedder>> {
21    create_embedder_with_progress(model_name, None)
22}
23
24pub fn create_embedder_with_progress(
25    model_name: Option<&str>,
26    progress_callback: Option<ModelDownloadCallback>,
27) -> Result<Box<dyn Embedder>> {
28    let model = model_name.unwrap_or("BAAI/bge-small-en-v1.5");
29
30    #[cfg(feature = "fastembed")]
31    {
32        Ok(Box::new(FastEmbedder::new_with_progress(
33            model,
34            progress_callback,
35        )?))
36    }
37
38    #[cfg(not(feature = "fastembed"))]
39    {
40        let _ = model; // Suppress unused variable warning
41        if let Some(callback) = progress_callback {
42            callback("Using dummy embedder (no model download required)");
43        }
44        Ok(Box::new(DummyEmbedder::new()))
45    }
46}
47
48pub struct DummyEmbedder {
49    dim: usize,
50}
51
52impl Default for DummyEmbedder {
53    fn default() -> Self {
54        Self::new()
55    }
56}
57
58impl DummyEmbedder {
59    pub fn new() -> Self {
60        Self { dim: 384 } // Match default BGE model
61    }
62}
63
64impl Embedder for DummyEmbedder {
65    fn id(&self) -> &'static str {
66        "dummy"
67    }
68
69    fn dim(&self) -> usize {
70        self.dim
71    }
72
73    fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
74        Ok(texts.iter().map(|_| vec![0.0; self.dim]).collect())
75    }
76}
77
78#[cfg(feature = "fastembed")]
79pub struct FastEmbedder {
80    model: fastembed::TextEmbedding,
81    dim: usize,
82}
83
84#[cfg(feature = "fastembed")]
85impl FastEmbedder {
86    pub fn new(model_name: &str) -> Result<Self> {
87        Self::new_with_progress(model_name, None)
88    }
89
90    pub fn new_with_progress(
91        model_name: &str,
92        progress_callback: Option<ModelDownloadCallback>,
93    ) -> Result<Self> {
94        use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
95
96        let model = match model_name {
97            // Current models
98            "BAAI/bge-small-en-v1.5" => EmbeddingModel::BGESmallENV15,
99            "sentence-transformers/all-MiniLM-L6-v2" => EmbeddingModel::AllMiniLML6V2,
100
101            // Enhanced models with longer context
102            "nomic-embed-text-v1" => EmbeddingModel::NomicEmbedTextV1,
103            "nomic-embed-text-v1.5" => EmbeddingModel::NomicEmbedTextV15,
104            "jina-embeddings-v2-base-code" => EmbeddingModel::JinaEmbeddingsV2BaseCode,
105
106            // BGE variants
107            "BAAI/bge-base-en-v1.5" => EmbeddingModel::BGEBaseENV15,
108            "BAAI/bge-large-en-v1.5" => EmbeddingModel::BGELargeENV15,
109
110            // Default to Nomic v1.5 for better performance
111            _ => EmbeddingModel::NomicEmbedTextV15,
112        };
113
114        // Configure permanent model cache directory
115        let model_cache_dir = Self::get_model_cache_dir()?;
116        std::fs::create_dir_all(&model_cache_dir)?;
117
118        if let Some(ref callback) = progress_callback {
119            callback(&format!("Initializing model: {}", model_name));
120
121            // Check if model already exists
122            let model_exists = Self::check_model_exists(&model_cache_dir, model_name);
123            if !model_exists {
124                callback(&format!(
125                    "Downloading model {} to {}",
126                    model_name,
127                    model_cache_dir.display()
128                ));
129            } else {
130                callback(&format!("Using cached model: {}", model_name));
131            }
132        }
133
134        // Configure max_length based on model capacity
135        let max_length = match model {
136            // Small models - keep at 512
137            EmbeddingModel::BGESmallENV15 | EmbeddingModel::AllMiniLML6V2 => 512,
138            EmbeddingModel::BGEBaseENV15 => 512,
139
140            // Large context models - use their full capacity!
141            EmbeddingModel::NomicEmbedTextV1 | EmbeddingModel::NomicEmbedTextV15 => 8192,
142            EmbeddingModel::JinaEmbeddingsV2BaseCode => 8192,
143
144            // BGE large can handle more
145            EmbeddingModel::BGELargeENV15 => 512, // Conservative for BGE
146
147            _ => 512, // Safe default
148        };
149
150        let init_options = InitOptions::new(model.clone())
151            .with_show_download_progress(progress_callback.is_some())
152            .with_cache_dir(model_cache_dir)
153            .with_max_length(max_length);
154
155        let embedding = TextEmbedding::try_new(init_options)?;
156
157        if let Some(ref callback) = progress_callback {
158            callback("Model loaded successfully");
159        }
160
161        let dim = match model {
162            // Small models (384 dimensions)
163            EmbeddingModel::BGESmallENV15 => 384,
164            EmbeddingModel::AllMiniLML6V2 => 384,
165
166            // Large context models (768 dimensions)
167            EmbeddingModel::NomicEmbedTextV1 => 768,
168            EmbeddingModel::NomicEmbedTextV15 => 768,
169            EmbeddingModel::JinaEmbeddingsV2BaseCode => 768,
170            EmbeddingModel::BGEBaseENV15 => 768,
171
172            // Large models (1024 dimensions)
173            EmbeddingModel::BGELargeENV15 => 1024,
174
175            _ => 384, // Default to 384 for BGE default
176        };
177
178        Ok(Self {
179            model: embedding,
180            dim,
181        })
182    }
183
184    fn get_model_cache_dir() -> Result<PathBuf> {
185        // Use platform-appropriate cache directory
186        let cache_dir = if let Some(cache_home) = std::env::var_os("XDG_CACHE_HOME") {
187            PathBuf::from(cache_home).join("ck")
188        } else if let Some(home) = std::env::var_os("HOME") {
189            PathBuf::from(home).join(".cache").join("ck")
190        } else if let Some(appdata) = std::env::var_os("LOCALAPPDATA") {
191            PathBuf::from(appdata).join("ck").join("cache")
192        } else {
193            // Fallback to current directory if no home found
194            PathBuf::from(".ck_models")
195        };
196
197        Ok(cache_dir.join("models"))
198    }
199
200    fn check_model_exists(cache_dir: &Path, model_name: &str) -> bool {
201        // Simple heuristic - check if model directory exists
202        let model_dir = cache_dir.join(model_name.replace("/", "_"));
203        model_dir.exists()
204    }
205}
206
207#[cfg(feature = "fastembed")]
208impl Embedder for FastEmbedder {
209    fn id(&self) -> &'static str {
210        "fastembed"
211    }
212
213    fn dim(&self) -> usize {
214        self.dim
215    }
216
217    fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
218        let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
219        let embeddings = self.model.embed(text_refs, None)?;
220        Ok(embeddings)
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    #[test]
229    fn test_dummy_embedder() {
230        let mut embedder = DummyEmbedder::new();
231
232        assert_eq!(embedder.id(), "dummy");
233        assert_eq!(embedder.dim(), 384);
234
235        let texts = vec!["hello".to_string(), "world".to_string()];
236        let embeddings = embedder.embed(&texts).unwrap();
237
238        assert_eq!(embeddings.len(), 2);
239        assert_eq!(embeddings[0].len(), 384);
240        assert_eq!(embeddings[1].len(), 384);
241
242        // Dummy embedder should return all zeros
243        assert!(embeddings[0].iter().all(|&x| x == 0.0));
244        assert!(embeddings[1].iter().all(|&x| x == 0.0));
245    }
246
247    #[test]
248    fn test_create_embedder_dummy() {
249        #[cfg(not(feature = "fastembed"))]
250        {
251            let embedder = create_embedder(None).unwrap();
252            assert_eq!(embedder.id(), "dummy");
253            assert_eq!(embedder.dim(), 384);
254        }
255    }
256
257    #[test]
258    fn test_embedder_trait_object() {
259        let mut embedder: Box<dyn Embedder> = Box::new(DummyEmbedder::new());
260
261        let texts = vec!["test".to_string()];
262        let result = embedder.embed(&texts);
263        assert!(result.is_ok());
264
265        let embeddings = result.unwrap();
266        assert_eq!(embeddings.len(), 1);
267        assert_eq!(embeddings[0].len(), 384);
268    }
269
270    #[cfg(feature = "fastembed")]
271    #[test]
272    fn test_fastembed_creation() {
273        // This test requires downloading models, so we'll skip it in CI
274        if std::env::var("CI").is_ok() {
275            return;
276        }
277
278        let embedder = FastEmbedder::new("BAAI/bge-small-en-v1.5");
279
280        // FastEmbed creation might fail due to network issues or missing models
281        // In a real test environment, you'd want to ensure models are available
282        match embedder {
283            Ok(mut embedder) => {
284                assert_eq!(embedder.id(), "fastembed");
285                assert_eq!(embedder.dim(), 384);
286
287                let texts = vec!["hello world".to_string()];
288                let result = embedder.embed(&texts);
289                assert!(result.is_ok());
290
291                let embeddings = result.unwrap();
292                assert_eq!(embeddings.len(), 1);
293                assert_eq!(embeddings[0].len(), 384);
294
295                // Real embeddings should not be all zeros
296                assert!(!embeddings[0].iter().all(|&x| x == 0.0));
297            }
298            Err(_) => {
299                // In test environments, FastEmbed might not be available
300                // This is acceptable for unit tests
301            }
302        }
303    }
304
305    #[cfg(feature = "fastembed")]
306    #[test]
307    fn test_create_embedder_fastembed() {
308        if std::env::var("CI").is_ok() {
309            return;
310        }
311
312        let embedder = create_embedder(Some("BAAI/bge-small-en-v1.5"));
313
314        match embedder {
315            Ok(embedder) => {
316                assert_eq!(embedder.id(), "fastembed");
317                assert_eq!(embedder.dim(), 384);
318            }
319            Err(_) => {
320                // Model might not be available in test environment
321            }
322        }
323    }
324
325    #[test]
326    fn test_embedder_empty_input() {
327        let mut embedder = DummyEmbedder::new();
328        let texts: Vec<String> = vec![];
329        let embeddings = embedder.embed(&texts).unwrap();
330        assert_eq!(embeddings.len(), 0);
331    }
332
333    #[test]
334    fn test_embedder_single_text() {
335        let mut embedder = DummyEmbedder::new();
336        let texts = vec!["single text".to_string()];
337        let embeddings = embedder.embed(&texts).unwrap();
338
339        assert_eq!(embeddings.len(), 1);
340        assert_eq!(embeddings[0].len(), 384);
341    }
342
343    #[test]
344    fn test_embedder_multiple_texts() {
345        let mut embedder = DummyEmbedder::new();
346        let texts = vec![
347            "first text".to_string(),
348            "second text".to_string(),
349            "third text".to_string(),
350        ];
351        let embeddings = embedder.embed(&texts).unwrap();
352
353        assert_eq!(embeddings.len(), 3);
354        for embedding in &embeddings {
355            assert_eq!(embedding.len(), 384);
356        }
357    }
358}