ck_embed/
lib.rs

1use anyhow::Result;
2
3#[cfg(feature = "fastembed")]
4use std::path::{Path, PathBuf};
5
6pub mod reranker;
7pub mod tokenizer;
8
9pub use reranker::{RerankResult, Reranker, create_reranker, create_reranker_with_progress};
10pub use tokenizer::TokenEstimator;
11
12pub trait Embedder: Send + Sync {
13    fn id(&self) -> &'static str;
14    fn dim(&self) -> usize;
15    fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
16}
17
18pub type ModelDownloadCallback = Box<dyn Fn(&str) + Send + Sync>;
19
20pub fn create_embedder(model_name: Option<&str>) -> Result<Box<dyn Embedder>> {
21    create_embedder_with_progress(model_name, None)
22}
23
24pub fn create_embedder_with_progress(
25    model_name: Option<&str>,
26    progress_callback: Option<ModelDownloadCallback>,
27) -> Result<Box<dyn Embedder>> {
28    let model = model_name.unwrap_or("BAAI/bge-small-en-v1.5");
29
30    #[cfg(feature = "fastembed")]
31    {
32        Ok(Box::new(FastEmbedder::new_with_progress(
33            model,
34            progress_callback,
35        )?))
36    }
37
38    #[cfg(not(feature = "fastembed"))]
39    {
40        let _ = model; // Suppress unused variable warning
41        if let Some(callback) = progress_callback {
42            callback("Using dummy embedder (no model download required)");
43        }
44        Ok(Box::new(DummyEmbedder::new()))
45    }
46}
47
48pub struct DummyEmbedder {
49    dim: usize,
50}
51
52impl Default for DummyEmbedder {
53    fn default() -> Self {
54        Self::new()
55    }
56}
57
58impl DummyEmbedder {
59    pub fn new() -> Self {
60        Self { dim: 384 } // Match default BGE model
61    }
62}
63
64impl Embedder for DummyEmbedder {
65    fn id(&self) -> &'static str {
66        "dummy"
67    }
68
69    fn dim(&self) -> usize {
70        self.dim
71    }
72
73    fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
74        Ok(texts.iter().map(|_| vec![0.0; self.dim]).collect())
75    }
76}
77
78#[cfg(feature = "fastembed")]
79pub struct FastEmbedder {
80    model: fastembed::TextEmbedding,
81    dim: usize,
82}
83
84#[cfg(feature = "fastembed")]
85impl FastEmbedder {
86    pub fn new(model_name: &str) -> Result<Self> {
87        Self::new_with_progress(model_name, None)
88    }
89
90    pub fn new_with_progress(
91        model_name: &str,
92        progress_callback: Option<ModelDownloadCallback>,
93    ) -> Result<Self> {
94        use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
95
96        let model = match model_name {
97            // Current models
98            "BAAI/bge-small-en-v1.5" => EmbeddingModel::BGESmallENV15,
99            "sentence-transformers/all-MiniLM-L6-v2" => EmbeddingModel::AllMiniLML6V2,
100
101            // Enhanced models with longer context
102            "nomic-embed-text-v1" => EmbeddingModel::NomicEmbedTextV1,
103            "nomic-embed-text-v1.5" => EmbeddingModel::NomicEmbedTextV15,
104            "jina-embeddings-v2-base-code" => EmbeddingModel::JinaEmbeddingsV2BaseCode,
105
106            // BGE variants
107            "BAAI/bge-base-en-v1.5" => EmbeddingModel::BGEBaseENV15,
108            "BAAI/bge-large-en-v1.5" => EmbeddingModel::BGELargeENV15,
109
110            // Default to Nomic v1.5 for better performance
111            _ => EmbeddingModel::NomicEmbedTextV15,
112        };
113
114        // Configure permanent model cache directory
115        let model_cache_dir = Self::get_model_cache_dir()?;
116        std::fs::create_dir_all(&model_cache_dir)?;
117
118        if let Some(ref callback) = progress_callback {
119            callback(&format!("Initializing model: {}", model_name));
120
121            // Check if model already exists
122            let model_exists = Self::check_model_exists(&model_cache_dir, model_name);
123            if !model_exists {
124                callback(&format!(
125                    "Downloading model {} to {}",
126                    model_name,
127                    model_cache_dir.display()
128                ));
129            } else {
130                callback(&format!("Using cached model: {}", model_name));
131            }
132        }
133
134        let init_options = InitOptions::new(model.clone())
135            .with_show_download_progress(progress_callback.is_some())
136            .with_cache_dir(model_cache_dir);
137
138        let embedding = TextEmbedding::try_new(init_options)?;
139
140        if let Some(ref callback) = progress_callback {
141            callback("Model loaded successfully");
142        }
143
144        let dim = match model {
145            // Small models (384 dimensions)
146            EmbeddingModel::BGESmallENV15 => 384,
147            EmbeddingModel::AllMiniLML6V2 => 384,
148
149            // Large context models (768 dimensions)
150            EmbeddingModel::NomicEmbedTextV1 => 768,
151            EmbeddingModel::NomicEmbedTextV15 => 768,
152            EmbeddingModel::JinaEmbeddingsV2BaseCode => 768,
153            EmbeddingModel::BGEBaseENV15 => 768,
154
155            // Large models (1024 dimensions)
156            EmbeddingModel::BGELargeENV15 => 1024,
157
158            _ => 384, // Default to 384 for BGE default
159        };
160
161        Ok(Self {
162            model: embedding,
163            dim,
164        })
165    }
166
167    fn get_model_cache_dir() -> Result<PathBuf> {
168        // Use platform-appropriate cache directory
169        let cache_dir = if let Some(cache_home) = std::env::var_os("XDG_CACHE_HOME") {
170            PathBuf::from(cache_home).join("ck")
171        } else if let Some(home) = std::env::var_os("HOME") {
172            PathBuf::from(home).join(".cache").join("ck")
173        } else if let Some(appdata) = std::env::var_os("LOCALAPPDATA") {
174            PathBuf::from(appdata).join("ck").join("cache")
175        } else {
176            // Fallback to current directory if no home found
177            PathBuf::from(".ck_models")
178        };
179
180        Ok(cache_dir.join("models"))
181    }
182
183    fn check_model_exists(cache_dir: &Path, model_name: &str) -> bool {
184        // Simple heuristic - check if model directory exists
185        let model_dir = cache_dir.join(model_name.replace("/", "_"));
186        model_dir.exists()
187    }
188}
189
190#[cfg(feature = "fastembed")]
191impl Embedder for FastEmbedder {
192    fn id(&self) -> &'static str {
193        "fastembed"
194    }
195
196    fn dim(&self) -> usize {
197        self.dim
198    }
199
200    fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
201        let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
202        let embeddings = self.model.embed(text_refs, None)?;
203        Ok(embeddings)
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    #[test]
212    fn test_dummy_embedder() {
213        let mut embedder = DummyEmbedder::new();
214
215        assert_eq!(embedder.id(), "dummy");
216        assert_eq!(embedder.dim(), 384);
217
218        let texts = vec!["hello".to_string(), "world".to_string()];
219        let embeddings = embedder.embed(&texts).unwrap();
220
221        assert_eq!(embeddings.len(), 2);
222        assert_eq!(embeddings[0].len(), 384);
223        assert_eq!(embeddings[1].len(), 384);
224
225        // Dummy embedder should return all zeros
226        assert!(embeddings[0].iter().all(|&x| x == 0.0));
227        assert!(embeddings[1].iter().all(|&x| x == 0.0));
228    }
229
230    #[test]
231    fn test_create_embedder_dummy() {
232        #[cfg(not(feature = "fastembed"))]
233        {
234            let embedder = create_embedder(None).unwrap();
235            assert_eq!(embedder.id(), "dummy");
236            assert_eq!(embedder.dim(), 384);
237        }
238    }
239
240    #[test]
241    fn test_embedder_trait_object() {
242        let mut embedder: Box<dyn Embedder> = Box::new(DummyEmbedder::new());
243
244        let texts = vec!["test".to_string()];
245        let result = embedder.embed(&texts);
246        assert!(result.is_ok());
247
248        let embeddings = result.unwrap();
249        assert_eq!(embeddings.len(), 1);
250        assert_eq!(embeddings[0].len(), 384);
251    }
252
253    #[cfg(feature = "fastembed")]
254    #[test]
255    fn test_fastembed_creation() {
256        // This test requires downloading models, so we'll skip it in CI
257        if std::env::var("CI").is_ok() {
258            return;
259        }
260
261        let embedder = FastEmbedder::new("BAAI/bge-small-en-v1.5");
262
263        // FastEmbed creation might fail due to network issues or missing models
264        // In a real test environment, you'd want to ensure models are available
265        match embedder {
266            Ok(mut embedder) => {
267                assert_eq!(embedder.id(), "fastembed");
268                assert_eq!(embedder.dim(), 384);
269
270                let texts = vec!["hello world".to_string()];
271                let result = embedder.embed(&texts);
272                assert!(result.is_ok());
273
274                let embeddings = result.unwrap();
275                assert_eq!(embeddings.len(), 1);
276                assert_eq!(embeddings[0].len(), 384);
277
278                // Real embeddings should not be all zeros
279                assert!(!embeddings[0].iter().all(|&x| x == 0.0));
280            }
281            Err(_) => {
282                // In test environments, FastEmbed might not be available
283                // This is acceptable for unit tests
284            }
285        }
286    }
287
288    #[cfg(feature = "fastembed")]
289    #[test]
290    fn test_create_embedder_fastembed() {
291        if std::env::var("CI").is_ok() {
292            return;
293        }
294
295        let embedder = create_embedder(Some("BAAI/bge-small-en-v1.5"));
296
297        match embedder {
298            Ok(embedder) => {
299                assert_eq!(embedder.id(), "fastembed");
300                assert_eq!(embedder.dim(), 384);
301            }
302            Err(_) => {
303                // Model might not be available in test environment
304            }
305        }
306    }
307
308    #[test]
309    fn test_embedder_empty_input() {
310        let mut embedder = DummyEmbedder::new();
311        let texts: Vec<String> = vec![];
312        let embeddings = embedder.embed(&texts).unwrap();
313        assert_eq!(embeddings.len(), 0);
314    }
315
316    #[test]
317    fn test_embedder_single_text() {
318        let mut embedder = DummyEmbedder::new();
319        let texts = vec!["single text".to_string()];
320        let embeddings = embedder.embed(&texts).unwrap();
321
322        assert_eq!(embeddings.len(), 1);
323        assert_eq!(embeddings[0].len(), 384);
324    }
325
326    #[test]
327    fn test_embedder_multiple_texts() {
328        let mut embedder = DummyEmbedder::new();
329        let texts = vec![
330            "first text".to_string(),
331            "second text".to_string(),
332            "third text".to_string(),
333        ];
334        let embeddings = embedder.embed(&texts).unwrap();
335
336        assert_eq!(embeddings.len(), 3);
337        for embedding in &embeddings {
338            assert_eq!(embedding.len(), 384);
339        }
340    }
341}