1use anyhow::Result;
2
3#[cfg(feature = "fastembed")]
4use std::path::{Path, PathBuf};
5
6pub mod reranker;
7pub mod tokenizer;
8
9pub use reranker::{RerankResult, Reranker, create_reranker, create_reranker_with_progress};
10pub use tokenizer::TokenEstimator;
11
12pub trait Embedder: Send + Sync {
13 fn id(&self) -> &'static str;
14 fn dim(&self) -> usize;
15 fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
16}
17
18pub type ModelDownloadCallback = Box<dyn Fn(&str) + Send + Sync>;
19
20pub fn create_embedder(model_name: Option<&str>) -> Result<Box<dyn Embedder>> {
21 create_embedder_with_progress(model_name, None)
22}
23
24pub fn create_embedder_with_progress(
25 model_name: Option<&str>,
26 progress_callback: Option<ModelDownloadCallback>,
27) -> Result<Box<dyn Embedder>> {
28 let model = model_name.unwrap_or("BAAI/bge-small-en-v1.5");
29
30 #[cfg(feature = "fastembed")]
31 {
32 Ok(Box::new(FastEmbedder::new_with_progress(
33 model,
34 progress_callback,
35 )?))
36 }
37
38 #[cfg(not(feature = "fastembed"))]
39 {
40 let _ = model; if let Some(callback) = progress_callback {
42 callback("Using dummy embedder (no model download required)");
43 }
44 Ok(Box::new(DummyEmbedder::new()))
45 }
46}
47
48pub struct DummyEmbedder {
49 dim: usize,
50}
51
52impl Default for DummyEmbedder {
53 fn default() -> Self {
54 Self::new()
55 }
56}
57
58impl DummyEmbedder {
59 pub fn new() -> Self {
60 Self { dim: 384 } }
62}
63
64impl Embedder for DummyEmbedder {
65 fn id(&self) -> &'static str {
66 "dummy"
67 }
68
69 fn dim(&self) -> usize {
70 self.dim
71 }
72
73 fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
74 Ok(texts.iter().map(|_| vec![0.0; self.dim]).collect())
75 }
76}
77
78#[cfg(feature = "fastembed")]
79pub struct FastEmbedder {
80 model: fastembed::TextEmbedding,
81 dim: usize,
82}
83
84#[cfg(feature = "fastembed")]
85impl FastEmbedder {
86 pub fn new(model_name: &str) -> Result<Self> {
87 Self::new_with_progress(model_name, None)
88 }
89
90 pub fn new_with_progress(
91 model_name: &str,
92 progress_callback: Option<ModelDownloadCallback>,
93 ) -> Result<Self> {
94 use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
95
96 let model = match model_name {
97 "BAAI/bge-small-en-v1.5" => EmbeddingModel::BGESmallENV15,
99 "sentence-transformers/all-MiniLM-L6-v2" => EmbeddingModel::AllMiniLML6V2,
100
101 "nomic-embed-text-v1" => EmbeddingModel::NomicEmbedTextV1,
103 "nomic-embed-text-v1.5" => EmbeddingModel::NomicEmbedTextV15,
104 "jina-embeddings-v2-base-code" => EmbeddingModel::JinaEmbeddingsV2BaseCode,
105
106 "BAAI/bge-base-en-v1.5" => EmbeddingModel::BGEBaseENV15,
108 "BAAI/bge-large-en-v1.5" => EmbeddingModel::BGELargeENV15,
109
110 _ => EmbeddingModel::NomicEmbedTextV15,
112 };
113
114 let model_cache_dir = Self::get_model_cache_dir()?;
116 std::fs::create_dir_all(&model_cache_dir)?;
117
118 if let Some(ref callback) = progress_callback {
119 callback(&format!("Initializing model: {}", model_name));
120
121 let model_exists = Self::check_model_exists(&model_cache_dir, model_name);
123 if !model_exists {
124 callback(&format!(
125 "Downloading model {} to {}",
126 model_name,
127 model_cache_dir.display()
128 ));
129 } else {
130 callback(&format!("Using cached model: {}", model_name));
131 }
132 }
133
134 let init_options = InitOptions::new(model.clone())
135 .with_show_download_progress(progress_callback.is_some())
136 .with_cache_dir(model_cache_dir);
137
138 let embedding = TextEmbedding::try_new(init_options)?;
139
140 if let Some(ref callback) = progress_callback {
141 callback("Model loaded successfully");
142 }
143
144 let dim = match model {
145 EmbeddingModel::BGESmallENV15 => 384,
147 EmbeddingModel::AllMiniLML6V2 => 384,
148
149 EmbeddingModel::NomicEmbedTextV1 => 768,
151 EmbeddingModel::NomicEmbedTextV15 => 768,
152 EmbeddingModel::JinaEmbeddingsV2BaseCode => 768,
153 EmbeddingModel::BGEBaseENV15 => 768,
154
155 EmbeddingModel::BGELargeENV15 => 1024,
157
158 _ => 384, };
160
161 Ok(Self {
162 model: embedding,
163 dim,
164 })
165 }
166
167 fn get_model_cache_dir() -> Result<PathBuf> {
168 let cache_dir = if let Some(cache_home) = std::env::var_os("XDG_CACHE_HOME") {
170 PathBuf::from(cache_home).join("ck")
171 } else if let Some(home) = std::env::var_os("HOME") {
172 PathBuf::from(home).join(".cache").join("ck")
173 } else if let Some(appdata) = std::env::var_os("LOCALAPPDATA") {
174 PathBuf::from(appdata).join("ck").join("cache")
175 } else {
176 PathBuf::from(".ck_models")
178 };
179
180 Ok(cache_dir.join("models"))
181 }
182
183 fn check_model_exists(cache_dir: &Path, model_name: &str) -> bool {
184 let model_dir = cache_dir.join(model_name.replace("/", "_"));
186 model_dir.exists()
187 }
188}
189
190#[cfg(feature = "fastembed")]
191impl Embedder for FastEmbedder {
192 fn id(&self) -> &'static str {
193 "fastembed"
194 }
195
196 fn dim(&self) -> usize {
197 self.dim
198 }
199
200 fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
201 let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
202 let embeddings = self.model.embed(text_refs, None)?;
203 Ok(embeddings)
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[test]
212 fn test_dummy_embedder() {
213 let mut embedder = DummyEmbedder::new();
214
215 assert_eq!(embedder.id(), "dummy");
216 assert_eq!(embedder.dim(), 384);
217
218 let texts = vec!["hello".to_string(), "world".to_string()];
219 let embeddings = embedder.embed(&texts).unwrap();
220
221 assert_eq!(embeddings.len(), 2);
222 assert_eq!(embeddings[0].len(), 384);
223 assert_eq!(embeddings[1].len(), 384);
224
225 assert!(embeddings[0].iter().all(|&x| x == 0.0));
227 assert!(embeddings[1].iter().all(|&x| x == 0.0));
228 }
229
230 #[test]
231 fn test_create_embedder_dummy() {
232 #[cfg(not(feature = "fastembed"))]
233 {
234 let embedder = create_embedder(None).unwrap();
235 assert_eq!(embedder.id(), "dummy");
236 assert_eq!(embedder.dim(), 384);
237 }
238 }
239
240 #[test]
241 fn test_embedder_trait_object() {
242 let mut embedder: Box<dyn Embedder> = Box::new(DummyEmbedder::new());
243
244 let texts = vec!["test".to_string()];
245 let result = embedder.embed(&texts);
246 assert!(result.is_ok());
247
248 let embeddings = result.unwrap();
249 assert_eq!(embeddings.len(), 1);
250 assert_eq!(embeddings[0].len(), 384);
251 }
252
253 #[cfg(feature = "fastembed")]
254 #[test]
255 fn test_fastembed_creation() {
256 if std::env::var("CI").is_ok() {
258 return;
259 }
260
261 let embedder = FastEmbedder::new("BAAI/bge-small-en-v1.5");
262
263 match embedder {
266 Ok(mut embedder) => {
267 assert_eq!(embedder.id(), "fastembed");
268 assert_eq!(embedder.dim(), 384);
269
270 let texts = vec!["hello world".to_string()];
271 let result = embedder.embed(&texts);
272 assert!(result.is_ok());
273
274 let embeddings = result.unwrap();
275 assert_eq!(embeddings.len(), 1);
276 assert_eq!(embeddings[0].len(), 384);
277
278 assert!(!embeddings[0].iter().all(|&x| x == 0.0));
280 }
281 Err(_) => {
282 }
285 }
286 }
287
288 #[cfg(feature = "fastembed")]
289 #[test]
290 fn test_create_embedder_fastembed() {
291 if std::env::var("CI").is_ok() {
292 return;
293 }
294
295 let embedder = create_embedder(Some("BAAI/bge-small-en-v1.5"));
296
297 match embedder {
298 Ok(embedder) => {
299 assert_eq!(embedder.id(), "fastembed");
300 assert_eq!(embedder.dim(), 384);
301 }
302 Err(_) => {
303 }
305 }
306 }
307
308 #[test]
309 fn test_embedder_empty_input() {
310 let mut embedder = DummyEmbedder::new();
311 let texts: Vec<String> = vec![];
312 let embeddings = embedder.embed(&texts).unwrap();
313 assert_eq!(embeddings.len(), 0);
314 }
315
316 #[test]
317 fn test_embedder_single_text() {
318 let mut embedder = DummyEmbedder::new();
319 let texts = vec!["single text".to_string()];
320 let embeddings = embedder.embed(&texts).unwrap();
321
322 assert_eq!(embeddings.len(), 1);
323 assert_eq!(embeddings[0].len(), 384);
324 }
325
326 #[test]
327 fn test_embedder_multiple_texts() {
328 let mut embedder = DummyEmbedder::new();
329 let texts = vec![
330 "first text".to_string(),
331 "second text".to_string(),
332 "third text".to_string(),
333 ];
334 let embeddings = embedder.embed(&texts).unwrap();
335
336 assert_eq!(embeddings.len(), 3);
337 for embedding in &embeddings {
338 assert_eq!(embedding.len(), 384);
339 }
340 }
341}