1use crate::batch::{mean_pooling, normalize_embeddings, BatchProcessor};
33use crate::error::{InferenceError, Result};
34use crate::models::{EmbeddingModel, ModelConfig};
35use candle_core::{DType, Device};
36use candle_nn::VarBuilder;
37use candle_transformers::models::bert::{BertModel, Config as BertConfig};
38use parking_lot::RwLock;
39use std::io::Read;
40use std::path::{Path, PathBuf};
41use std::sync::Arc;
42use tokenizers::Tokenizer;
43use tracing::{debug, info, instrument, warn};
44
45pub struct EmbeddingEngine {
51 model: Arc<RwLock<BertModel>>,
53 processor: Arc<BatchProcessor>,
55 device: Device,
57 config: ModelConfig,
59 dimension: usize,
61}
62
63impl EmbeddingEngine {
64 #[instrument(skip_all, fields(model = %config.model))]
68 pub async fn new(config: ModelConfig) -> Result<Self> {
69 info!("Initializing embedding engine with model: {}", config.model);
70
71 let device = Self::select_device(&config)?;
73 info!("Using device: {:?}", device);
74
75 let (model_path, tokenizer_path, config_path) = Self::download_model_files(&config).await?;
77
78 info!("Loading tokenizer from {:?}", tokenizer_path);
80 let tokenizer = Tokenizer::from_file(&tokenizer_path)
81 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
82
83 info!("Loading model config from {:?}", config_path);
85 let model_config: BertConfig = {
86 let config_str = std::fs::read_to_string(&config_path)?;
87 serde_json::from_str(&config_str)
88 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
89 };
90
91 info!("Loading model weights from {:?}", model_path);
93 let vb =
94 unsafe { VarBuilder::from_mmaped_safetensors(&[model_path], DType::F32, &device)? };
95
96 let model = BertModel::load(vb, &model_config)
97 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
98
99 let dimension = config.model.dimension();
100 let processor = Arc::new(BatchProcessor::new(
101 tokenizer,
102 config.model,
103 config.max_batch_size,
104 ));
105
106 info!(
107 "Embedding engine initialized: model={}, dimension={}, max_batch={}",
108 config.model, dimension, config.max_batch_size
109 );
110
111 Ok(Self {
112 model: Arc::new(RwLock::new(model)),
113 processor,
114 device,
115 config,
116 dimension,
117 })
118 }
119
120 fn select_device(config: &ModelConfig) -> Result<Device> {
122 if config.use_gpu {
123 #[cfg(feature = "cuda")]
125 {
126 if let Ok(device) = Device::new_cuda(0) {
127 return Ok(device);
128 }
129 warn!("CUDA requested but not available, falling back to CPU");
130 }
131
132 #[cfg(feature = "metal")]
134 {
135 if let Ok(device) = Device::new_metal(0) {
136 return Ok(device);
137 }
138 warn!("Metal requested but not available, falling back to CPU");
139 }
140
141 #[cfg(not(any(feature = "cuda", feature = "metal")))]
142 {
143 warn!("GPU requested but no GPU features enabled, using CPU");
144 }
145 }
146
147 Ok(Device::Cpu)
148 }
149
150 #[instrument(skip_all, fields(model = %config.model))]
158 async fn download_model_files(config: &ModelConfig) -> Result<(PathBuf, PathBuf, PathBuf)> {
159 let model_id = config.model.model_id();
160 info!("Resolving model files for: {}", model_id);
161
162 let model_id_owned = model_id.to_string();
163
164 let hf_cache = hf_hub::Cache::default();
166 let hf_repo = hf_hub::Repo::new(model_id_owned.clone(), hf_hub::RepoType::Model);
167 let cached_repo = hf_cache.repo(hf_repo);
168
169 let cached_model = cached_repo
170 .get("model.safetensors")
171 .or_else(|| cached_repo.get("pytorch_model.bin"));
172 let cached_tokenizer = cached_repo.get("tokenizer.json");
173 let cached_config = cached_repo.get("config.json");
174
175 if let (Some(m), Some(t), Some(c)) = (cached_model, cached_tokenizer, cached_config) {
176 info!("All model files found in HF cache");
177 return Ok((m, t, c));
178 }
179
180 let cache_dir = Self::model_cache_dir(model_id)?;
182 let local_model = cache_dir.join("model.safetensors");
183 let local_model_bin = cache_dir.join("pytorch_model.bin");
184 let local_tokenizer = cache_dir.join("tokenizer.json");
185 let local_config = cache_dir.join("config.json");
186
187 let model_exists = local_model.exists() || local_model_bin.exists();
188 if model_exists && local_tokenizer.exists() && local_config.exists() {
189 let mp = if local_model.exists() {
190 local_model
191 } else {
192 local_model_bin
193 };
194 info!("All model files found in local cache");
195 return Ok((mp, local_tokenizer, local_config));
196 }
197
198 info!("Downloading model files from HuggingFace...");
202
203 let cd = cache_dir.clone();
204 let mid = model_id_owned.clone();
205 tokio::task::spawn_blocking(move || {
206 Self::download_hf_file(&mid, "model.safetensors", &cd)
207 .or_else(|_| Self::download_hf_file(&mid, "pytorch_model.bin", &cd))
208 .map_err(|e| {
209 InferenceError::HubError(format!("Failed to download model weights: {}", e))
210 })?;
211 Self::download_hf_file(&mid, "tokenizer.json", &cd).map_err(|e| {
212 InferenceError::HubError(format!("Failed to download tokenizer: {}", e))
213 })?;
214 Self::download_hf_file(&mid, "config.json", &cd).map_err(|e| {
215 InferenceError::HubError(format!("Failed to download config: {}", e))
216 })?;
217 Ok::<_, InferenceError>(())
218 })
219 .await
220 .map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
221
222 let final_model = if cache_dir.join("model.safetensors").exists() {
223 cache_dir.join("model.safetensors")
224 } else {
225 cache_dir.join("pytorch_model.bin")
226 };
227
228 info!("Model files downloaded successfully to {:?}", cache_dir);
229 Ok((final_model, local_tokenizer, local_config))
230 }
231
232 fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
234 let base = std::env::var("HF_HOME")
235 .map(PathBuf::from)
236 .unwrap_or_else(|_| {
237 let home = std::env::var("HOME").unwrap_or_else(|_| {
238 tracing::warn!("HOME environment variable not set, using /tmp for model cache");
239 "/tmp".to_string()
240 });
241 PathBuf::from(home).join(".cache").join("huggingface")
242 });
243 let dir = base.join("dakera").join(model_id.replace('/', "--"));
244 std::fs::create_dir_all(&dir)?;
245 Ok(dir)
246 }
247
248 fn download_hf_file(
254 model_id: &str,
255 filename: &str,
256 cache_dir: &Path,
257 ) -> std::result::Result<PathBuf, String> {
258 let file_path = cache_dir.join(filename);
259 if file_path.exists() {
260 info!("Cached: {}", filename);
261 return Ok(file_path);
262 }
263
264 let url = format!(
265 "https://huggingface.co/{}/resolve/main/{}",
266 model_id, filename
267 );
268 info!("Downloading: {}", url);
269
270 let agent = ureq::AgentBuilder::new()
272 .redirects(0)
273 .timeout(std::time::Duration::from_secs(300))
274 .build();
275
276 let mut current_url = url.clone();
277 let mut redirects = 0;
278 let max_redirects = 10;
279
280 let response = loop {
281 let resp = agent.get(¤t_url).call();
282
283 let r = match resp {
284 Ok(r) => r,
285 Err(ureq::Error::Status(_status, r)) => r,
286 Err(e) => return Err(format!("{}: {}", filename, e)),
287 };
288
289 let status = r.status();
290 if (200..300).contains(&status) {
291 break r;
292 } else if (300..400).contains(&status) {
293 redirects += 1;
294 if redirects > max_redirects {
295 return Err(format!("{}: too many redirects", filename));
296 }
297 let location = r
298 .header("location")
299 .ok_or_else(|| format!("{}: redirect without Location header", filename))?
300 .to_string();
301
302 current_url = if location.starts_with('/') {
304 let parsed = url::Url::parse(¤t_url)
305 .map_err(|e| format!("{}: bad URL {}: {}", filename, current_url, e))?;
306 let host = parsed.host_str().ok_or_else(|| {
307 format!("{}: redirect URL missing host: {}", filename, current_url)
308 })?;
309 format!("{}://{}{}", parsed.scheme(), host, location)
310 } else {
311 location
312 };
313 info!("Redirect {} → {}", redirects, current_url);
314 } else {
315 return Err(format!("{}: HTTP {}", filename, status));
316 }
317 };
318
319 let mut bytes = Vec::new();
320 response
321 .into_reader()
322 .take(500_000_000) .read_to_end(&mut bytes)
324 .map_err(|e| format!("Failed to read {}: {}", filename, e))?;
325
326 std::fs::write(&file_path, &bytes)
327 .map_err(|e| format!("Failed to write {}: {}", filename, e))?;
328
329 info!("Downloaded {} ({} bytes)", filename, bytes.len());
330 Ok(file_path)
331 }
332
333 pub fn dimension(&self) -> usize {
335 self.dimension
336 }
337
338 pub fn model(&self) -> EmbeddingModel {
340 self.config.model
341 }
342
343 pub fn device(&self) -> &Device {
345 &self.device
346 }
347
348 #[instrument(skip(self, text), fields(text_len = text.len()))]
352 pub async fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
353 let texts = vec![text.to_string()];
354 let prepared = self.processor.prepare_texts(&texts, true);
355 let embeddings = self.embed_batch_internal(&prepared).await?;
356 embeddings.into_iter().next().ok_or_else(|| {
357 crate::error::InferenceError::InferenceError(
358 "No embedding returned for query".to_string(),
359 )
360 })
361 }
362
363 #[instrument(skip(self, texts), fields(count = texts.len()))]
367 pub async fn embed_queries(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
368 let prepared = self.processor.prepare_texts(texts, true);
369 self.embed_batch_internal(&prepared).await
370 }
371
372 #[instrument(skip(self, text), fields(text_len = text.len()))]
376 pub async fn embed_document(&self, text: &str) -> Result<Vec<f32>> {
377 let texts = vec![text.to_string()];
378 let prepared = self.processor.prepare_texts(&texts, false);
379 let embeddings = self.embed_batch_internal(&prepared).await?;
380 embeddings.into_iter().next().ok_or_else(|| {
381 crate::error::InferenceError::InferenceError(
382 "No embedding returned for document".to_string(),
383 )
384 })
385 }
386
387 #[instrument(skip(self, texts), fields(count = texts.len()))]
391 pub async fn embed_documents(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
392 let prepared = self.processor.prepare_texts(texts, false);
393 self.embed_batch_internal(&prepared).await
394 }
395
396 #[instrument(skip(self, texts), fields(count = texts.len()))]
398 pub async fn embed_raw(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
399 self.embed_batch_internal(texts).await
400 }
401
402 async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
407 if texts.is_empty() {
408 return Ok(vec![]);
409 }
410
411 let batches = self.processor.split_into_batches(texts);
413 let mut all_embeddings = Vec::with_capacity(texts.len());
414
415 for batch in batches {
416 let batch_owned: Vec<String> = batch.to_vec();
418 let model = Arc::clone(&self.model);
419 let processor = Arc::clone(&self.processor);
420 let device = self.device.clone();
421 let normalize = self.config.model.normalize_embeddings();
422
423 let batch_embeddings = tokio::task::spawn_blocking(move || {
424 Self::process_batch_blocking(&batch_owned, &model, &processor, &device, normalize)
425 })
426 .await
427 .map_err(|e| {
428 InferenceError::InferenceError(format!("Inference task panicked: {}", e))
429 })??;
430
431 all_embeddings.extend(batch_embeddings);
432 }
433
434 Ok(all_embeddings)
435 }
436
437 fn process_batch_blocking(
442 texts: &[String],
443 model: &Arc<RwLock<BertModel>>,
444 processor: &BatchProcessor,
445 device: &Device,
446 normalize: bool,
447 ) -> Result<Vec<Vec<f32>>> {
448 let prepared = processor.tokenize_batch(texts, device)?;
450
451 let model_guard = model.read();
453
454 let input_ids = prepared.input_ids.to_dtype(DType::U32)?;
455 let attention_mask = prepared.attention_mask.to_dtype(DType::U32)?;
456 let token_type_ids = prepared.token_type_ids.to_dtype(DType::U32)?;
457
458 let output = model_guard.forward(&input_ids, &token_type_ids, Some(&attention_mask))?;
459
460 let attention_mask_f32 = prepared.attention_mask.to_dtype(DType::F32)?;
462 let pooled = mean_pooling(&output, &attention_mask_f32)?;
463
464 let normalized = if normalize {
466 normalize_embeddings(&pooled)?
467 } else {
468 pooled
469 };
470
471 drop(model_guard);
473
474 let embeddings = normalized.to_vec2::<f32>()?;
476
477 debug!(
478 "Generated {} embeddings of dimension {}",
479 embeddings.len(),
480 embeddings.first().map(|e| e.len()).unwrap_or(0)
481 );
482
483 Ok(embeddings)
484 }
485
486 pub fn estimate_time_ms(&self, text_count: usize, avg_text_len: usize) -> f64 {
488 let tokens_per_text =
490 (avg_text_len as f64 / 4.0).min(self.config.model.max_seq_length() as f64);
491 let total_tokens = tokens_per_text * text_count as f64;
492 let tokens_per_second = self.config.model.tokens_per_second_cpu() as f64;
493
494 let speed_multiplier = if matches!(self.device, Device::Cpu) {
496 1.0
497 } else {
498 10.0
499 };
500
501 (total_tokens / (tokens_per_second * speed_multiplier)) * 1000.0
502 }
503}
504
505impl std::fmt::Debug for EmbeddingEngine {
506 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
507 f.debug_struct("EmbeddingEngine")
508 .field("model", &self.config.model)
509 .field("dimension", &self.dimension)
510 .field("device", &self.device)
511 .field("max_batch_size", &self.config.max_batch_size)
512 .finish()
513 }
514}
515
516pub struct EmbeddingEngineBuilder {
518 config: ModelConfig,
519}
520
521impl EmbeddingEngineBuilder {
522 pub fn new() -> Self {
524 Self {
525 config: ModelConfig::default(),
526 }
527 }
528
529 pub fn model(mut self, model: EmbeddingModel) -> Self {
531 self.config.model = model;
532 self
533 }
534
535 pub fn cache_dir(mut self, dir: impl Into<String>) -> Self {
537 self.config.cache_dir = Some(dir.into());
538 self
539 }
540
541 pub fn max_batch_size(mut self, size: usize) -> Self {
543 self.config.max_batch_size = size;
544 self
545 }
546
547 pub fn use_gpu(mut self, enable: bool) -> Self {
549 self.config.use_gpu = enable;
550 self
551 }
552
553 pub fn num_threads(mut self, threads: usize) -> Self {
555 self.config.num_threads = Some(threads);
556 self
557 }
558
559 pub async fn build(self) -> Result<EmbeddingEngine> {
561 EmbeddingEngine::new(self.config).await
562 }
563}
564
565impl Default for EmbeddingEngineBuilder {
566 fn default() -> Self {
567 Self::new()
568 }
569}
570
571#[cfg(test)]
572mod tests {
573 use super::*;
574
575 #[test]
576 fn test_estimate_time() {
577 let config = ModelConfig::new(EmbeddingModel::MiniLM);
578 let tokens_per_second = config.model.tokens_per_second_cpu() as f64;
580 assert!(tokens_per_second > 0.0);
581 }
582
583 #[test]
584 fn test_builder() {
585 let builder = EmbeddingEngineBuilder::new()
586 .model(EmbeddingModel::BgeSmall)
587 .max_batch_size(64)
588 .use_gpu(false);
589
590 assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
591 assert_eq!(builder.config.max_batch_size, 64);
592 assert!(!builder.config.use_gpu);
593 }
594}