Skip to main content

mnemo_core/embedding/
onnx.rs

1//! ONNX Runtime local embedding provider.
2//!
3//! Provides local embedding inference using ONNX Runtime, eliminating the
4//! need for an external API. Supports sentence-transformer models such as
5//! `all-MiniLM-L6-v2` exported to ONNX format.
6//!
7//! # Feature gating
8//!
9//! When compiled **without** the `onnx` feature the module provides a stub
10//! that validates the model path but returns [`Error::Embedding`] from
11//! `embed()` and `embed_batch()`.
12//!
13//! When compiled **with** the `onnx` feature the module loads the ONNX
14//! session and a HuggingFace tokenizer, then performs real local inference
15//! with mean-pooling and L2 normalisation.
16//!
17//! ```toml
18//! [features]
19//! onnx = ["dep:ort", "dep:tokenizers", "dep:ndarray"]
20//!
21//! [dependencies]
22//! ort = { version = "2", optional = true }
23//! tokenizers = { version = "0.21", optional = true, default-features = false }
24//! ndarray = { version = "0.16", optional = true }
25//! ```
26//!
27//! # Example (stub)
28//!
29//! ```rust,no_run
30//! use mnemo_core::embedding::onnx::OnnxEmbedding;
31//! use mnemo_core::embedding::EmbeddingProvider;
32//!
33//! // Will succeed only if the path exists on disk.
34//! let provider = OnnxEmbedding::new("/models/all-MiniLM-L6-v2.onnx", 384)
35//!     .expect("model path must exist");
36//!
37//! assert_eq!(provider.dimensions(), 384);
38//! assert_eq!(provider.model_path(), "/models/all-MiniLM-L6-v2.onnx");
39//! ```
40
41use crate::embedding::EmbeddingProvider;
42use crate::error::{Error, Result};
43
44// ---------------------------------------------------------------------------
45// Real implementation (feature = "onnx")
46// ---------------------------------------------------------------------------
47#[cfg(feature = "onnx")]
48mod inner {
49    use super::*;
50    use ndarray::Array2;
51    use ort::Session;
52    use std::path::Path;
53    use std::sync::Arc;
54    use tokenizers::Tokenizer;
55
56    /// ONNX-based local embedding provider.
57    ///
58    /// Wraps an ONNX sentence-transformer model (e.g. `all-MiniLM-L6-v2`)
59    /// together with a HuggingFace tokenizer for on-device vector generation.
60    pub struct OnnxEmbedding {
61        dimensions: usize,
62        model_path: String,
63        session: Arc<Session>,
64        tokenizer: Arc<Tokenizer>,
65    }
66
67    // `ort::Session` is Send + Sync in ort v2.
68    // `tokenizers::Tokenizer` is Send + Sync.
69    // The Arc wrappers enable cheap cloning for spawn_blocking moves.
70
71    // Manual Debug because Session/Tokenizer do not implement Debug.
72    impl std::fmt::Debug for OnnxEmbedding {
73        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74            f.debug_struct("OnnxEmbedding")
75                .field("dimensions", &self.dimensions)
76                .field("model_path", &self.model_path)
77                .finish_non_exhaustive()
78        }
79    }
80
81    impl OnnxEmbedding {
82        /// Create a new ONNX embedding provider from a model path.
83        ///
84        /// The model should be an ONNX file for a sentence-transformer model
85        /// (e.g. `all-MiniLM-L6-v2` exported to ONNX format).
86        ///
87        /// A `tokenizer.json` file **must** exist in the same directory as the
88        /// model file. This is the standard layout produced by
89        /// `optimum-cli export onnx` or manual HuggingFace model export.
90        ///
91        /// # Errors
92        ///
93        /// Returns [`Error::Validation`] if the model file does not exist.
94        /// Returns [`Error::Embedding`] if the ONNX session or tokenizer
95        /// fails to load.
96        pub fn new(model_path: &str, dimensions: usize) -> Result<Self> {
97            let model = Path::new(model_path);
98            if !model.exists() {
99                return Err(Error::Validation(format!(
100                    "ONNX model not found at: {model_path}"
101                )));
102            }
103
104            // Locate tokenizer.json next to the model file.
105            let tokenizer_path = model
106                .parent()
107                .map(|p| p.join("tokenizer.json"))
108                .unwrap_or_else(|| Path::new("tokenizer.json").to_path_buf());
109
110            if !tokenizer_path.exists() {
111                return Err(Error::Embedding(format!(
112                    "tokenizer.json not found next to ONNX model (expected at {})",
113                    tokenizer_path.display()
114                )));
115            }
116
117            let session = Session::builder()
118                .map_err(|e| {
119                    Error::Embedding(format!("failed to create ONNX session builder: {e}"))
120                })?
121                .with_intra_threads(4)
122                .map_err(|e| Error::Embedding(format!("failed to set intra threads: {e}")))?
123                .commit_from_file(model_path)
124                .map_err(|e| Error::Embedding(format!("failed to load ONNX model: {e}")))?;
125
126            let tokenizer = Tokenizer::from_file(&tokenizer_path)
127                .map_err(|e| Error::Embedding(format!("failed to load tokenizer: {e}")))?;
128
129            Ok(Self {
130                dimensions,
131                model_path: model_path.to_string(),
132                session: Arc::new(session),
133                tokenizer: Arc::new(tokenizer),
134            })
135        }
136
137        /// Get the model path.
138        #[must_use]
139        pub fn model_path(&self) -> &str {
140            &self.model_path
141        }
142
143        /// Tokenize a batch of texts and return (input_ids, attention_mask,
144        /// token_type_ids) as 2-D i64 arrays with shape `[batch, max_len]`.
145        fn tokenize_batch(
146            tokenizer: &Tokenizer,
147            texts: &[&str],
148        ) -> Result<(Array2<i64>, Array2<i64>, Array2<i64>)> {
149            let encodings = tokenizer
150                .encode_batch(texts.to_vec(), true)
151                .map_err(|e| Error::Embedding(format!("tokenization failed: {e}")))?;
152
153            let batch_size = encodings.len();
154            let max_len = encodings
155                .iter()
156                .map(|e| e.get_ids().len())
157                .max()
158                .unwrap_or(0);
159
160            let mut input_ids = Array2::<i64>::zeros((batch_size, max_len));
161            let mut attention_mask = Array2::<i64>::zeros((batch_size, max_len));
162            let mut token_type_ids = Array2::<i64>::zeros((batch_size, max_len));
163
164            for (i, enc) in encodings.iter().enumerate() {
165                for (j, &id) in enc.get_ids().iter().enumerate() {
166                    input_ids[[i, j]] = i64::from(id);
167                }
168                for (j, &mask) in enc.get_attention_mask().iter().enumerate() {
169                    attention_mask[[i, j]] = i64::from(mask);
170                }
171                for (j, &tid) in enc.get_type_ids().iter().enumerate() {
172                    token_type_ids[[i, j]] = i64::from(tid);
173                }
174            }
175
176            Ok((input_ids, attention_mask, token_type_ids))
177        }
178
179        /// Mean-pool the last hidden state over the token dimension, weighted
180        /// by the attention mask, then L2-normalise each vector.
181        fn mean_pool_and_normalize(
182            hidden: &Array2<f32>,
183            mask: &Array2<i64>,
184            batch_size: usize,
185            seq_len: usize,
186            hidden_dim: usize,
187        ) -> Vec<Vec<f32>> {
188            // hidden shape: [batch * seq_len, hidden_dim] (flattened) OR
189            // we receive it already as [batch, hidden_dim] after manual pooling.
190            // We handle the [batch, seq_len, hidden_dim] case by reshaping.
191            let _ = seq_len; // used only for the assertion below
192
193            let mut results = Vec::with_capacity(batch_size);
194
195            for i in 0..batch_size {
196                let mut pooled = vec![0.0f32; hidden_dim];
197                let mut count = 0.0f32;
198
199                for j in 0..seq_len {
200                    let m = mask[[i, j]] as f32;
201                    if m > 0.0 {
202                        for k in 0..hidden_dim {
203                            pooled[k] += hidden[[i * seq_len + j, k]] * m;
204                        }
205                        count += m;
206                    }
207                }
208
209                if count > 0.0 {
210                    for v in &mut pooled {
211                        *v /= count;
212                    }
213                }
214
215                // L2 normalise
216                let norm: f32 = pooled.iter().map(|x| x * x).sum::<f32>().sqrt();
217                if norm > 0.0 {
218                    for v in &mut pooled {
219                        *v /= norm;
220                    }
221                }
222
223                results.push(pooled);
224            }
225
226            results
227        }
228
229        /// Run inference on a batch of texts. This is the shared
230        /// implementation used by both `embed` and `embed_batch`.
231        async fn run_inference(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
232            if texts.is_empty() {
233                return Ok(Vec::new());
234            }
235
236            let session = Arc::clone(&self.session);
237            let tokenizer = Arc::clone(&self.tokenizer);
238            let dims = self.dimensions;
239            let owned_texts: Vec<String> = texts.iter().map(|t| (*t).to_string()).collect();
240
241            let result = tokio::task::spawn_blocking(move || -> Result<Vec<Vec<f32>>> {
242                let text_refs: Vec<&str> = owned_texts.iter().map(String::as_str).collect();
243                let (input_ids, attention_mask, token_type_ids) =
244                    Self::tokenize_batch(&tokenizer, &text_refs)?;
245
246                let batch_size = input_ids.nrows();
247                let seq_len = input_ids.ncols();
248
249                let outputs = session
250                    .run(ort::inputs![
251                        "input_ids" => input_ids.view(),
252                        "attention_mask" => attention_mask.view(),
253                        "token_type_ids" => token_type_ids.view(),
254                    ].map_err(|e| Error::Embedding(format!("failed to create inputs: {e}")))?)
255                    .map_err(|e| Error::Embedding(format!("ONNX inference failed: {e}")))?;
256
257                // Sentence-transformer models typically output
258                // "last_hidden_state" at index 0 with shape
259                // [batch, seq_len, hidden_dim].
260                let output_tensor = outputs
261                    .get("last_hidden_state")
262                    .or_else(|| outputs.iter().next().map(|(_, v)| v))
263                    .ok_or_else(|| Error::Embedding("no output tensor from ONNX model".to_string()))?;
264
265                let output_array = output_tensor
266                    .try_extract_tensor::<f32>()
267                    .map_err(|e| Error::Embedding(format!("failed to extract output tensor: {e}")))?;
268
269                let shape = output_array.shape();
270
271                // Handle different output shapes:
272                // - [batch, seq_len, hidden_dim]: needs mean-pooling
273                // - [batch, hidden_dim]: already pooled (e.g. sentence_embedding output)
274                if shape.len() == 3 {
275                    let hidden_dim = shape[2];
276                    if hidden_dim != dims {
277                        return Err(Error::Embedding(format!(
278                            "model hidden dim ({hidden_dim}) does not match configured dimensions ({dims})"
279                        )));
280                    }
281
282                    // Reshape to [batch * seq_len, hidden_dim] for pooling
283                    let flat = output_array
284                        .to_shape((batch_size * seq_len, hidden_dim))
285                        .map_err(|e| Error::Embedding(format!("reshape failed: {e}")))?;
286
287                    let flat_owned: Array2<f32> = flat.to_owned();
288                    Ok(Self::mean_pool_and_normalize(
289                        &flat_owned,
290                        &attention_mask,
291                        batch_size,
292                        seq_len,
293                        hidden_dim,
294                    ))
295                } else if shape.len() == 2 {
296                    // Already pooled output [batch, hidden_dim]
297                    let hidden_dim = shape[1];
298                    if hidden_dim != dims {
299                        return Err(Error::Embedding(format!(
300                            "model hidden dim ({hidden_dim}) does not match configured dimensions ({dims})"
301                        )));
302                    }
303
304                    let mut results = Vec::with_capacity(batch_size);
305                    for i in 0..batch_size {
306                        let mut vec = Vec::with_capacity(hidden_dim);
307                        for j in 0..hidden_dim {
308                            vec.push(output_array[[i, j]]);
309                        }
310                        // L2 normalise
311                        let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
312                        if norm > 0.0 {
313                            for v in &mut vec {
314                                *v /= norm;
315                            }
316                        }
317                        results.push(vec);
318                    }
319                    Ok(results)
320                } else {
321                    Err(Error::Embedding(format!(
322                        "unexpected output tensor shape: {shape:?}"
323                    )))
324                }
325            })
326            .await
327            .map_err(|e| Error::Embedding(format!("inference task panicked: {e}")))?;
328
329            result
330        }
331    }
332
333    #[async_trait::async_trait]
334    impl EmbeddingProvider for OnnxEmbedding {
335        /// Generate an embedding vector for a single text input.
336        ///
337        /// Tokenizes the input, runs ONNX inference, applies mean-pooling
338        /// weighted by the attention mask, and L2-normalises the result.
339        ///
340        /// # Errors
341        ///
342        /// Returns [`Error::Embedding`] if tokenization or inference fails.
343        async fn embed(&self, text: &str) -> Result<Vec<f32>> {
344            let mut results = self.run_inference(&[text]).await?;
345            results
346                .pop()
347                .ok_or_else(|| Error::Embedding("empty inference result".to_string()))
348        }
349
350        /// Generate embedding vectors for a batch of text inputs.
351        ///
352        /// Processes all texts in a single batched ONNX inference call for
353        /// maximum throughput.
354        ///
355        /// # Errors
356        ///
357        /// Returns [`Error::Embedding`] if tokenization or inference fails.
358        async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
359            self.run_inference(texts).await
360        }
361
362        fn dimensions(&self) -> usize {
363            self.dimensions
364        }
365    }
366}
367
368// ---------------------------------------------------------------------------
369// Stub implementation (no onnx feature)
370// ---------------------------------------------------------------------------
371#[cfg(not(feature = "onnx"))]
372mod inner {
373    use super::*;
374
375    /// ONNX-based local embedding provider.
376    ///
377    /// Wraps an ONNX sentence-transformer model for on-device vector generation.
378    /// When the `onnx` feature is not enabled, `embed` and `embed_batch` return
379    /// an [`Error::Embedding`] explaining how to enable full inference.
380    #[derive(Debug)]
381    pub struct OnnxEmbedding {
382        dimensions: usize,
383        model_path: String,
384        // In a full implementation, this would hold:
385        // session: ort::Session,
386        // tokenizer: tokenizers::Tokenizer,
387    }
388
389    impl OnnxEmbedding {
390        /// Create a new ONNX embedding provider from a model path.
391        ///
392        /// The model should be an ONNX sentence-transformer model
393        /// (e.g., `all-MiniLM-L6-v2` exported to ONNX format).
394        ///
395        /// # Errors
396        ///
397        /// Returns [`Error::Validation`] if the file at `model_path` does not
398        /// exist on disk.
399        pub fn new(model_path: &str, dimensions: usize) -> Result<Self> {
400            if !std::path::Path::new(model_path).exists() {
401                return Err(Error::Validation(format!(
402                    "ONNX model not found at: {model_path}"
403                )));
404            }
405            Ok(Self {
406                dimensions,
407                model_path: model_path.to_string(),
408            })
409        }
410
411        /// Get the model path.
412        #[must_use]
413        pub fn model_path(&self) -> &str {
414            &self.model_path
415        }
416    }
417
418    #[async_trait::async_trait]
419    impl EmbeddingProvider for OnnxEmbedding {
420        /// Generate an embedding vector for a single text input.
421        ///
422        /// # Errors
423        ///
424        /// Currently returns [`Error::Embedding`] because full ONNX Runtime
425        /// inference requires the `onnx` feature (with `ort`, `tokenizers`,
426        /// and `ndarray` crates).
427        async fn embed(&self, _text: &str) -> Result<Vec<f32>> {
428            Err(Error::Embedding(
429                "ONNX Runtime not available: compile with full onnx dependencies \
430                 (ort, tokenizers, ndarray) to enable local inference"
431                    .to_string(),
432            ))
433        }
434
435        /// Generate embedding vectors for a batch of text inputs.
436        ///
437        /// # Errors
438        ///
439        /// Currently returns [`Error::Embedding`] because full ONNX Runtime
440        /// inference requires the `onnx` feature (with `ort`, `tokenizers`,
441        /// and `ndarray` crates).
442        async fn embed_batch(&self, _texts: &[&str]) -> Result<Vec<Vec<f32>>> {
443            Err(Error::Embedding(
444                "ONNX Runtime not available: compile with full onnx dependencies \
445                 (ort, tokenizers, ndarray) to enable local inference"
446                    .to_string(),
447            ))
448        }
449
450        fn dimensions(&self) -> usize {
451            self.dimensions
452        }
453    }
454}
455
456// Re-export `OnnxEmbedding` from the active inner module so that
457// downstream code can use `crate::embedding::onnx::OnnxEmbedding`
458// regardless of the feature flag.
459pub use inner::OnnxEmbedding;
460
461#[cfg(test)]
462mod tests {
463    use super::*;
464
465    #[test]
466    fn test_onnx_missing_model() {
467        let result = OnnxEmbedding::new("/nonexistent/path/model.onnx", 384);
468        assert!(result.is_err());
469        let err = result.unwrap_err();
470        let msg = err.to_string();
471        assert!(
472            msg.contains("ONNX model not found"),
473            "unexpected error message: {msg}"
474        );
475    }
476
477    #[test]
478    fn test_onnx_dimensions() {
479        // Use Cargo.toml as a stand-in file that is guaranteed to exist.
480        let path = concat!(env!("CARGO_MANIFEST_DIR"), "/Cargo.toml");
481        #[cfg(not(feature = "onnx"))]
482        {
483            let provider = OnnxEmbedding::new(path, 384).expect("file should exist");
484            assert_eq!(provider.dimensions(), 384);
485        }
486        // When the onnx feature is on, construction also requires
487        // tokenizer.json, so we only test that the path validation
488        // passes for the stub variant.
489        #[cfg(feature = "onnx")]
490        {
491            // Without a tokenizer.json next to Cargo.toml, we expect an
492            // embedding error rather than a validation error.
493            let result = OnnxEmbedding::new(path, 384);
494            assert!(result.is_err());
495            let msg = result.unwrap_err().to_string();
496            assert!(
497                msg.contains("tokenizer.json"),
498                "expected tokenizer.json error, got: {msg}"
499            );
500        }
501    }
502
503    #[test]
504    fn test_onnx_model_path() {
505        let path = concat!(env!("CARGO_MANIFEST_DIR"), "/Cargo.toml");
506        #[cfg(not(feature = "onnx"))]
507        {
508            let provider = OnnxEmbedding::new(path, 768).expect("file should exist");
509            assert_eq!(provider.model_path(), path);
510        }
511        #[cfg(feature = "onnx")]
512        {
513            let result = OnnxEmbedding::new(path, 768);
514            assert!(result.is_err());
515        }
516    }
517
518    #[cfg(not(feature = "onnx"))]
519    #[tokio::test]
520    async fn test_onnx_embed_returns_error_without_runtime() {
521        let path = concat!(env!("CARGO_MANIFEST_DIR"), "/Cargo.toml");
522        let provider = OnnxEmbedding::new(path, 384).expect("file should exist");
523        let result = provider.embed("hello world").await;
524        assert!(result.is_err());
525        let msg = result.unwrap_err().to_string();
526        assert!(
527            msg.contains("ONNX Runtime not available"),
528            "unexpected error: {msg}"
529        );
530    }
531
532    #[cfg(not(feature = "onnx"))]
533    #[tokio::test]
534    async fn test_onnx_embed_batch_returns_error_without_runtime() {
535        let path = concat!(env!("CARGO_MANIFEST_DIR"), "/Cargo.toml");
536        let provider = OnnxEmbedding::new(path, 384).expect("file should exist");
537        let result = provider.embed_batch(&["a", "b"]).await;
538        assert!(result.is_err());
539        let msg = result.unwrap_err().to_string();
540        assert!(
541            msg.contains("ONNX Runtime not available"),
542            "unexpected error: {msg}"
543        );
544    }
545}