Skip to main content

inference/
engine.rs

1//! Core embedding engine for generating vector embeddings from text.
2//!
3//! The `EmbeddingEngine` provides a high-level interface for:
4//! - Loading embedding models from HuggingFace Hub
5//! - Generating embeddings for single texts or batches
6//! - Automatic batching and parallel processing
7//!
8//! # Example
9//!
10//! ```no_run
11//! use inference::{EmbeddingEngine, ModelConfig, EmbeddingModel};
12//!
13//! #[tokio::main]
14//! async fn main() {
15//!     let config = ModelConfig::new(EmbeddingModel::MiniLM);
16//!     let engine = EmbeddingEngine::new(config).await.unwrap();
17//!
18//!     // Embed a single query
19//!     let embedding = engine.embed_query("What is machine learning?").await.unwrap();
20//!     println!("Embedding dimension: {}", embedding.len());
21//!
22//!     // Embed multiple documents
23//!     let docs = vec![
24//!         "Machine learning is a subset of AI.".to_string(),
25//!         "Deep learning uses neural networks.".to_string(),
26//!     ];
27//!     let embeddings = engine.embed_documents(&docs).await.unwrap();
28//!     println!("Generated {} embeddings", embeddings.len());
29//! }
30//! ```
31
32use crate::batch::{mean_pooling, normalize_embeddings, BatchProcessor};
33use crate::error::{InferenceError, Result};
34use crate::models::{EmbeddingModel, ModelConfig};
35use candle_core::{DType, Device};
36use candle_nn::VarBuilder;
37use candle_transformers::models::bert::{BertModel, Config as BertConfig};
38use parking_lot::RwLock;
39use std::io::Read;
40use std::path::{Path, PathBuf};
41use std::sync::Arc;
42use tokenizers::Tokenizer;
43use tracing::{debug, info, instrument, warn};
44
45/// The main embedding engine for generating vector embeddings.
46///
47/// This struct is thread-safe and can be shared across async tasks.
48/// Internal fields use `Arc` so CPU-heavy inference can be offloaded
49/// to blocking threads via `tokio::task::spawn_blocking`.
50pub struct EmbeddingEngine {
51    /// The loaded BERT model
52    model: Arc<RwLock<BertModel>>,
53    /// Batch processor for tokenization (Arc-wrapped for spawn_blocking)
54    processor: Arc<BatchProcessor>,
55    /// Device for computation
56    device: Device,
57    /// Model configuration
58    config: ModelConfig,
59    /// Embedding dimension
60    dimension: usize,
61}
62
63impl EmbeddingEngine {
64    /// Create a new embedding engine with the given configuration.
65    ///
66    /// This will download the model from HuggingFace Hub if not cached.
67    #[instrument(skip_all, fields(model = %config.model))]
68    pub async fn new(config: ModelConfig) -> Result<Self> {
69        info!("Initializing embedding engine with model: {}", config.model);
70
71        // Select device
72        let device = Self::select_device(&config)?;
73        info!("Using device: {:?}", device);
74
75        // Download model files
76        let (model_path, tokenizer_path, config_path) = Self::download_model_files(&config).await?;
77
78        // Load tokenizer
79        info!("Loading tokenizer from {:?}", tokenizer_path);
80        let tokenizer = Tokenizer::from_file(&tokenizer_path)
81            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
82
83        // Load model config
84        info!("Loading model config from {:?}", config_path);
85        let model_config: BertConfig = {
86            let config_str = std::fs::read_to_string(&config_path)?;
87            serde_json::from_str(&config_str)
88                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
89        };
90
91        // Load model weights
92        info!("Loading model weights from {:?}", model_path);
93        let vb =
94            unsafe { VarBuilder::from_mmaped_safetensors(&[model_path], DType::F32, &device)? };
95
96        let model = BertModel::load(vb, &model_config)
97            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
98
99        let dimension = config.model.dimension();
100        let processor = Arc::new(BatchProcessor::new(
101            tokenizer,
102            config.model,
103            config.max_batch_size,
104        ));
105
106        info!(
107            "Embedding engine initialized: model={}, dimension={}, max_batch={}",
108            config.model, dimension, config.max_batch_size
109        );
110
111        Ok(Self {
112            model: Arc::new(RwLock::new(model)),
113            processor,
114            device,
115            config,
116            dimension,
117        })
118    }
119
120    /// Select the appropriate compute device based on configuration.
121    fn select_device(config: &ModelConfig) -> Result<Device> {
122        if config.use_gpu {
123            // Try CUDA first
124            #[cfg(feature = "cuda")]
125            {
126                if let Ok(device) = Device::new_cuda(0) {
127                    return Ok(device);
128                }
129                warn!("CUDA requested but not available, falling back to CPU");
130            }
131
132            // Try Metal (macOS)
133            #[cfg(feature = "metal")]
134            {
135                if let Ok(device) = Device::new_metal(0) {
136                    return Ok(device);
137                }
138                warn!("Metal requested but not available, falling back to CPU");
139            }
140
141            #[cfg(not(any(feature = "cuda", feature = "metal")))]
142            {
143                warn!("GPU requested but no GPU features enabled, using CPU");
144            }
145        }
146
147        Ok(Device::Cpu)
148    }
149
150    /// Resolve model files, downloading from HuggingFace Hub if needed.
151    ///
152    /// Uses a two-tier cache strategy:
153    /// 1. Check local Dakera cache directory
154    /// 2. Check hf-hub's standard cache (~/.cache/huggingface/hub/)
155    /// 3. Download directly via ureq (bypasses hf-hub's broken redirect handling
156    ///    in its metadata() method which can't resolve relative Location headers)
157    #[instrument(skip_all, fields(model = %config.model))]
158    async fn download_model_files(config: &ModelConfig) -> Result<(PathBuf, PathBuf, PathBuf)> {
159        let model_id = config.model.model_id();
160        info!("Resolving model files for: {}", model_id);
161
162        let model_id_owned = model_id.to_string();
163
164        // Check hf-hub's standard cache first (fast path, no network)
165        let hf_cache = hf_hub::Cache::default();
166        let hf_repo = hf_hub::Repo::new(model_id_owned.clone(), hf_hub::RepoType::Model);
167        let cached_repo = hf_cache.repo(hf_repo);
168
169        let cached_model = cached_repo
170            .get("model.safetensors")
171            .or_else(|| cached_repo.get("pytorch_model.bin"));
172        let cached_tokenizer = cached_repo.get("tokenizer.json");
173        let cached_config = cached_repo.get("config.json");
174
175        if let (Some(m), Some(t), Some(c)) = (cached_model, cached_tokenizer, cached_config) {
176            info!("All model files found in HF cache");
177            return Ok((m, t, c));
178        }
179
180        // Check local Dakera cache
181        let cache_dir = Self::model_cache_dir(model_id)?;
182        let local_model = cache_dir.join("model.safetensors");
183        let local_model_bin = cache_dir.join("pytorch_model.bin");
184        let local_tokenizer = cache_dir.join("tokenizer.json");
185        let local_config = cache_dir.join("config.json");
186
187        let model_exists = local_model.exists() || local_model_bin.exists();
188        if model_exists && local_tokenizer.exists() && local_config.exists() {
189            let mp = if local_model.exists() {
190                local_model
191            } else {
192                local_model_bin
193            };
194            info!("All model files found in local cache");
195            return Ok((mp, local_tokenizer, local_config));
196        }
197
198        // Download directly using ureq (handles relative redirects correctly,
199        // unlike hf-hub 0.3's metadata() which passes relative Location headers
200        // to reqwest's get() as bare strings, causing URL parse failures).
201        info!("Downloading model files from HuggingFace...");
202
203        let cd = cache_dir.clone();
204        let mid = model_id_owned.clone();
205        tokio::task::spawn_blocking(move || {
206            Self::download_hf_file(&mid, "model.safetensors", &cd)
207                .or_else(|_| Self::download_hf_file(&mid, "pytorch_model.bin", &cd))
208                .map_err(|e| {
209                    InferenceError::HubError(format!("Failed to download model weights: {}", e))
210                })?;
211            Self::download_hf_file(&mid, "tokenizer.json", &cd).map_err(|e| {
212                InferenceError::HubError(format!("Failed to download tokenizer: {}", e))
213            })?;
214            Self::download_hf_file(&mid, "config.json", &cd).map_err(|e| {
215                InferenceError::HubError(format!("Failed to download config: {}", e))
216            })?;
217            Ok::<_, InferenceError>(())
218        })
219        .await
220        .map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
221
222        let final_model = if cache_dir.join("model.safetensors").exists() {
223            cache_dir.join("model.safetensors")
224        } else {
225            cache_dir.join("pytorch_model.bin")
226        };
227
228        info!("Model files downloaded successfully to {:?}", cache_dir);
229        Ok((final_model, local_tokenizer, local_config))
230    }
231
232    /// Get or create the local model cache directory.
233    fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
234        let base = std::env::var("HF_HOME")
235            .map(PathBuf::from)
236            .unwrap_or_else(|_| {
237                let home = std::env::var("HOME").unwrap_or_else(|_| {
238                    tracing::warn!("HOME environment variable not set, using /tmp for model cache");
239                    "/tmp".to_string()
240                });
241                PathBuf::from(home).join(".cache").join("huggingface")
242            });
243        let dir = base.join("dakera").join(model_id.replace('/', "--"));
244        std::fs::create_dir_all(&dir)?;
245        Ok(dir)
246    }
247
248    /// Download a single file from HuggingFace using ureq (sync).
249    ///
250    /// Handles redirects manually because HuggingFace returns relative
251    /// Location headers (e.g. `/api/resolve-cache/...`) which ureq 2.x
252    /// cannot resolve — it fails with "relative URL without a base".
253    fn download_hf_file(
254        model_id: &str,
255        filename: &str,
256        cache_dir: &Path,
257    ) -> std::result::Result<PathBuf, String> {
258        let file_path = cache_dir.join(filename);
259        if file_path.exists() {
260            info!("Cached: {}", filename);
261            return Ok(file_path);
262        }
263
264        let url = format!(
265            "https://huggingface.co/{}/resolve/main/{}",
266            model_id, filename
267        );
268        info!("Downloading: {}", url);
269
270        // Disable automatic redirects so we can resolve relative Location headers ourselves.
271        let agent = ureq::AgentBuilder::new()
272            .redirects(0)
273            .timeout(std::time::Duration::from_secs(300))
274            .build();
275
276        let mut current_url = url.clone();
277        let mut redirects = 0;
278        let max_redirects = 10;
279
280        let response = loop {
281            let resp = agent.get(&current_url).call();
282
283            let r = match resp {
284                Ok(r) => r,
285                Err(ureq::Error::Status(_status, r)) => r,
286                Err(e) => return Err(format!("{}: {}", filename, e)),
287            };
288
289            let status = r.status();
290            if (200..300).contains(&status) {
291                break r;
292            } else if (300..400).contains(&status) {
293                redirects += 1;
294                if redirects > max_redirects {
295                    return Err(format!("{}: too many redirects", filename));
296                }
297                let location = r
298                    .header("location")
299                    .ok_or_else(|| format!("{}: redirect without Location header", filename))?
300                    .to_string();
301
302                // Resolve relative redirects against the current URL's origin
303                current_url = if location.starts_with('/') {
304                    let parsed = url::Url::parse(&current_url)
305                        .map_err(|e| format!("{}: bad URL {}: {}", filename, current_url, e))?;
306                    let host = parsed.host_str().ok_or_else(|| {
307                        format!("{}: redirect URL missing host: {}", filename, current_url)
308                    })?;
309                    format!("{}://{}{}", parsed.scheme(), host, location)
310                } else {
311                    location
312                };
313                info!("Redirect {} → {}", redirects, current_url);
314            } else {
315                return Err(format!("{}: HTTP {}", filename, status));
316            }
317        };
318
319        let mut bytes = Vec::new();
320        response
321            .into_reader()
322            .take(500_000_000) // 500MB safety limit
323            .read_to_end(&mut bytes)
324            .map_err(|e| format!("Failed to read {}: {}", filename, e))?;
325
326        std::fs::write(&file_path, &bytes)
327            .map_err(|e| format!("Failed to write {}: {}", filename, e))?;
328
329        info!("Downloaded {} ({} bytes)", filename, bytes.len());
330        Ok(file_path)
331    }
332
333    /// Get the embedding dimension for the loaded model.
334    pub fn dimension(&self) -> usize {
335        self.dimension
336    }
337
338    /// Get the model being used.
339    pub fn model(&self) -> EmbeddingModel {
340        self.config.model
341    }
342
343    /// Get the device being used.
344    pub fn device(&self) -> &Device {
345        &self.device
346    }
347
348    /// Embed a single query text.
349    ///
350    /// For models like E5, this automatically applies the query prefix.
351    #[instrument(skip(self, text), fields(text_len = text.len()))]
352    pub async fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
353        let texts = vec![text.to_string()];
354        let prepared = self.processor.prepare_texts(&texts, true);
355        let embeddings = self.embed_batch_internal(&prepared).await?;
356        embeddings.into_iter().next().ok_or_else(|| {
357            crate::error::InferenceError::InferenceError(
358                "No embedding returned for query".to_string(),
359            )
360        })
361    }
362
363    /// Embed multiple query texts.
364    ///
365    /// For models like E5, this automatically applies the query prefix.
366    #[instrument(skip(self, texts), fields(count = texts.len()))]
367    pub async fn embed_queries(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
368        let prepared = self.processor.prepare_texts(texts, true);
369        self.embed_batch_internal(&prepared).await
370    }
371
372    /// Embed a single document/passage.
373    ///
374    /// For models like E5, this automatically applies the document prefix.
375    #[instrument(skip(self, text), fields(text_len = text.len()))]
376    pub async fn embed_document(&self, text: &str) -> Result<Vec<f32>> {
377        let texts = vec![text.to_string()];
378        let prepared = self.processor.prepare_texts(&texts, false);
379        let embeddings = self.embed_batch_internal(&prepared).await?;
380        embeddings.into_iter().next().ok_or_else(|| {
381            crate::error::InferenceError::InferenceError(
382                "No embedding returned for document".to_string(),
383            )
384        })
385    }
386
387    /// Embed multiple documents/passages.
388    ///
389    /// For models like E5, this automatically applies the document prefix.
390    #[instrument(skip(self, texts), fields(count = texts.len()))]
391    pub async fn embed_documents(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
392        let prepared = self.processor.prepare_texts(texts, false);
393        self.embed_batch_internal(&prepared).await
394    }
395
396    /// Embed texts without any prefix (raw embedding).
397    #[instrument(skip(self, texts), fields(count = texts.len()))]
398    pub async fn embed_raw(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
399        self.embed_batch_internal(texts).await
400    }
401
402    /// Internal batch embedding implementation.
403    ///
404    /// Each batch is offloaded to a blocking thread via `spawn_blocking`
405    /// so that CPU-heavy BERT inference does not block the Tokio runtime.
406    async fn embed_batch_internal(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
407        if texts.is_empty() {
408            return Ok(vec![]);
409        }
410
411        // Split into batches if needed
412        let batches = self.processor.split_into_batches(texts);
413        let mut all_embeddings = Vec::with_capacity(texts.len());
414
415        for batch in batches {
416            // Clone Arc refs and owned data for the blocking closure
417            let batch_owned: Vec<String> = batch.to_vec();
418            let model = Arc::clone(&self.model);
419            let processor = Arc::clone(&self.processor);
420            let device = self.device.clone();
421            let normalize = self.config.model.normalize_embeddings();
422
423            let batch_embeddings = tokio::task::spawn_blocking(move || {
424                Self::process_batch_blocking(&batch_owned, &model, &processor, &device, normalize)
425            })
426            .await
427            .map_err(|e| {
428                InferenceError::InferenceError(format!("Inference task panicked: {}", e))
429            })??;
430
431            all_embeddings.extend(batch_embeddings);
432        }
433
434        Ok(all_embeddings)
435    }
436
437    /// Process a single batch through tokenization + model forward pass.
438    ///
439    /// This is a static method designed to run inside `spawn_blocking`
440    /// so it does not hold `&self` (which is not `Send`).
441    fn process_batch_blocking(
442        texts: &[String],
443        model: &Arc<RwLock<BertModel>>,
444        processor: &BatchProcessor,
445        device: &Device,
446        normalize: bool,
447    ) -> Result<Vec<Vec<f32>>> {
448        // Tokenize
449        let prepared = processor.tokenize_batch(texts, device)?;
450
451        // Forward pass: acquire read lock on the model
452        let model_guard = model.read();
453
454        let input_ids = prepared.input_ids.to_dtype(DType::U32)?;
455        let attention_mask = prepared.attention_mask.to_dtype(DType::U32)?;
456        let token_type_ids = prepared.token_type_ids.to_dtype(DType::U32)?;
457
458        let output = model_guard.forward(&input_ids, &token_type_ids, Some(&attention_mask))?;
459
460        // Apply mean pooling
461        let attention_mask_f32 = prepared.attention_mask.to_dtype(DType::F32)?;
462        let pooled = mean_pooling(&output, &attention_mask_f32)?;
463
464        // Normalize if configured
465        let normalized = if normalize {
466            normalize_embeddings(&pooled)?
467        } else {
468            pooled
469        };
470
471        // Release model lock before conversion
472        drop(model_guard);
473
474        // Convert to Vec<Vec<f32>>
475        let embeddings = normalized.to_vec2::<f32>()?;
476
477        debug!(
478            "Generated {} embeddings of dimension {}",
479            embeddings.len(),
480            embeddings.first().map(|e| e.len()).unwrap_or(0)
481        );
482
483        Ok(embeddings)
484    }
485
486    /// Estimate the time to embed a batch of texts (in milliseconds).
487    pub fn estimate_time_ms(&self, text_count: usize, avg_text_len: usize) -> f64 {
488        // Rough estimation based on model speed and text length
489        let tokens_per_text =
490            (avg_text_len as f64 / 4.0).min(self.config.model.max_seq_length() as f64);
491        let total_tokens = tokens_per_text * text_count as f64;
492        let tokens_per_second = self.config.model.tokens_per_second_cpu() as f64;
493
494        // GPU is roughly 10x faster
495        let speed_multiplier = if matches!(self.device, Device::Cpu) {
496            1.0
497        } else {
498            10.0
499        };
500
501        (total_tokens / (tokens_per_second * speed_multiplier)) * 1000.0
502    }
503}
504
505impl std::fmt::Debug for EmbeddingEngine {
506    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
507        f.debug_struct("EmbeddingEngine")
508            .field("model", &self.config.model)
509            .field("dimension", &self.dimension)
510            .field("device", &self.device)
511            .field("max_batch_size", &self.config.max_batch_size)
512            .finish()
513    }
514}
515
516/// Builder for creating an EmbeddingEngine with fluent API.
517pub struct EmbeddingEngineBuilder {
518    config: ModelConfig,
519}
520
521impl EmbeddingEngineBuilder {
522    /// Create a new builder with default configuration.
523    pub fn new() -> Self {
524        Self {
525            config: ModelConfig::default(),
526        }
527    }
528
529    /// Set the embedding model to use.
530    pub fn model(mut self, model: EmbeddingModel) -> Self {
531        self.config.model = model;
532        self
533    }
534
535    /// Set the cache directory for model files.
536    pub fn cache_dir(mut self, dir: impl Into<String>) -> Self {
537        self.config.cache_dir = Some(dir.into());
538        self
539    }
540
541    /// Set the maximum batch size.
542    pub fn max_batch_size(mut self, size: usize) -> Self {
543        self.config.max_batch_size = size;
544        self
545    }
546
547    /// Enable GPU acceleration.
548    pub fn use_gpu(mut self, enable: bool) -> Self {
549        self.config.use_gpu = enable;
550        self
551    }
552
553    /// Set the number of CPU threads.
554    pub fn num_threads(mut self, threads: usize) -> Self {
555        self.config.num_threads = Some(threads);
556        self
557    }
558
559    /// Build the embedding engine.
560    pub async fn build(self) -> Result<EmbeddingEngine> {
561        EmbeddingEngine::new(self.config).await
562    }
563}
564
565impl Default for EmbeddingEngineBuilder {
566    fn default() -> Self {
567        Self::new()
568    }
569}
570
571#[cfg(test)]
572mod tests {
573    use super::*;
574
575    #[test]
576    fn test_estimate_time() {
577        let config = ModelConfig::new(EmbeddingModel::MiniLM);
578        // Can't fully test without loading model, but we can test the estimation logic
579        let tokens_per_second = config.model.tokens_per_second_cpu() as f64;
580        assert!(tokens_per_second > 0.0);
581    }
582
583    #[test]
584    fn test_builder() {
585        let builder = EmbeddingEngineBuilder::new()
586            .model(EmbeddingModel::BgeSmall)
587            .max_batch_size(64)
588            .use_gpu(false);
589
590        assert_eq!(builder.config.model, EmbeddingModel::BgeSmall);
591        assert_eq!(builder.config.max_batch_size, 64);
592        assert!(!builder.config.use_gpu);
593    }
594}