ck_embed/
lib.rs

1use anyhow::Result;
2
3#[cfg(feature = "fastembed")]
4use std::path::{Path, PathBuf};
5
6pub mod reranker;
7pub mod tokenizer;
8
9pub use reranker::{RerankResult, Reranker, create_reranker, create_reranker_with_progress};
10pub use tokenizer::TokenEstimator;
11
12pub trait Embedder: Send + Sync {
13    fn id(&self) -> &'static str;
14    fn dim(&self) -> usize;
15    fn model_name(&self) -> &str;
16    fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
17}
18
19pub type ModelDownloadCallback = Box<dyn Fn(&str) + Send + Sync>;
20
21pub fn create_embedder(model_name: Option<&str>) -> Result<Box<dyn Embedder>> {
22    create_embedder_with_progress(model_name, None)
23}
24
25pub fn create_embedder_with_progress(
26    model_name: Option<&str>,
27    progress_callback: Option<ModelDownloadCallback>,
28) -> Result<Box<dyn Embedder>> {
29    let model = model_name.unwrap_or("BAAI/bge-small-en-v1.5");
30
31    #[cfg(feature = "fastembed")]
32    {
33        Ok(Box::new(FastEmbedder::new_with_progress(
34            model,
35            progress_callback,
36        )?))
37    }
38
39    #[cfg(not(feature = "fastembed"))]
40    {
41        if let Some(callback) = progress_callback {
42            callback("Using dummy embedder (no model download required)");
43        }
44        Ok(Box::new(DummyEmbedder::new_with_model(model)))
45    }
46}
47
48pub struct DummyEmbedder {
49    dim: usize,
50    model_name: String,
51}
52
53impl Default for DummyEmbedder {
54    fn default() -> Self {
55        Self::new()
56    }
57}
58
59impl DummyEmbedder {
60    pub fn new() -> Self {
61        Self {
62            dim: 384, // Match default BGE model
63            model_name: "dummy".to_string(),
64        }
65    }
66
67    pub fn new_with_model(model_name: &str) -> Self {
68        Self {
69            dim: 384, // Match default BGE model
70            model_name: model_name.to_string(),
71        }
72    }
73}
74
75impl Embedder for DummyEmbedder {
76    fn id(&self) -> &'static str {
77        "dummy"
78    }
79
80    fn dim(&self) -> usize {
81        self.dim
82    }
83
84    fn model_name(&self) -> &str {
85        &self.model_name
86    }
87
88    fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
89        Ok(texts.iter().map(|_| vec![0.0; self.dim]).collect())
90    }
91}
92
93#[cfg(feature = "fastembed")]
94pub struct FastEmbedder {
95    model: fastembed::TextEmbedding,
96    dim: usize,
97    model_name: String,
98}
99
100#[cfg(feature = "fastembed")]
101impl FastEmbedder {
102    pub fn new(model_name: &str) -> Result<Self> {
103        Self::new_with_progress(model_name, None)
104    }
105
106    pub fn new_with_progress(
107        model_name: &str,
108        progress_callback: Option<ModelDownloadCallback>,
109    ) -> Result<Self> {
110        use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
111
112        let model = match model_name {
113            // Current models
114            "BAAI/bge-small-en-v1.5" => EmbeddingModel::BGESmallENV15,
115            "sentence-transformers/all-MiniLM-L6-v2" => EmbeddingModel::AllMiniLML6V2,
116
117            // Enhanced models with longer context
118            "nomic-embed-text-v1" => EmbeddingModel::NomicEmbedTextV1,
119            "nomic-embed-text-v1.5" => EmbeddingModel::NomicEmbedTextV15,
120            "jina-embeddings-v2-base-code" => EmbeddingModel::JinaEmbeddingsV2BaseCode,
121
122            // BGE variants
123            "BAAI/bge-base-en-v1.5" => EmbeddingModel::BGEBaseENV15,
124            "BAAI/bge-large-en-v1.5" => EmbeddingModel::BGELargeENV15,
125
126            // Default to Nomic v1.5 for better performance
127            _ => EmbeddingModel::NomicEmbedTextV15,
128        };
129
130        // Configure permanent model cache directory
131        let model_cache_dir = Self::get_model_cache_dir()?;
132        std::fs::create_dir_all(&model_cache_dir)?;
133
134        if let Some(ref callback) = progress_callback {
135            callback(&format!("Initializing model: {}", model_name));
136
137            // Check if model already exists
138            let model_exists = Self::check_model_exists(&model_cache_dir, model_name);
139            if !model_exists {
140                callback(&format!(
141                    "Downloading model {} to {}",
142                    model_name,
143                    model_cache_dir.display()
144                ));
145            } else {
146                callback(&format!("Using cached model: {}", model_name));
147            }
148        }
149
150        // Configure max_length based on model capacity
151        let max_length = match model {
152            // Small models - keep at 512
153            EmbeddingModel::BGESmallENV15 | EmbeddingModel::AllMiniLML6V2 => 512,
154            EmbeddingModel::BGEBaseENV15 => 512,
155
156            // Large context models - use their full capacity!
157            EmbeddingModel::NomicEmbedTextV1 | EmbeddingModel::NomicEmbedTextV15 => 8192,
158            EmbeddingModel::JinaEmbeddingsV2BaseCode => 8192,
159
160            // BGE large can handle more
161            EmbeddingModel::BGELargeENV15 => 512, // Conservative for BGE
162
163            _ => 512, // Safe default
164        };
165
166        let init_options = InitOptions::new(model.clone())
167            .with_show_download_progress(progress_callback.is_some())
168            .with_cache_dir(model_cache_dir)
169            .with_max_length(max_length);
170
171        let embedding = TextEmbedding::try_new(init_options)?;
172
173        if let Some(ref callback) = progress_callback {
174            callback("Model loaded successfully");
175        }
176
177        let dim = match model {
178            // Small models (384 dimensions)
179            EmbeddingModel::BGESmallENV15 => 384,
180            EmbeddingModel::AllMiniLML6V2 => 384,
181
182            // Large context models (768 dimensions)
183            EmbeddingModel::NomicEmbedTextV1 => 768,
184            EmbeddingModel::NomicEmbedTextV15 => 768,
185            EmbeddingModel::JinaEmbeddingsV2BaseCode => 768,
186            EmbeddingModel::BGEBaseENV15 => 768,
187
188            // Large models (1024 dimensions)
189            EmbeddingModel::BGELargeENV15 => 1024,
190
191            _ => 384, // Default to 384 for BGE default
192        };
193
194        Ok(Self {
195            model: embedding,
196            dim,
197            model_name: model_name.to_string(),
198        })
199    }
200
201    fn get_model_cache_dir() -> Result<PathBuf> {
202        // Use platform-appropriate cache directory
203        let cache_dir = if let Some(cache_home) = std::env::var_os("XDG_CACHE_HOME") {
204            PathBuf::from(cache_home).join("ck")
205        } else if let Some(home) = std::env::var_os("HOME") {
206            PathBuf::from(home).join(".cache").join("ck")
207        } else if let Some(appdata) = std::env::var_os("LOCALAPPDATA") {
208            PathBuf::from(appdata).join("ck").join("cache")
209        } else {
210            // Fallback to current directory if no home found
211            PathBuf::from(".ck_models")
212        };
213
214        Ok(cache_dir.join("models"))
215    }
216
217    fn check_model_exists(cache_dir: &Path, model_name: &str) -> bool {
218        // Simple heuristic - check if model directory exists
219        let model_dir = cache_dir.join(model_name.replace("/", "_"));
220        model_dir.exists()
221    }
222}
223
224#[cfg(feature = "fastembed")]
225impl Embedder for FastEmbedder {
226    fn id(&self) -> &'static str {
227        "fastembed"
228    }
229
230    fn dim(&self) -> usize {
231        self.dim
232    }
233
234    fn model_name(&self) -> &str {
235        &self.model_name
236    }
237
238    fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
239        let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
240        let embeddings = self.model.embed(text_refs, None)?;
241        Ok(embeddings)
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    #[test]
250    fn test_dummy_embedder() {
251        let mut embedder = DummyEmbedder::new();
252
253        assert_eq!(embedder.id(), "dummy");
254        assert_eq!(embedder.dim(), 384);
255
256        let texts = vec!["hello".to_string(), "world".to_string()];
257        let embeddings = embedder.embed(&texts).unwrap();
258
259        assert_eq!(embeddings.len(), 2);
260        assert_eq!(embeddings[0].len(), 384);
261        assert_eq!(embeddings[1].len(), 384);
262
263        // Dummy embedder should return all zeros
264        assert!(embeddings[0].iter().all(|&x| x == 0.0));
265        assert!(embeddings[1].iter().all(|&x| x == 0.0));
266    }
267
268    #[test]
269    fn test_create_embedder_dummy() {
270        #[cfg(not(feature = "fastembed"))]
271        {
272            let embedder = create_embedder(None).unwrap();
273            assert_eq!(embedder.id(), "dummy");
274            assert_eq!(embedder.dim(), 384);
275        }
276    }
277
278    #[test]
279    fn test_embedder_trait_object() {
280        let mut embedder: Box<dyn Embedder> = Box::new(DummyEmbedder::new());
281
282        let texts = vec!["test".to_string()];
283        let result = embedder.embed(&texts);
284        assert!(result.is_ok());
285
286        let embeddings = result.unwrap();
287        assert_eq!(embeddings.len(), 1);
288        assert_eq!(embeddings[0].len(), 384);
289    }
290
291    #[cfg(feature = "fastembed")]
292    #[test]
293    fn test_fastembed_creation() {
294        // This test requires downloading models, so we'll skip it in CI
295        if std::env::var("CI").is_ok() {
296            return;
297        }
298
299        let embedder = FastEmbedder::new("BAAI/bge-small-en-v1.5");
300
301        // FastEmbed creation might fail due to network issues or missing models
302        // In a real test environment, you'd want to ensure models are available
303        match embedder {
304            Ok(mut embedder) => {
305                assert_eq!(embedder.id(), "fastembed");
306                assert_eq!(embedder.dim(), 384);
307
308                let texts = vec!["hello world".to_string()];
309                let result = embedder.embed(&texts);
310                assert!(result.is_ok());
311
312                let embeddings = result.unwrap();
313                assert_eq!(embeddings.len(), 1);
314                assert_eq!(embeddings[0].len(), 384);
315
316                // Real embeddings should not be all zeros
317                assert!(!embeddings[0].iter().all(|&x| x == 0.0));
318            }
319            Err(_) => {
320                // In test environments, FastEmbed might not be available
321                // This is acceptable for unit tests
322            }
323        }
324    }
325
326    #[cfg(feature = "fastembed")]
327    #[test]
328    fn test_create_embedder_fastembed() {
329        if std::env::var("CI").is_ok() {
330            return;
331        }
332
333        let embedder = create_embedder(Some("BAAI/bge-small-en-v1.5"));
334
335        match embedder {
336            Ok(embedder) => {
337                assert_eq!(embedder.id(), "fastembed");
338                assert_eq!(embedder.dim(), 384);
339            }
340            Err(_) => {
341                // Model might not be available in test environment
342            }
343        }
344    }
345
346    #[test]
347    fn test_embedder_empty_input() {
348        let mut embedder = DummyEmbedder::new();
349        let texts: Vec<String> = vec![];
350        let embeddings = embedder.embed(&texts).unwrap();
351        assert_eq!(embeddings.len(), 0);
352    }
353
354    #[test]
355    fn test_embedder_single_text() {
356        let mut embedder = DummyEmbedder::new();
357        let texts = vec!["single text".to_string()];
358        let embeddings = embedder.embed(&texts).unwrap();
359
360        assert_eq!(embeddings.len(), 1);
361        assert_eq!(embeddings[0].len(), 384);
362    }
363
364    #[test]
365    fn test_embedder_multiple_texts() {
366        let mut embedder = DummyEmbedder::new();
367        let texts = vec![
368            "first text".to_string(),
369            "second text".to_string(),
370            "third text".to_string(),
371        ];
372        let embeddings = embedder.embed(&texts).unwrap();
373
374        assert_eq!(embeddings.len(), 3);
375        for embedding in &embeddings {
376            assert_eq!(embedding.len(), 384);
377        }
378    }
379}