ck_embed/
lib.rs

1use anyhow::Result;
2use std::path::Path;
3use std::path::PathBuf;
4
5pub mod reranker;
6pub mod tokenizer;
7
8pub use reranker::{RerankResult, Reranker, create_reranker, create_reranker_with_progress};
9pub use tokenizer::TokenEstimator;
10
11pub trait Embedder: Send + Sync {
12    fn id(&self) -> &'static str;
13    fn dim(&self) -> usize;
14    fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
15}
16
17pub type ModelDownloadCallback = Box<dyn Fn(&str) + Send + Sync>;
18
19pub fn create_embedder(model_name: Option<&str>) -> Result<Box<dyn Embedder>> {
20    create_embedder_with_progress(model_name, None)
21}
22
23pub fn create_embedder_with_progress(
24    model_name: Option<&str>,
25    progress_callback: Option<ModelDownloadCallback>,
26) -> Result<Box<dyn Embedder>> {
27    let model = model_name.unwrap_or("BAAI/bge-small-en-v1.5");
28
29    #[cfg(feature = "fastembed")]
30    {
31        Ok(Box::new(FastEmbedder::new_with_progress(
32            model,
33            progress_callback,
34        )?))
35    }
36
37    #[cfg(not(feature = "fastembed"))]
38    {
39        if let Some(callback) = progress_callback {
40            callback("Using dummy embedder (no model download required)");
41        }
42        Ok(Box::new(DummyEmbedder::new()))
43    }
44}
45
46pub struct DummyEmbedder {
47    dim: usize,
48}
49
50impl Default for DummyEmbedder {
51    fn default() -> Self {
52        Self::new()
53    }
54}
55
56impl DummyEmbedder {
57    pub fn new() -> Self {
58        Self { dim: 384 } // Match default BGE model
59    }
60}
61
62impl Embedder for DummyEmbedder {
63    fn id(&self) -> &'static str {
64        "dummy"
65    }
66
67    fn dim(&self) -> usize {
68        self.dim
69    }
70
71    fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
72        Ok(texts.iter().map(|_| vec![0.0; self.dim]).collect())
73    }
74}
75
76#[cfg(feature = "fastembed")]
77pub struct FastEmbedder {
78    model: fastembed::TextEmbedding,
79    dim: usize,
80}
81
82#[cfg(feature = "fastembed")]
83impl FastEmbedder {
84    pub fn new(model_name: &str) -> Result<Self> {
85        Self::new_with_progress(model_name, None)
86    }
87
88    pub fn new_with_progress(
89        model_name: &str,
90        progress_callback: Option<ModelDownloadCallback>,
91    ) -> Result<Self> {
92        use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
93
94        let model = match model_name {
95            // Current models
96            "BAAI/bge-small-en-v1.5" => EmbeddingModel::BGESmallENV15,
97            "sentence-transformers/all-MiniLM-L6-v2" => EmbeddingModel::AllMiniLML6V2,
98
99            // Enhanced models with longer context
100            "nomic-embed-text-v1" => EmbeddingModel::NomicEmbedTextV1,
101            "nomic-embed-text-v1.5" => EmbeddingModel::NomicEmbedTextV15,
102            "jina-embeddings-v2-base-code" => EmbeddingModel::JinaEmbeddingsV2BaseCode,
103
104            // BGE variants
105            "BAAI/bge-base-en-v1.5" => EmbeddingModel::BGEBaseENV15,
106            "BAAI/bge-large-en-v1.5" => EmbeddingModel::BGELargeENV15,
107
108            // Default to Nomic v1.5 for better performance
109            _ => EmbeddingModel::NomicEmbedTextV15,
110        };
111
112        // Configure permanent model cache directory
113        let model_cache_dir = Self::get_model_cache_dir()?;
114        std::fs::create_dir_all(&model_cache_dir)?;
115
116        if let Some(ref callback) = progress_callback {
117            callback(&format!("Initializing model: {}", model_name));
118
119            // Check if model already exists
120            let model_exists = Self::check_model_exists(&model_cache_dir, model_name);
121            if !model_exists {
122                callback(&format!(
123                    "Downloading model {} to {}",
124                    model_name,
125                    model_cache_dir.display()
126                ));
127            } else {
128                callback(&format!("Using cached model: {}", model_name));
129            }
130        }
131
132        let init_options = InitOptions::new(model.clone())
133            .with_show_download_progress(progress_callback.is_some())
134            .with_cache_dir(model_cache_dir);
135
136        let embedding = TextEmbedding::try_new(init_options)?;
137
138        if let Some(ref callback) = progress_callback {
139            callback("Model loaded successfully");
140        }
141
142        let dim = match model {
143            // Small models (384 dimensions)
144            EmbeddingModel::BGESmallENV15 => 384,
145            EmbeddingModel::AllMiniLML6V2 => 384,
146
147            // Large context models (768 dimensions)
148            EmbeddingModel::NomicEmbedTextV1 => 768,
149            EmbeddingModel::NomicEmbedTextV15 => 768,
150            EmbeddingModel::JinaEmbeddingsV2BaseCode => 768,
151            EmbeddingModel::BGEBaseENV15 => 768,
152
153            // Large models (1024 dimensions)
154            EmbeddingModel::BGELargeENV15 => 1024,
155
156            _ => 384, // Default to 384 for BGE default
157        };
158
159        Ok(Self {
160            model: embedding,
161            dim,
162        })
163    }
164
165    fn get_model_cache_dir() -> Result<PathBuf> {
166        // Use platform-appropriate cache directory
167        let cache_dir = if let Some(cache_home) = std::env::var_os("XDG_CACHE_HOME") {
168            PathBuf::from(cache_home).join("ck")
169        } else if let Some(home) = std::env::var_os("HOME") {
170            PathBuf::from(home).join(".cache").join("ck")
171        } else if let Some(appdata) = std::env::var_os("LOCALAPPDATA") {
172            PathBuf::from(appdata).join("ck").join("cache")
173        } else {
174            // Fallback to current directory if no home found
175            PathBuf::from(".ck_models")
176        };
177
178        Ok(cache_dir.join("models"))
179    }
180
181    fn check_model_exists(cache_dir: &Path, model_name: &str) -> bool {
182        // Simple heuristic - check if model directory exists
183        let model_dir = cache_dir.join(model_name.replace("/", "_"));
184        model_dir.exists()
185    }
186}
187
188#[cfg(feature = "fastembed")]
189impl Embedder for FastEmbedder {
190    fn id(&self) -> &'static str {
191        "fastembed"
192    }
193
194    fn dim(&self) -> usize {
195        self.dim
196    }
197
198    fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
199        let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
200        let embeddings = self.model.embed(text_refs, None)?;
201        Ok(embeddings)
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208
209    #[test]
210    fn test_dummy_embedder() {
211        let mut embedder = DummyEmbedder::new();
212
213        assert_eq!(embedder.id(), "dummy");
214        assert_eq!(embedder.dim(), 384);
215
216        let texts = vec!["hello".to_string(), "world".to_string()];
217        let embeddings = embedder.embed(&texts).unwrap();
218
219        assert_eq!(embeddings.len(), 2);
220        assert_eq!(embeddings[0].len(), 384);
221        assert_eq!(embeddings[1].len(), 384);
222
223        // Dummy embedder should return all zeros
224        assert!(embeddings[0].iter().all(|&x| x == 0.0));
225        assert!(embeddings[1].iter().all(|&x| x == 0.0));
226    }
227
228    #[test]
229    fn test_create_embedder_dummy() {
230        #[cfg(not(feature = "fastembed"))]
231        {
232            let embedder = create_embedder(None).unwrap();
233            assert_eq!(embedder.id(), "dummy");
234            assert_eq!(embedder.dim(), 384);
235        }
236    }
237
238    #[test]
239    fn test_embedder_trait_object() {
240        let mut embedder: Box<dyn Embedder> = Box::new(DummyEmbedder::new());
241
242        let texts = vec!["test".to_string()];
243        let result = embedder.embed(&texts);
244        assert!(result.is_ok());
245
246        let embeddings = result.unwrap();
247        assert_eq!(embeddings.len(), 1);
248        assert_eq!(embeddings[0].len(), 384);
249    }
250
251    #[cfg(feature = "fastembed")]
252    #[test]
253    fn test_fastembed_creation() {
254        // This test requires downloading models, so we'll skip it in CI
255        if std::env::var("CI").is_ok() {
256            return;
257        }
258
259        let embedder = FastEmbedder::new("BAAI/bge-small-en-v1.5");
260
261        // FastEmbed creation might fail due to network issues or missing models
262        // In a real test environment, you'd want to ensure models are available
263        match embedder {
264            Ok(mut embedder) => {
265                assert_eq!(embedder.id(), "fastembed");
266                assert_eq!(embedder.dim(), 384);
267
268                let texts = vec!["hello world".to_string()];
269                let result = embedder.embed(&texts);
270                assert!(result.is_ok());
271
272                let embeddings = result.unwrap();
273                assert_eq!(embeddings.len(), 1);
274                assert_eq!(embeddings[0].len(), 384);
275
276                // Real embeddings should not be all zeros
277                assert!(!embeddings[0].iter().all(|&x| x == 0.0));
278            }
279            Err(_) => {
280                // In test environments, FastEmbed might not be available
281                // This is acceptable for unit tests
282            }
283        }
284    }
285
286    #[cfg(feature = "fastembed")]
287    #[test]
288    fn test_create_embedder_fastembed() {
289        if std::env::var("CI").is_ok() {
290            return;
291        }
292
293        let embedder = create_embedder(Some("BAAI/bge-small-en-v1.5"));
294
295        match embedder {
296            Ok(embedder) => {
297                assert_eq!(embedder.id(), "fastembed");
298                assert_eq!(embedder.dim(), 384);
299            }
300            Err(_) => {
301                // Model might not be available in test environment
302            }
303        }
304    }
305
306    #[test]
307    fn test_embedder_empty_input() {
308        let mut embedder = DummyEmbedder::new();
309        let texts: Vec<String> = vec![];
310        let embeddings = embedder.embed(&texts).unwrap();
311        assert_eq!(embeddings.len(), 0);
312    }
313
314    #[test]
315    fn test_embedder_single_text() {
316        let mut embedder = DummyEmbedder::new();
317        let texts = vec!["single text".to_string()];
318        let embeddings = embedder.embed(&texts).unwrap();
319
320        assert_eq!(embeddings.len(), 1);
321        assert_eq!(embeddings[0].len(), 384);
322    }
323
324    #[test]
325    fn test_embedder_multiple_texts() {
326        let mut embedder = DummyEmbedder::new();
327        let texts = vec![
328            "first text".to_string(),
329            "second text".to_string(),
330            "third text".to_string(),
331        ];
332        let embeddings = embedder.embed(&texts).unwrap();
333
334        assert_eq!(embeddings.len(), 3);
335        for embedding in &embeddings {
336            assert_eq!(embedding.len(), 384);
337        }
338    }
339}