ck_embed/
lib.rs

1use anyhow::Result;
2use std::path::Path;
3use std::path::PathBuf;
4
5pub trait Embedder: Send + Sync {
6    fn id(&self) -> &'static str;
7    fn dim(&self) -> usize;
8    fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
9}
10
11pub type ModelDownloadCallback = Box<dyn Fn(&str) + Send + Sync>;
12
13pub fn create_embedder(model_name: Option<&str>) -> Result<Box<dyn Embedder>> {
14    create_embedder_with_progress(model_name, None)
15}
16
17pub fn create_embedder_with_progress(
18    model_name: Option<&str>,
19    progress_callback: Option<ModelDownloadCallback>,
20) -> Result<Box<dyn Embedder>> {
21    let model = model_name.unwrap_or("BAAI/bge-small-en-v1.5");
22
23    #[cfg(feature = "fastembed")]
24    {
25        Ok(Box::new(FastEmbedder::new_with_progress(
26            model,
27            progress_callback,
28        )?))
29    }
30
31    #[cfg(not(feature = "fastembed"))]
32    {
33        if let Some(callback) = progress_callback {
34            callback("Using dummy embedder (no model download required)");
35        }
36        Ok(Box::new(DummyEmbedder::new()))
37    }
38}
39
40pub struct DummyEmbedder {
41    dim: usize,
42}
43
44impl Default for DummyEmbedder {
45    fn default() -> Self {
46        Self::new()
47    }
48}
49
50impl DummyEmbedder {
51    pub fn new() -> Self {
52        Self { dim: 384 }
53    }
54}
55
56impl Embedder for DummyEmbedder {
57    fn id(&self) -> &'static str {
58        "dummy"
59    }
60
61    fn dim(&self) -> usize {
62        self.dim
63    }
64
65    fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
66        Ok(texts.iter().map(|_| vec![0.0; self.dim]).collect())
67    }
68}
69
70#[cfg(feature = "fastembed")]
71pub struct FastEmbedder {
72    model: fastembed::TextEmbedding,
73    dim: usize,
74}
75
76#[cfg(feature = "fastembed")]
77impl FastEmbedder {
78    pub fn new(model_name: &str) -> Result<Self> {
79        Self::new_with_progress(model_name, None)
80    }
81
82    pub fn new_with_progress(
83        model_name: &str,
84        progress_callback: Option<ModelDownloadCallback>,
85    ) -> Result<Self> {
86        use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
87
88        let model = match model_name {
89            "BAAI/bge-small-en-v1.5" => EmbeddingModel::BGESmallENV15,
90            "sentence-transformers/all-MiniLM-L6-v2" => EmbeddingModel::AllMiniLML6V2,
91            _ => EmbeddingModel::BGESmallENV15,
92        };
93
94        // Configure permanent model cache directory
95        let model_cache_dir = Self::get_model_cache_dir()?;
96        std::fs::create_dir_all(&model_cache_dir)?;
97
98        if let Some(ref callback) = progress_callback {
99            callback(&format!("Initializing model: {}", model_name));
100
101            // Check if model already exists
102            let model_exists = Self::check_model_exists(&model_cache_dir, model_name);
103            if !model_exists {
104                callback(&format!(
105                    "Downloading model {} to {}",
106                    model_name,
107                    model_cache_dir.display()
108                ));
109            } else {
110                callback(&format!("Using cached model: {}", model_name));
111            }
112        }
113
114        let init_options = InitOptions::new(model.clone())
115            .with_show_download_progress(progress_callback.is_some())
116            .with_cache_dir(model_cache_dir);
117
118        let embedding = TextEmbedding::try_new(init_options)?;
119
120        if let Some(ref callback) = progress_callback {
121            callback("Model loaded successfully");
122        }
123
124        let dim = match model {
125            EmbeddingModel::BGESmallENV15 => 384,
126            EmbeddingModel::AllMiniLML6V2 => 384,
127            _ => 384,
128        };
129
130        Ok(Self {
131            model: embedding,
132            dim,
133        })
134    }
135
136    fn get_model_cache_dir() -> Result<PathBuf> {
137        // Use platform-appropriate cache directory
138        let cache_dir = if let Some(cache_home) = std::env::var_os("XDG_CACHE_HOME") {
139            PathBuf::from(cache_home).join("ck")
140        } else if let Some(home) = std::env::var_os("HOME") {
141            PathBuf::from(home).join(".cache").join("ck")
142        } else if let Some(appdata) = std::env::var_os("LOCALAPPDATA") {
143            PathBuf::from(appdata).join("ck").join("cache")
144        } else {
145            // Fallback to current directory if no home found
146            PathBuf::from(".ck_models")
147        };
148
149        Ok(cache_dir.join("models"))
150    }
151
152    fn check_model_exists(cache_dir: &Path, model_name: &str) -> bool {
153        // Simple heuristic - check if model directory exists
154        let model_dir = cache_dir.join(model_name.replace("/", "_"));
155        model_dir.exists()
156    }
157}
158
159#[cfg(feature = "fastembed")]
160impl Embedder for FastEmbedder {
161    fn id(&self) -> &'static str {
162        "fastembed"
163    }
164
165    fn dim(&self) -> usize {
166        self.dim
167    }
168
169    fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
170        let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
171        let embeddings = self.model.embed(text_refs, None)?;
172        Ok(embeddings)
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    #[test]
181    fn test_dummy_embedder() {
182        let mut embedder = DummyEmbedder::new();
183
184        assert_eq!(embedder.id(), "dummy");
185        assert_eq!(embedder.dim(), 384);
186
187        let texts = vec!["hello".to_string(), "world".to_string()];
188        let embeddings = embedder.embed(&texts).unwrap();
189
190        assert_eq!(embeddings.len(), 2);
191        assert_eq!(embeddings[0].len(), 384);
192        assert_eq!(embeddings[1].len(), 384);
193
194        // Dummy embedder should return all zeros
195        assert!(embeddings[0].iter().all(|&x| x == 0.0));
196        assert!(embeddings[1].iter().all(|&x| x == 0.0));
197    }
198
199    #[test]
200    fn test_create_embedder_dummy() {
201        #[cfg(not(feature = "fastembed"))]
202        {
203            let embedder = create_embedder(None).unwrap();
204            assert_eq!(embedder.id(), "dummy");
205            assert_eq!(embedder.dim(), 384);
206        }
207    }
208
209    #[test]
210    fn test_embedder_trait_object() {
211        let mut embedder: Box<dyn Embedder> = Box::new(DummyEmbedder::new());
212
213        let texts = vec!["test".to_string()];
214        let result = embedder.embed(&texts);
215        assert!(result.is_ok());
216
217        let embeddings = result.unwrap();
218        assert_eq!(embeddings.len(), 1);
219        assert_eq!(embeddings[0].len(), 384);
220    }
221
222    #[cfg(feature = "fastembed")]
223    #[test]
224    fn test_fastembed_creation() {
225        // This test requires downloading models, so we'll skip it in CI
226        if std::env::var("CI").is_ok() {
227            return;
228        }
229
230        let embedder = FastEmbedder::new("BAAI/bge-small-en-v1.5");
231
232        // FastEmbed creation might fail due to network issues or missing models
233        // In a real test environment, you'd want to ensure models are available
234        match embedder {
235            Ok(mut embedder) => {
236                assert_eq!(embedder.id(), "fastembed");
237                assert_eq!(embedder.dim(), 384);
238
239                let texts = vec!["hello world".to_string()];
240                let result = embedder.embed(&texts);
241                assert!(result.is_ok());
242
243                let embeddings = result.unwrap();
244                assert_eq!(embeddings.len(), 1);
245                assert_eq!(embeddings[0].len(), 384);
246
247                // Real embeddings should not be all zeros
248                assert!(!embeddings[0].iter().all(|&x| x == 0.0));
249            }
250            Err(_) => {
251                // In test environments, FastEmbed might not be available
252                // This is acceptable for unit tests
253            }
254        }
255    }
256
257    #[cfg(feature = "fastembed")]
258    #[test]
259    fn test_create_embedder_fastembed() {
260        if std::env::var("CI").is_ok() {
261            return;
262        }
263
264        let embedder = create_embedder(Some("BAAI/bge-small-en-v1.5"));
265
266        match embedder {
267            Ok(embedder) => {
268                assert_eq!(embedder.id(), "fastembed");
269                assert_eq!(embedder.dim(), 384);
270            }
271            Err(_) => {
272                // Model might not be available in test environment
273            }
274        }
275    }
276
277    #[test]
278    fn test_embedder_empty_input() {
279        let mut embedder = DummyEmbedder::new();
280        let texts: Vec<String> = vec![];
281        let embeddings = embedder.embed(&texts).unwrap();
282        assert_eq!(embeddings.len(), 0);
283    }
284
285    #[test]
286    fn test_embedder_single_text() {
287        let mut embedder = DummyEmbedder::new();
288        let texts = vec!["single text".to_string()];
289        let embeddings = embedder.embed(&texts).unwrap();
290
291        assert_eq!(embeddings.len(), 1);
292        assert_eq!(embeddings[0].len(), 384);
293    }
294
295    #[test]
296    fn test_embedder_multiple_texts() {
297        let mut embedder = DummyEmbedder::new();
298        let texts = vec![
299            "first text".to_string(),
300            "second text".to_string(),
301            "third text".to_string(),
302        ];
303        let embeddings = embedder.embed(&texts).unwrap();
304
305        assert_eq!(embeddings.len(), 3);
306        for embedding in &embeddings {
307            assert_eq!(embedding.len(), 384);
308        }
309    }
310}