Skip to main content

cognee_embedding/
onnx.rs

1use async_trait::async_trait;
2use ort::session::{Session, builder::GraphOptimizationLevel};
3use ort::value::Tensor;
4use std::sync::{Arc, Mutex};
5use tokenizers::Tokenizer;
6use tracing::info;
7
8use crate::{
9    config::OnnxEmbeddingConfig,
10    download::{ModelUrls, ensure_model_exists, ensure_tokenizer_exists},
11    engine::EmbeddingEngine,
12    error::{EmbeddingError, EmbeddingResult},
13    utils::{l2_normalize, mean_pool},
14};
15/// Type alias for tokenization batch results
16type TokenizationBatch = (Vec<Vec<i64>>, Vec<Vec<i64>>);
17/// ONNX-based embedding engine for local inference
18///
19/// Wraps ONNX Runtime session and HuggingFace tokenizer.
20/// Based on examples/embeddings.rs with proper tokenization for Python parity.
21pub struct OnnxEmbeddingEngine {
22    session: Arc<Mutex<Session>>,
23    tokenizer: Arc<Mutex<Tokenizer>>,
24    config: OnnxEmbeddingConfig,
25}
26
27impl OnnxEmbeddingEngine {
28    /// Create a new ONNX embedding engine
29    ///
30    /// Initializes ONNX Runtime, loads the model, and downloads/caches the tokenizer.
31    ///
32    /// # Arguments
33    /// * `config` - Engine configuration with model path and tokenizer model ID
34    ///
35    /// # Returns
36    /// * Initialized engine or error
37    ///
38    /// # Errors
39    /// * Returns error if model file not found, ONNX Runtime init fails, or tokenizer download fails
40    ///
41    /// # Example
42    /// ```ignore
43    /// let config = OnnxEmbeddingConfig::bge_small("./target/models");
44    /// let engine = OnnxEmbeddingEngine::new(config)?;
45    /// ```
46    pub fn new(config: OnnxEmbeddingConfig) -> EmbeddingResult<Self> {
47        ort::init().commit();
48
49        if !config.model_path.exists() {
50            return Err(EmbeddingError::ModelLoadError(format!(
51                "Model file not found: {:?}",
52                config.model_path
53            )));
54        }
55
56        info!("Loading tokenizer: {:?}", config.tokenizer_path);
57        let tokenizer = Tokenizer::from_file(&config.tokenizer_path).map_err(|e| {
58            EmbeddingError::TokenizerError(format!(
59                "Failed to load tokenizer from {:?}: {}. Please ensure tokenizer.json file exists.",
60                config.tokenizer_path, e
61            ))
62        })?;
63
64        info!("Loading ONNX model: {:?}", config.model_path);
65        let session = Session::builder()
66            .map_err(|e| EmbeddingError::ModelLoadError(e.to_string()))?
67            .with_optimization_level(GraphOptimizationLevel::Level3)
68            .map_err(|e| EmbeddingError::ModelLoadError(e.to_string()))?
69            .commit_from_file(&config.model_path)
70            .map_err(|e| EmbeddingError::ModelLoadError(e.to_string()))?;
71
72        info!(
73            "✓ Loaded {} (dims: {}, max_seq_len: {})",
74            config.model_name, config.dimensions, config.max_sequence_length
75        );
76
77        Ok(Self {
78            session: Arc::new(Mutex::new(session)),
79            tokenizer: Arc::new(Mutex::new(tokenizer)),
80            config,
81        })
82    }
83
84    /// Create a new ONNX embedding engine with automatic model downloading
85    ///
86    /// Downloads model and tokenizer from HuggingFace Hub if not found locally.
87    /// This is the recommended constructor for most use cases.
88    ///
89    /// # Arguments
90    /// * `config` - Engine configuration with model path and tokenizer model ID
91    ///
92    /// # Returns
93    /// * Initialized engine or error
94    ///
95    /// # Errors
96    /// * Returns error if download fails, ONNX Runtime init fails, or tokenizer load fails
97    ///
98    /// # Example
99    /// ```ignore
100    /// let config = OnnxEmbeddingConfig::bge_small("./target/models");
101    /// let engine = OnnxEmbeddingEngine::with_auto_download(config).await?;
102    /// ```
103    pub async fn with_auto_download(config: OnnxEmbeddingConfig) -> EmbeddingResult<Self> {
104        let (model_url, tokenizer_url) = match config.model_name.to_lowercase().as_str() {
105            "bge-small-en-v1.5" | "bge-small-v1.5" => (
106                ModelUrls::BGE_SMALL.model_url,
107                ModelUrls::BGE_SMALL.tokenizer_url,
108            ),
109            "all-minilm-l6-v2" => (
110                ModelUrls::MINILM_L6.model_url,
111                ModelUrls::MINILM_L6.tokenizer_url,
112            ),
113            _ => {
114                return Err(EmbeddingError::ModelLoadError(format!(
115                    "Unknown model name '{}'. Supported: 'bge-small-en-v1.5', 'all-MiniLM-L6-v2'",
116                    config.model_name
117                )));
118            }
119        };
120
121        let model_downloaded = ensure_model_exists(&config.model_path, model_url).await?;
122        if model_downloaded {
123            info!("✓ Downloaded model to {:?}", config.model_path);
124        }
125
126        let tokenizer_downloaded =
127            ensure_tokenizer_exists(&config.tokenizer_path, tokenizer_url).await?;
128        if tokenizer_downloaded {
129            info!("✓ Downloaded tokenizer to {:?}", config.tokenizer_path);
130        }
131
132        Self::new(config)
133    }
134
135    /// Tokenize a batch of texts using HuggingFace tokenizer
136    ///
137    /// Uses proper BPE/WordPiece tokenization matching Python fastembed.
138    ///
139    /// # Arguments
140    /// * `texts` - Texts to tokenize
141    ///
142    /// # Returns
143    /// * Tuple of (input_ids, attention_mask) tensors, both with shape [batch_size, max_seq_len]
144    fn tokenize_batch(&self, texts: &[&str]) -> EmbeddingResult<TokenizationBatch> {
145        #[allow(clippy::unwrap_used, reason = "lock poison is unrecoverable")]
146        let tokenizer = self.tokenizer.lock().unwrap(); // lock poison is unrecoverable
147        let max_len = self.config.max_sequence_length;
148
149        let mut input_ids_batch = Vec::new();
150        let mut attention_mask_batch = Vec::new();
151
152        for text in texts {
153            let encoding = tokenizer
154                .encode(*text, true)
155                .map_err(|e| EmbeddingError::TokenizerError(e.to_string()))?;
156
157            let mut ids = encoding
158                .get_ids()
159                .iter()
160                .map(|&id| id as i64)
161                .collect::<Vec<_>>();
162            let mut mask = encoding
163                .get_attention_mask()
164                .iter()
165                .map(|&m| m as i64)
166                .collect::<Vec<_>>();
167
168            if ids.len() > max_len {
169                ids.truncate(max_len);
170                mask.truncate(max_len);
171            }
172
173            while ids.len() < max_len {
174                ids.push(0); // [PAD] token
175                mask.push(0); // Padding mask
176            }
177
178            input_ids_batch.push(ids);
179            attention_mask_batch.push(mask);
180        }
181
182        Ok((input_ids_batch, attention_mask_batch))
183    }
184
185    /// Extract embedding from ONNX output tensor
186    ///
187    /// Handles both 2D [seq_len, hidden_dim] and 3D [batch_size, seq_len, hidden_dim] outputs.
188    fn extract_embedding(
189        &self,
190        output_data: &[f32],
191        output_shape: &[usize],
192        attention_mask: &[i64],
193    ) -> EmbeddingResult<Vec<f32>> {
194        let output_dim = self.config.dimensions;
195
196        if output_shape.len() == 3 {
197            let seq_len = output_shape[1];
198            let hidden_dim = output_shape[2];
199
200            let pooled = mean_pool(output_data, seq_len, hidden_dim, attention_mask, output_dim);
201            Ok(l2_normalize(&pooled))
202        } else if output_shape.len() == 2 {
203            let embedding: Vec<f32> = output_data.iter().take(output_dim).copied().collect();
204            Ok(l2_normalize(&embedding))
205        } else {
206            Err(EmbeddingError::InferenceError(format!(
207                "Unexpected output shape: {output_shape:?}"
208            )))
209        }
210    }
211}
212
213impl OnnxEmbeddingEngine {
214    /// Run ONNX inference over a SINGLE bounded batch of `texts`.
215    ///
216    /// A transformer's activation memory scales with `batch × seq_len`
217    /// (attention is `batch × heads × seq_len²`), so the batch must stay small.
218    /// The public [`OnnxEmbeddingEngine::embed`] splits large inputs into
219    /// `config.batch_size` chunks before calling this — never pass an unbounded
220    /// slice here.
221    async fn embed_batch(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
222        if texts.is_empty() {
223            return Ok(Vec::new());
224        }
225
226        let batch_size = texts.len();
227        let seq_len = self.config.max_sequence_length;
228
229        let (input_ids_batch, attention_mask_batch) = self.tokenize_batch(texts)?;
230
231        let input_ids_flat: Vec<i64> = input_ids_batch.iter().flatten().copied().collect();
232        let attention_mask_flat: Vec<i64> =
233            attention_mask_batch.iter().flatten().copied().collect();
234
235        let input_ids_tensor = Tensor::from_array((vec![batch_size, seq_len], input_ids_flat))
236            .map_err(|e| EmbeddingError::InferenceError(e.to_string()))?;
237        let attention_mask_tensor =
238            Tensor::from_array((vec![batch_size, seq_len], attention_mask_flat))
239                .map_err(|e| EmbeddingError::InferenceError(e.to_string()))?;
240        let token_type_ids_tensor =
241            Tensor::from_array((vec![batch_size, seq_len], vec![0i64; batch_size * seq_len]))
242                .map_err(|e| EmbeddingError::InferenceError(e.to_string()))?;
243
244        let session = Arc::clone(&self.session);
245        let attention_masks = attention_mask_batch.clone();
246
247        let (output_shape, output_data) = tokio::task::spawn_blocking(move || {
248            #[allow(clippy::unwrap_used, reason = "lock poison is unrecoverable")]
249            let mut session = session.lock().unwrap(); // lock poison is unrecoverable
250            let outputs = session.run(ort::inputs! {
251                "input_ids" => input_ids_tensor,
252                "attention_mask" => attention_mask_tensor,
253                "token_type_ids" => token_type_ids_tensor,
254            })?;
255
256            let (shape, data) = outputs[0].try_extract_tensor::<f32>()?;
257            let shape_usize: Vec<usize> = shape.iter().map(|&d| d as usize).collect();
258            Ok::<_, Box<dyn std::error::Error + Send + Sync>>((shape_usize, data.to_vec()))
259        })
260        .await
261        .map_err(|e| EmbeddingError::InferenceError(e.to_string()))?
262        .map_err(|e| EmbeddingError::InferenceError(e.to_string()))?;
263
264        let mut embeddings = Vec::with_capacity(batch_size);
265
266        if output_shape.len() == 3 {
267            let seq_len = output_shape[1];
268            let hidden_dim = output_shape[2];
269            let sample_size = seq_len * hidden_dim;
270
271            for (i, mask) in attention_masks.iter().enumerate().take(batch_size) {
272                let start = i * sample_size;
273                let end = start + sample_size;
274                let sample_data = &output_data[start..end];
275
276                let embedding =
277                    self.extract_embedding(sample_data, &[1, seq_len, hidden_dim], mask)?;
278
279                embeddings.push(embedding);
280            }
281        } else if output_shape.len() == 2 && batch_size == 1 {
282            let embedding =
283                self.extract_embedding(&output_data, &output_shape, &attention_masks[0])?;
284            embeddings.push(embedding);
285        } else {
286            return Err(EmbeddingError::InferenceError(format!(
287                "Unexpected output shape: {output_shape:?} for batch_size {batch_size}"
288            )));
289        }
290
291        Ok(embeddings)
292    }
293}
294
295#[async_trait]
296impl EmbeddingEngine for OnnxEmbeddingEngine {
297    /// Embed `texts`, splitting into `config.batch_size` sub-batches so the ONNX
298    /// session never receives an unbounded batch. A transformer's activation
299    /// memory scales with `batch × seq_len²`; embedding a whole corpus in one
300    /// call (several thousand chunks) would allocate tens of GB and OOM.
301    /// Sub-batching keeps peak memory flat regardless of how many texts are
302    /// passed.
303    async fn embed(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
304        if texts.is_empty() {
305            return Ok(Vec::new());
306        }
307        let batch = self.config.batch_size.max(1);
308        if texts.len() <= batch {
309            return self.embed_batch(texts).await;
310        }
311        let mut embeddings = Vec::with_capacity(texts.len());
312        for chunk in texts.chunks(batch) {
313            embeddings.extend(self.embed_batch(chunk).await?);
314        }
315        Ok(embeddings)
316    }
317
318    fn dimension(&self) -> usize {
319        self.config.dimensions
320    }
321
322    fn batch_size(&self) -> usize {
323        self.config.batch_size
324    }
325
326    fn max_sequence_length(&self) -> usize {
327        self.config.max_sequence_length
328    }
329}
330
331#[cfg(test)]
332#[allow(
333    clippy::unwrap_used,
334    clippy::expect_used,
335    reason = "test code — panics are acceptable failures"
336)]
337mod tests {
338    use super::*;
339
340    #[tokio::test]
341    async fn test_tokenization() {
342        // Test HuggingFace tokenizer loading from file
343        // This test will be skipped if tokenizer file doesn't exist
344        let tokenizer_path = "../../target/models/bge-small-tokenizer.json";
345        if std::path::Path::new(tokenizer_path).exists() {
346            let tokenizer = Tokenizer::from_file(tokenizer_path).expect("Failed to load tokenizer");
347
348            let encoding = tokenizer.encode("Hello world", true).unwrap();
349            let ids = encoding.get_ids();
350
351            assert!(!ids.is_empty());
352            assert_eq!(ids[0], 101); // [CLS] for BERT-based models
353        }
354    }
355
356    #[test]
357    fn test_l2_normalization() {
358        use crate::utils::{compute_norm, l2_normalize};
359
360        let vec = vec![3.0, 4.0];
361        let normalized = l2_normalize(&vec);
362        let norm = compute_norm(&normalized);
363
364        assert!((norm - 1.0).abs() < 0.001);
365    }
366
367    #[tokio::test]
368    async fn test_engine_creation() {
369        let config = OnnxEmbeddingConfig::default();
370        // Will fail if model not present - that's expected
371        let result = OnnxEmbeddingEngine::new(config);
372
373        // Test passes if error is clear about missing model
374        if let Err(e) = result {
375            assert!(
376                e.to_string().contains("Model file not found")
377                    || e.to_string().contains("tokenizer")
378            );
379        }
380    }
381
382    /// Regression test for the unbounded-batch OOM: `embed` must split inputs
383    /// larger than `config.batch_size` into sub-batches (so ONNX never sees a
384    /// giant `[N, seq_len]` tensor), while returning one embedding per input
385    /// that matches the single-batch result. Skips when the model artifacts
386    /// have not been downloaded.
387    #[tokio::test]
388    async fn embed_sub_batches_large_inputs() {
389        let model = "../../target/models/BGE-Small-v1.5-model_quantized.onnx";
390        let tok = "../../target/models/bge-small-tokenizer.json";
391        if !std::path::Path::new(model).exists() || !std::path::Path::new(tok).exists() {
392            return; // model not available in this environment — skip
393        }
394
395        let config = OnnxEmbeddingConfig {
396            model_path: model.into(),
397            tokenizer_path: tok.into(),
398            batch_size: 4, // force several sub-batches
399            ..Default::default()
400        };
401
402        let engine = OnnxEmbeddingEngine::new(config).expect("engine creation");
403
404        // 10 inputs with batch_size 4 → 3 sub-batches (4 + 4 + 2).
405        let texts: Vec<String> = (0..10).map(|i| format!("sentence number {i}")).collect();
406        let refs: Vec<&str> = texts.iter().map(String::as_str).collect();
407
408        let chunked = engine.embed(&refs).await.expect("embed");
409        assert_eq!(
410            chunked.len(),
411            10,
412            "one embedding per input across sub-batches"
413        );
414        assert_eq!(chunked[0].len(), engine.dimension());
415
416        // Sub-batching must not change an embedding's meaning. (Exact equality
417        // can't be required: the quantized model selects batch-size-dependent
418        // kernels, so values differ by tiny numerical noise.) The L2-normalized
419        // vectors must stay essentially parallel — cosine similarity ≈ 1.
420        let single = engine.embed_batch(&refs).await.expect("embed_batch");
421        assert_eq!(single.len(), chunked.len());
422        for (a, b) in chunked.iter().zip(single.iter()) {
423            assert_eq!(a.len(), b.len());
424            let cos: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
425            assert!(cos > 0.999, "chunked vs single-batch diverged: cos={cos}");
426        }
427    }
428}