1use anyhow::Result;
2use std::path::Path;
3use std::path::PathBuf;
4
5pub mod reranker;
6pub mod tokenizer;
7
8pub use reranker::{RerankResult, Reranker, create_reranker, create_reranker_with_progress};
9pub use tokenizer::TokenEstimator;
10
11pub trait Embedder: Send + Sync {
12 fn id(&self) -> &'static str;
13 fn dim(&self) -> usize;
14 fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
15}
16
17pub type ModelDownloadCallback = Box<dyn Fn(&str) + Send + Sync>;
18
19pub fn create_embedder(model_name: Option<&str>) -> Result<Box<dyn Embedder>> {
20 create_embedder_with_progress(model_name, None)
21}
22
23pub fn create_embedder_with_progress(
24 model_name: Option<&str>,
25 progress_callback: Option<ModelDownloadCallback>,
26) -> Result<Box<dyn Embedder>> {
27 let model = model_name.unwrap_or("BAAI/bge-small-en-v1.5");
28
29 #[cfg(feature = "fastembed")]
30 {
31 Ok(Box::new(FastEmbedder::new_with_progress(
32 model,
33 progress_callback,
34 )?))
35 }
36
37 #[cfg(not(feature = "fastembed"))]
38 {
39 if let Some(callback) = progress_callback {
40 callback("Using dummy embedder (no model download required)");
41 }
42 Ok(Box::new(DummyEmbedder::new()))
43 }
44}
45
46pub struct DummyEmbedder {
47 dim: usize,
48}
49
50impl Default for DummyEmbedder {
51 fn default() -> Self {
52 Self::new()
53 }
54}
55
56impl DummyEmbedder {
57 pub fn new() -> Self {
58 Self { dim: 384 } }
60}
61
62impl Embedder for DummyEmbedder {
63 fn id(&self) -> &'static str {
64 "dummy"
65 }
66
67 fn dim(&self) -> usize {
68 self.dim
69 }
70
71 fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
72 Ok(texts.iter().map(|_| vec![0.0; self.dim]).collect())
73 }
74}
75
76#[cfg(feature = "fastembed")]
77pub struct FastEmbedder {
78 model: fastembed::TextEmbedding,
79 dim: usize,
80}
81
82#[cfg(feature = "fastembed")]
83impl FastEmbedder {
84 pub fn new(model_name: &str) -> Result<Self> {
85 Self::new_with_progress(model_name, None)
86 }
87
88 pub fn new_with_progress(
89 model_name: &str,
90 progress_callback: Option<ModelDownloadCallback>,
91 ) -> Result<Self> {
92 use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
93
94 let model = match model_name {
95 "BAAI/bge-small-en-v1.5" => EmbeddingModel::BGESmallENV15,
97 "sentence-transformers/all-MiniLM-L6-v2" => EmbeddingModel::AllMiniLML6V2,
98
99 "nomic-embed-text-v1" => EmbeddingModel::NomicEmbedTextV1,
101 "nomic-embed-text-v1.5" => EmbeddingModel::NomicEmbedTextV15,
102 "jina-embeddings-v2-base-code" => EmbeddingModel::JinaEmbeddingsV2BaseCode,
103
104 "BAAI/bge-base-en-v1.5" => EmbeddingModel::BGEBaseENV15,
106 "BAAI/bge-large-en-v1.5" => EmbeddingModel::BGELargeENV15,
107
108 _ => EmbeddingModel::NomicEmbedTextV15,
110 };
111
112 let model_cache_dir = Self::get_model_cache_dir()?;
114 std::fs::create_dir_all(&model_cache_dir)?;
115
116 if let Some(ref callback) = progress_callback {
117 callback(&format!("Initializing model: {}", model_name));
118
119 let model_exists = Self::check_model_exists(&model_cache_dir, model_name);
121 if !model_exists {
122 callback(&format!(
123 "Downloading model {} to {}",
124 model_name,
125 model_cache_dir.display()
126 ));
127 } else {
128 callback(&format!("Using cached model: {}", model_name));
129 }
130 }
131
132 let init_options = InitOptions::new(model.clone())
133 .with_show_download_progress(progress_callback.is_some())
134 .with_cache_dir(model_cache_dir);
135
136 let embedding = TextEmbedding::try_new(init_options)?;
137
138 if let Some(ref callback) = progress_callback {
139 callback("Model loaded successfully");
140 }
141
142 let dim = match model {
143 EmbeddingModel::BGESmallENV15 => 384,
145 EmbeddingModel::AllMiniLML6V2 => 384,
146
147 EmbeddingModel::NomicEmbedTextV1 => 768,
149 EmbeddingModel::NomicEmbedTextV15 => 768,
150 EmbeddingModel::JinaEmbeddingsV2BaseCode => 768,
151 EmbeddingModel::BGEBaseENV15 => 768,
152
153 EmbeddingModel::BGELargeENV15 => 1024,
155
156 _ => 384, };
158
159 Ok(Self {
160 model: embedding,
161 dim,
162 })
163 }
164
165 fn get_model_cache_dir() -> Result<PathBuf> {
166 let cache_dir = if let Some(cache_home) = std::env::var_os("XDG_CACHE_HOME") {
168 PathBuf::from(cache_home).join("ck")
169 } else if let Some(home) = std::env::var_os("HOME") {
170 PathBuf::from(home).join(".cache").join("ck")
171 } else if let Some(appdata) = std::env::var_os("LOCALAPPDATA") {
172 PathBuf::from(appdata).join("ck").join("cache")
173 } else {
174 PathBuf::from(".ck_models")
176 };
177
178 Ok(cache_dir.join("models"))
179 }
180
181 fn check_model_exists(cache_dir: &Path, model_name: &str) -> bool {
182 let model_dir = cache_dir.join(model_name.replace("/", "_"));
184 model_dir.exists()
185 }
186}
187
188#[cfg(feature = "fastembed")]
189impl Embedder for FastEmbedder {
190 fn id(&self) -> &'static str {
191 "fastembed"
192 }
193
194 fn dim(&self) -> usize {
195 self.dim
196 }
197
198 fn embed(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
199 let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
200 let embeddings = self.model.embed(text_refs, None)?;
201 Ok(embeddings)
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208
209 #[test]
210 fn test_dummy_embedder() {
211 let mut embedder = DummyEmbedder::new();
212
213 assert_eq!(embedder.id(), "dummy");
214 assert_eq!(embedder.dim(), 384);
215
216 let texts = vec!["hello".to_string(), "world".to_string()];
217 let embeddings = embedder.embed(&texts).unwrap();
218
219 assert_eq!(embeddings.len(), 2);
220 assert_eq!(embeddings[0].len(), 384);
221 assert_eq!(embeddings[1].len(), 384);
222
223 assert!(embeddings[0].iter().all(|&x| x == 0.0));
225 assert!(embeddings[1].iter().all(|&x| x == 0.0));
226 }
227
228 #[test]
229 fn test_create_embedder_dummy() {
230 #[cfg(not(feature = "fastembed"))]
231 {
232 let embedder = create_embedder(None).unwrap();
233 assert_eq!(embedder.id(), "dummy");
234 assert_eq!(embedder.dim(), 384);
235 }
236 }
237
238 #[test]
239 fn test_embedder_trait_object() {
240 let mut embedder: Box<dyn Embedder> = Box::new(DummyEmbedder::new());
241
242 let texts = vec!["test".to_string()];
243 let result = embedder.embed(&texts);
244 assert!(result.is_ok());
245
246 let embeddings = result.unwrap();
247 assert_eq!(embeddings.len(), 1);
248 assert_eq!(embeddings[0].len(), 384);
249 }
250
251 #[cfg(feature = "fastembed")]
252 #[test]
253 fn test_fastembed_creation() {
254 if std::env::var("CI").is_ok() {
256 return;
257 }
258
259 let embedder = FastEmbedder::new("BAAI/bge-small-en-v1.5");
260
261 match embedder {
264 Ok(mut embedder) => {
265 assert_eq!(embedder.id(), "fastembed");
266 assert_eq!(embedder.dim(), 384);
267
268 let texts = vec!["hello world".to_string()];
269 let result = embedder.embed(&texts);
270 assert!(result.is_ok());
271
272 let embeddings = result.unwrap();
273 assert_eq!(embeddings.len(), 1);
274 assert_eq!(embeddings[0].len(), 384);
275
276 assert!(!embeddings[0].iter().all(|&x| x == 0.0));
278 }
279 Err(_) => {
280 }
283 }
284 }
285
286 #[cfg(feature = "fastembed")]
287 #[test]
288 fn test_create_embedder_fastembed() {
289 if std::env::var("CI").is_ok() {
290 return;
291 }
292
293 let embedder = create_embedder(Some("BAAI/bge-small-en-v1.5"));
294
295 match embedder {
296 Ok(embedder) => {
297 assert_eq!(embedder.id(), "fastembed");
298 assert_eq!(embedder.dim(), 384);
299 }
300 Err(_) => {
301 }
303 }
304 }
305
306 #[test]
307 fn test_embedder_empty_input() {
308 let mut embedder = DummyEmbedder::new();
309 let texts: Vec<String> = vec![];
310 let embeddings = embedder.embed(&texts).unwrap();
311 assert_eq!(embeddings.len(), 0);
312 }
313
314 #[test]
315 fn test_embedder_single_text() {
316 let mut embedder = DummyEmbedder::new();
317 let texts = vec!["single text".to_string()];
318 let embeddings = embedder.embed(&texts).unwrap();
319
320 assert_eq!(embeddings.len(), 1);
321 assert_eq!(embeddings[0].len(), 384);
322 }
323
324 #[test]
325 fn test_embedder_multiple_texts() {
326 let mut embedder = DummyEmbedder::new();
327 let texts = vec![
328 "first text".to_string(),
329 "second text".to_string(),
330 "third text".to_string(),
331 ];
332 let embeddings = embedder.embed(&texts).unwrap();
333
334 assert_eq!(embeddings.len(), 3);
335 for embedding in &embeddings {
336 assert_eq!(embedding.len(), 384);
337 }
338 }
339}