1use anyhow::Result;
2use std::path::Path;
3use std::path::PathBuf;
4
5pub trait Embedder: Send + Sync {
6 fn id(&self) -> &'static str;
7 fn dim(&self) -> usize;
8 fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
9}
10
11pub type ModelDownloadCallback = Box<dyn Fn(&str) + Send + Sync>;
12
13pub fn create_embedder(model_name: Option<&str>) -> Result<Box<dyn Embedder>> {
14 create_embedder_with_progress(model_name, None)
15}
16
17pub fn create_embedder_with_progress(
18 model_name: Option<&str>,
19 progress_callback: Option<ModelDownloadCallback>,
20) -> Result<Box<dyn Embedder>> {
21 let model = model_name.unwrap_or("BAAI/bge-small-en-v1.5");
22
23 #[cfg(feature = "fastembed")]
24 {
25 Ok(Box::new(FastEmbedder::new_with_progress(
26 model,
27 progress_callback,
28 )?))
29 }
30
31 #[cfg(not(feature = "fastembed"))]
32 {
33 if let Some(callback) = progress_callback {
34 callback("Using dummy embedder (no model download required)");
35 }
36 Ok(Box::new(DummyEmbedder::new()))
37 }
38}
39
40pub struct DummyEmbedder {
41 dim: usize,
42}
43
44impl Default for DummyEmbedder {
45 fn default() -> Self {
46 Self::new()
47 }
48}
49
50impl DummyEmbedder {
51 pub fn new() -> Self {
52 Self { dim: 384 }
53 }
54}
55
56impl Embedder for DummyEmbedder {
57 fn id(&self) -> &'static str {
58 "dummy"
59 }
60
61 fn dim(&self) -> usize {
62 self.dim
63 }
64
65 fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
66 Ok(texts.iter().map(|_| vec![0.0; self.dim]).collect())
67 }
68}
69
70#[cfg(feature = "fastembed")]
71pub struct FastEmbedder {
72 model: fastembed::TextEmbedding,
73 dim: usize,
74}
75
76#[cfg(feature = "fastembed")]
77impl FastEmbedder {
78 pub fn new(model_name: &str) -> Result<Self> {
79 Self::new_with_progress(model_name, None)
80 }
81
82 pub fn new_with_progress(
83 model_name: &str,
84 progress_callback: Option<ModelDownloadCallback>,
85 ) -> Result<Self> {
86 use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
87
88 let model = match model_name {
89 "BAAI/bge-small-en-v1.5" => EmbeddingModel::BGESmallENV15,
90 "sentence-transformers/all-MiniLM-L6-v2" => EmbeddingModel::AllMiniLML6V2,
91 _ => EmbeddingModel::BGESmallENV15,
92 };
93
94 let model_cache_dir = Self::get_model_cache_dir()?;
96 std::fs::create_dir_all(&model_cache_dir)?;
97
98 if let Some(ref callback) = progress_callback {
99 callback(&format!("Initializing model: {}", model_name));
100
101 let model_exists = Self::check_model_exists(&model_cache_dir, model_name);
103 if !model_exists {
104 callback(&format!(
105 "Downloading model {} to {}",
106 model_name,
107 model_cache_dir.display()
108 ));
109 } else {
110 callback(&format!("Using cached model: {}", model_name));
111 }
112 }
113
114 let init_options = InitOptions::new(model.clone())
115 .with_show_download_progress(progress_callback.is_some())
116 .with_cache_dir(model_cache_dir);
117
118 let embedding = TextEmbedding::try_new(init_options)?;
119
120 if let Some(ref callback) = progress_callback {
121 callback("Model loaded successfully");
122 }
123
124 let dim = match model {
125 EmbeddingModel::BGESmallENV15 => 384,
126 EmbeddingModel::AllMiniLML6V2 => 384,
127 _ => 384,
128 };
129
130 Ok(Self {
131 model: embedding,
132 dim,
133 })
134 }
135
136 fn get_model_cache_dir() -> Result<PathBuf> {
137 let cache_dir = if let Some(cache_home) = std::env::var_os("XDG_CACHE_HOME") {
139 PathBuf::from(cache_home).join("ck")
140 } else if let Some(home) = std::env::var_os("HOME") {
141 PathBuf::from(home).join(".cache").join("ck")
142 } else if let Some(appdata) = std::env::var_os("LOCALAPPDATA") {
143 PathBuf::from(appdata).join("ck").join("cache")
144 } else {
145 PathBuf::from(".ck_models")
147 };
148
149 Ok(cache_dir.join("models"))
150 }
151
152 fn check_model_exists(cache_dir: &Path, model_name: &str) -> bool {
153 let model_dir = cache_dir.join(model_name.replace("/", "_"));
155 model_dir.exists()
156 }
157}
158
159#[cfg(feature = "fastembed")]
160impl Embedder for FastEmbedder {
161 fn id(&self) -> &'static str {
162 "fastembed"
163 }
164
165 fn dim(&self) -> usize {
166 self.dim
167 }
168
169 fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
170 let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
171 let embeddings = self.model.embed(text_refs, None)?;
172 Ok(embeddings)
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179
180 #[test]
181 fn test_dummy_embedder() {
182 let mut embedder = DummyEmbedder::new();
183
184 assert_eq!(embedder.id(), "dummy");
185 assert_eq!(embedder.dim(), 384);
186
187 let texts = vec!["hello".to_string(), "world".to_string()];
188 let embeddings = embedder.embed(&texts).unwrap();
189
190 assert_eq!(embeddings.len(), 2);
191 assert_eq!(embeddings[0].len(), 384);
192 assert_eq!(embeddings[1].len(), 384);
193
194 assert!(embeddings[0].iter().all(|&x| x == 0.0));
196 assert!(embeddings[1].iter().all(|&x| x == 0.0));
197 }
198
199 #[test]
200 fn test_create_embedder_dummy() {
201 #[cfg(not(feature = "fastembed"))]
202 {
203 let embedder = create_embedder(None).unwrap();
204 assert_eq!(embedder.id(), "dummy");
205 assert_eq!(embedder.dim(), 384);
206 }
207 }
208
209 #[test]
210 fn test_embedder_trait_object() {
211 let mut embedder: Box<dyn Embedder> = Box::new(DummyEmbedder::new());
212
213 let texts = vec!["test".to_string()];
214 let result = embedder.embed(&texts);
215 assert!(result.is_ok());
216
217 let embeddings = result.unwrap();
218 assert_eq!(embeddings.len(), 1);
219 assert_eq!(embeddings[0].len(), 384);
220 }
221
222 #[cfg(feature = "fastembed")]
223 #[test]
224 fn test_fastembed_creation() {
225 if std::env::var("CI").is_ok() {
227 return;
228 }
229
230 let embedder = FastEmbedder::new("BAAI/bge-small-en-v1.5");
231
232 match embedder {
235 Ok(mut embedder) => {
236 assert_eq!(embedder.id(), "fastembed");
237 assert_eq!(embedder.dim(), 384);
238
239 let texts = vec!["hello world".to_string()];
240 let result = embedder.embed(&texts);
241 assert!(result.is_ok());
242
243 let embeddings = result.unwrap();
244 assert_eq!(embeddings.len(), 1);
245 assert_eq!(embeddings[0].len(), 384);
246
247 assert!(!embeddings[0].iter().all(|&x| x == 0.0));
249 }
250 Err(_) => {
251 }
254 }
255 }
256
257 #[cfg(feature = "fastembed")]
258 #[test]
259 fn test_create_embedder_fastembed() {
260 if std::env::var("CI").is_ok() {
261 return;
262 }
263
264 let embedder = create_embedder(Some("BAAI/bge-small-en-v1.5"));
265
266 match embedder {
267 Ok(embedder) => {
268 assert_eq!(embedder.id(), "fastembed");
269 assert_eq!(embedder.dim(), 384);
270 }
271 Err(_) => {
272 }
274 }
275 }
276
277 #[test]
278 fn test_embedder_empty_input() {
279 let mut embedder = DummyEmbedder::new();
280 let texts: Vec<String> = vec![];
281 let embeddings = embedder.embed(&texts).unwrap();
282 assert_eq!(embeddings.len(), 0);
283 }
284
285 #[test]
286 fn test_embedder_single_text() {
287 let mut embedder = DummyEmbedder::new();
288 let texts = vec!["single text".to_string()];
289 let embeddings = embedder.embed(&texts).unwrap();
290
291 assert_eq!(embeddings.len(), 1);
292 assert_eq!(embeddings[0].len(), 384);
293 }
294
295 #[test]
296 fn test_embedder_multiple_texts() {
297 let mut embedder = DummyEmbedder::new();
298 let texts = vec![
299 "first text".to_string(),
300 "second text".to_string(),
301 "third text".to_string(),
302 ];
303 let embeddings = embedder.embed(&texts).unwrap();
304
305 assert_eq!(embeddings.len(), 3);
306 for embedding in &embeddings {
307 assert_eq!(embedding.len(), 384);
308 }
309 }
310}