mnemo_core/embedding/onnx.rs
1//! ONNX Runtime local embedding provider.
2//!
3//! Provides local embedding inference using ONNX Runtime, eliminating the
4//! need for an external API. Supports sentence-transformer models such as
5//! `all-MiniLM-L6-v2` exported to ONNX format.
6//!
7//! # Feature gating
8//!
9//! When compiled **without** the `onnx` feature the module provides a stub
10//! that validates the model path but returns [`Error::Embedding`] from
11//! `embed()` and `embed_batch()`.
12//!
13//! When compiled **with** the `onnx` feature the module loads the ONNX
14//! session and a HuggingFace tokenizer, then performs real local inference
15//! with mean-pooling and L2 normalisation.
16//!
17//! ```toml
18//! [features]
19//! onnx = ["dep:ort", "dep:tokenizers", "dep:ndarray"]
20//!
21//! [dependencies]
22//! ort = { version = "2", optional = true }
23//! tokenizers = { version = "0.21", optional = true, default-features = false }
24//! ndarray = { version = "0.16", optional = true }
25//! ```
26//!
27//! # Example (stub)
28//!
29//! ```rust,no_run
30//! use mnemo_core::embedding::onnx::OnnxEmbedding;
31//! use mnemo_core::embedding::EmbeddingProvider;
32//!
33//! // Will succeed only if the path exists on disk.
34//! let provider = OnnxEmbedding::new("/models/all-MiniLM-L6-v2.onnx", 384)
35//! .expect("model path must exist");
36//!
37//! assert_eq!(provider.dimensions(), 384);
38//! assert_eq!(provider.model_path(), "/models/all-MiniLM-L6-v2.onnx");
39//! ```
40
41use crate::embedding::EmbeddingProvider;
42use crate::error::{Error, Result};
43
44// ---------------------------------------------------------------------------
45// Real implementation (feature = "onnx")
46// ---------------------------------------------------------------------------
47#[cfg(feature = "onnx")]
48mod inner {
49 use super::*;
50 use ndarray::Array2;
51 use ort::Session;
52 use std::path::Path;
53 use std::sync::Arc;
54 use tokenizers::Tokenizer;
55
56 /// ONNX-based local embedding provider.
57 ///
58 /// Wraps an ONNX sentence-transformer model (e.g. `all-MiniLM-L6-v2`)
59 /// together with a HuggingFace tokenizer for on-device vector generation.
60 pub struct OnnxEmbedding {
61 dimensions: usize,
62 model_path: String,
63 session: Arc<Session>,
64 tokenizer: Arc<Tokenizer>,
65 }
66
67 // `ort::Session` is Send + Sync in ort v2.
68 // `tokenizers::Tokenizer` is Send + Sync.
69 // The Arc wrappers enable cheap cloning for spawn_blocking moves.
70
71 // Manual Debug because Session/Tokenizer do not implement Debug.
72 impl std::fmt::Debug for OnnxEmbedding {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 f.debug_struct("OnnxEmbedding")
75 .field("dimensions", &self.dimensions)
76 .field("model_path", &self.model_path)
77 .finish_non_exhaustive()
78 }
79 }
80
81 impl OnnxEmbedding {
82 /// Create a new ONNX embedding provider from a model path.
83 ///
84 /// The model should be an ONNX file for a sentence-transformer model
85 /// (e.g. `all-MiniLM-L6-v2` exported to ONNX format).
86 ///
87 /// A `tokenizer.json` file **must** exist in the same directory as the
88 /// model file. This is the standard layout produced by
89 /// `optimum-cli export onnx` or manual HuggingFace model export.
90 ///
91 /// # Errors
92 ///
93 /// Returns [`Error::Validation`] if the model file does not exist.
94 /// Returns [`Error::Embedding`] if the ONNX session or tokenizer
95 /// fails to load.
96 pub fn new(model_path: &str, dimensions: usize) -> Result<Self> {
97 let model = Path::new(model_path);
98 if !model.exists() {
99 return Err(Error::Validation(format!(
100 "ONNX model not found at: {model_path}"
101 )));
102 }
103
104 // Locate tokenizer.json next to the model file.
105 let tokenizer_path = model
106 .parent()
107 .map(|p| p.join("tokenizer.json"))
108 .unwrap_or_else(|| Path::new("tokenizer.json").to_path_buf());
109
110 if !tokenizer_path.exists() {
111 return Err(Error::Embedding(format!(
112 "tokenizer.json not found next to ONNX model (expected at {})",
113 tokenizer_path.display()
114 )));
115 }
116
117 let session = Session::builder()
118 .map_err(|e| {
119 Error::Embedding(format!("failed to create ONNX session builder: {e}"))
120 })?
121 .with_intra_threads(4)
122 .map_err(|e| Error::Embedding(format!("failed to set intra threads: {e}")))?
123 .commit_from_file(model_path)
124 .map_err(|e| Error::Embedding(format!("failed to load ONNX model: {e}")))?;
125
126 let tokenizer = Tokenizer::from_file(&tokenizer_path)
127 .map_err(|e| Error::Embedding(format!("failed to load tokenizer: {e}")))?;
128
129 Ok(Self {
130 dimensions,
131 model_path: model_path.to_string(),
132 session: Arc::new(session),
133 tokenizer: Arc::new(tokenizer),
134 })
135 }
136
137 /// Get the model path.
138 #[must_use]
139 pub fn model_path(&self) -> &str {
140 &self.model_path
141 }
142
143 /// Tokenize a batch of texts and return (input_ids, attention_mask,
144 /// token_type_ids) as 2-D i64 arrays with shape `[batch, max_len]`.
145 fn tokenize_batch(
146 tokenizer: &Tokenizer,
147 texts: &[&str],
148 ) -> Result<(Array2<i64>, Array2<i64>, Array2<i64>)> {
149 let encodings = tokenizer
150 .encode_batch(texts.to_vec(), true)
151 .map_err(|e| Error::Embedding(format!("tokenization failed: {e}")))?;
152
153 let batch_size = encodings.len();
154 let max_len = encodings
155 .iter()
156 .map(|e| e.get_ids().len())
157 .max()
158 .unwrap_or(0);
159
160 let mut input_ids = Array2::<i64>::zeros((batch_size, max_len));
161 let mut attention_mask = Array2::<i64>::zeros((batch_size, max_len));
162 let mut token_type_ids = Array2::<i64>::zeros((batch_size, max_len));
163
164 for (i, enc) in encodings.iter().enumerate() {
165 for (j, &id) in enc.get_ids().iter().enumerate() {
166 input_ids[[i, j]] = i64::from(id);
167 }
168 for (j, &mask) in enc.get_attention_mask().iter().enumerate() {
169 attention_mask[[i, j]] = i64::from(mask);
170 }
171 for (j, &tid) in enc.get_type_ids().iter().enumerate() {
172 token_type_ids[[i, j]] = i64::from(tid);
173 }
174 }
175
176 Ok((input_ids, attention_mask, token_type_ids))
177 }
178
179 /// Mean-pool the last hidden state over the token dimension, weighted
180 /// by the attention mask, then L2-normalise each vector.
181 fn mean_pool_and_normalize(
182 hidden: &Array2<f32>,
183 mask: &Array2<i64>,
184 batch_size: usize,
185 seq_len: usize,
186 hidden_dim: usize,
187 ) -> Vec<Vec<f32>> {
188 // hidden shape: [batch * seq_len, hidden_dim] (flattened) OR
189 // we receive it already as [batch, hidden_dim] after manual pooling.
190 // We handle the [batch, seq_len, hidden_dim] case by reshaping.
191 let _ = seq_len; // used only for the assertion below
192
193 let mut results = Vec::with_capacity(batch_size);
194
195 for i in 0..batch_size {
196 let mut pooled = vec![0.0f32; hidden_dim];
197 let mut count = 0.0f32;
198
199 for j in 0..seq_len {
200 let m = mask[[i, j]] as f32;
201 if m > 0.0 {
202 for k in 0..hidden_dim {
203 pooled[k] += hidden[[i * seq_len + j, k]] * m;
204 }
205 count += m;
206 }
207 }
208
209 if count > 0.0 {
210 for v in &mut pooled {
211 *v /= count;
212 }
213 }
214
215 // L2 normalise
216 let norm: f32 = pooled.iter().map(|x| x * x).sum::<f32>().sqrt();
217 if norm > 0.0 {
218 for v in &mut pooled {
219 *v /= norm;
220 }
221 }
222
223 results.push(pooled);
224 }
225
226 results
227 }
228
229 /// Run inference on a batch of texts. This is the shared
230 /// implementation used by both `embed` and `embed_batch`.
231 async fn run_inference(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
232 if texts.is_empty() {
233 return Ok(Vec::new());
234 }
235
236 let session = Arc::clone(&self.session);
237 let tokenizer = Arc::clone(&self.tokenizer);
238 let dims = self.dimensions;
239 let owned_texts: Vec<String> = texts.iter().map(|t| (*t).to_string()).collect();
240
241 let result = tokio::task::spawn_blocking(move || -> Result<Vec<Vec<f32>>> {
242 let text_refs: Vec<&str> = owned_texts.iter().map(String::as_str).collect();
243 let (input_ids, attention_mask, token_type_ids) =
244 Self::tokenize_batch(&tokenizer, &text_refs)?;
245
246 let batch_size = input_ids.nrows();
247 let seq_len = input_ids.ncols();
248
249 let outputs = session
250 .run(ort::inputs![
251 "input_ids" => input_ids.view(),
252 "attention_mask" => attention_mask.view(),
253 "token_type_ids" => token_type_ids.view(),
254 ].map_err(|e| Error::Embedding(format!("failed to create inputs: {e}")))?)
255 .map_err(|e| Error::Embedding(format!("ONNX inference failed: {e}")))?;
256
257 // Sentence-transformer models typically output
258 // "last_hidden_state" at index 0 with shape
259 // [batch, seq_len, hidden_dim].
260 let output_tensor = outputs
261 .get("last_hidden_state")
262 .or_else(|| outputs.iter().next().map(|(_, v)| v))
263 .ok_or_else(|| Error::Embedding("no output tensor from ONNX model".to_string()))?;
264
265 let output_array = output_tensor
266 .try_extract_tensor::<f32>()
267 .map_err(|e| Error::Embedding(format!("failed to extract output tensor: {e}")))?;
268
269 let shape = output_array.shape();
270
271 // Handle different output shapes:
272 // - [batch, seq_len, hidden_dim]: needs mean-pooling
273 // - [batch, hidden_dim]: already pooled (e.g. sentence_embedding output)
274 if shape.len() == 3 {
275 let hidden_dim = shape[2];
276 if hidden_dim != dims {
277 return Err(Error::Embedding(format!(
278 "model hidden dim ({hidden_dim}) does not match configured dimensions ({dims})"
279 )));
280 }
281
282 // Reshape to [batch * seq_len, hidden_dim] for pooling
283 let flat = output_array
284 .to_shape((batch_size * seq_len, hidden_dim))
285 .map_err(|e| Error::Embedding(format!("reshape failed: {e}")))?;
286
287 let flat_owned: Array2<f32> = flat.to_owned();
288 Ok(Self::mean_pool_and_normalize(
289 &flat_owned,
290 &attention_mask,
291 batch_size,
292 seq_len,
293 hidden_dim,
294 ))
295 } else if shape.len() == 2 {
296 // Already pooled output [batch, hidden_dim]
297 let hidden_dim = shape[1];
298 if hidden_dim != dims {
299 return Err(Error::Embedding(format!(
300 "model hidden dim ({hidden_dim}) does not match configured dimensions ({dims})"
301 )));
302 }
303
304 let mut results = Vec::with_capacity(batch_size);
305 for i in 0..batch_size {
306 let mut vec = Vec::with_capacity(hidden_dim);
307 for j in 0..hidden_dim {
308 vec.push(output_array[[i, j]]);
309 }
310 // L2 normalise
311 let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
312 if norm > 0.0 {
313 for v in &mut vec {
314 *v /= norm;
315 }
316 }
317 results.push(vec);
318 }
319 Ok(results)
320 } else {
321 Err(Error::Embedding(format!(
322 "unexpected output tensor shape: {shape:?}"
323 )))
324 }
325 })
326 .await
327 .map_err(|e| Error::Embedding(format!("inference task panicked: {e}")))?;
328
329 result
330 }
331 }
332
333 #[async_trait::async_trait]
334 impl EmbeddingProvider for OnnxEmbedding {
335 /// Generate an embedding vector for a single text input.
336 ///
337 /// Tokenizes the input, runs ONNX inference, applies mean-pooling
338 /// weighted by the attention mask, and L2-normalises the result.
339 ///
340 /// # Errors
341 ///
342 /// Returns [`Error::Embedding`] if tokenization or inference fails.
343 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
344 let mut results = self.run_inference(&[text]).await?;
345 results
346 .pop()
347 .ok_or_else(|| Error::Embedding("empty inference result".to_string()))
348 }
349
350 /// Generate embedding vectors for a batch of text inputs.
351 ///
352 /// Processes all texts in a single batched ONNX inference call for
353 /// maximum throughput.
354 ///
355 /// # Errors
356 ///
357 /// Returns [`Error::Embedding`] if tokenization or inference fails.
358 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
359 self.run_inference(texts).await
360 }
361
362 fn dimensions(&self) -> usize {
363 self.dimensions
364 }
365 }
366}
367
368// ---------------------------------------------------------------------------
369// Stub implementation (no onnx feature)
370// ---------------------------------------------------------------------------
371#[cfg(not(feature = "onnx"))]
372mod inner {
373 use super::*;
374
375 /// ONNX-based local embedding provider.
376 ///
377 /// Wraps an ONNX sentence-transformer model for on-device vector generation.
378 /// When the `onnx` feature is not enabled, `embed` and `embed_batch` return
379 /// an [`Error::Embedding`] explaining how to enable full inference.
380 #[derive(Debug)]
381 pub struct OnnxEmbedding {
382 dimensions: usize,
383 model_path: String,
384 // In a full implementation, this would hold:
385 // session: ort::Session,
386 // tokenizer: tokenizers::Tokenizer,
387 }
388
389 impl OnnxEmbedding {
390 /// Create a new ONNX embedding provider from a model path.
391 ///
392 /// The model should be an ONNX sentence-transformer model
393 /// (e.g., `all-MiniLM-L6-v2` exported to ONNX format).
394 ///
395 /// # Errors
396 ///
397 /// Returns [`Error::Validation`] if the file at `model_path` does not
398 /// exist on disk.
399 pub fn new(model_path: &str, dimensions: usize) -> Result<Self> {
400 if !std::path::Path::new(model_path).exists() {
401 return Err(Error::Validation(format!(
402 "ONNX model not found at: {model_path}"
403 )));
404 }
405 Ok(Self {
406 dimensions,
407 model_path: model_path.to_string(),
408 })
409 }
410
411 /// Get the model path.
412 #[must_use]
413 pub fn model_path(&self) -> &str {
414 &self.model_path
415 }
416 }
417
418 #[async_trait::async_trait]
419 impl EmbeddingProvider for OnnxEmbedding {
420 /// Generate an embedding vector for a single text input.
421 ///
422 /// # Errors
423 ///
424 /// Currently returns [`Error::Embedding`] because full ONNX Runtime
425 /// inference requires the `onnx` feature (with `ort`, `tokenizers`,
426 /// and `ndarray` crates).
427 async fn embed(&self, _text: &str) -> Result<Vec<f32>> {
428 Err(Error::Embedding(
429 "ONNX Runtime not available: compile with full onnx dependencies \
430 (ort, tokenizers, ndarray) to enable local inference"
431 .to_string(),
432 ))
433 }
434
435 /// Generate embedding vectors for a batch of text inputs.
436 ///
437 /// # Errors
438 ///
439 /// Currently returns [`Error::Embedding`] because full ONNX Runtime
440 /// inference requires the `onnx` feature (with `ort`, `tokenizers`,
441 /// and `ndarray` crates).
442 async fn embed_batch(&self, _texts: &[&str]) -> Result<Vec<Vec<f32>>> {
443 Err(Error::Embedding(
444 "ONNX Runtime not available: compile with full onnx dependencies \
445 (ort, tokenizers, ndarray) to enable local inference"
446 .to_string(),
447 ))
448 }
449
450 fn dimensions(&self) -> usize {
451 self.dimensions
452 }
453 }
454}
455
456// Re-export `OnnxEmbedding` from the active inner module so that
457// downstream code can use `crate::embedding::onnx::OnnxEmbedding`
458// regardless of the feature flag.
459pub use inner::OnnxEmbedding;
460
461#[cfg(test)]
462mod tests {
463 use super::*;
464
465 #[test]
466 fn test_onnx_missing_model() {
467 let result = OnnxEmbedding::new("/nonexistent/path/model.onnx", 384);
468 assert!(result.is_err());
469 let err = result.unwrap_err();
470 let msg = err.to_string();
471 assert!(
472 msg.contains("ONNX model not found"),
473 "unexpected error message: {msg}"
474 );
475 }
476
477 #[test]
478 fn test_onnx_dimensions() {
479 // Use Cargo.toml as a stand-in file that is guaranteed to exist.
480 let path = concat!(env!("CARGO_MANIFEST_DIR"), "/Cargo.toml");
481 #[cfg(not(feature = "onnx"))]
482 {
483 let provider = OnnxEmbedding::new(path, 384).expect("file should exist");
484 assert_eq!(provider.dimensions(), 384);
485 }
486 // When the onnx feature is on, construction also requires
487 // tokenizer.json, so we only test that the path validation
488 // passes for the stub variant.
489 #[cfg(feature = "onnx")]
490 {
491 // Without a tokenizer.json next to Cargo.toml, we expect an
492 // embedding error rather than a validation error.
493 let result = OnnxEmbedding::new(path, 384);
494 assert!(result.is_err());
495 let msg = result.unwrap_err().to_string();
496 assert!(
497 msg.contains("tokenizer.json"),
498 "expected tokenizer.json error, got: {msg}"
499 );
500 }
501 }
502
503 #[test]
504 fn test_onnx_model_path() {
505 let path = concat!(env!("CARGO_MANIFEST_DIR"), "/Cargo.toml");
506 #[cfg(not(feature = "onnx"))]
507 {
508 let provider = OnnxEmbedding::new(path, 768).expect("file should exist");
509 assert_eq!(provider.model_path(), path);
510 }
511 #[cfg(feature = "onnx")]
512 {
513 let result = OnnxEmbedding::new(path, 768);
514 assert!(result.is_err());
515 }
516 }
517
518 #[cfg(not(feature = "onnx"))]
519 #[tokio::test]
520 async fn test_onnx_embed_returns_error_without_runtime() {
521 let path = concat!(env!("CARGO_MANIFEST_DIR"), "/Cargo.toml");
522 let provider = OnnxEmbedding::new(path, 384).expect("file should exist");
523 let result = provider.embed("hello world").await;
524 assert!(result.is_err());
525 let msg = result.unwrap_err().to_string();
526 assert!(
527 msg.contains("ONNX Runtime not available"),
528 "unexpected error: {msg}"
529 );
530 }
531
532 #[cfg(not(feature = "onnx"))]
533 #[tokio::test]
534 async fn test_onnx_embed_batch_returns_error_without_runtime() {
535 let path = concat!(env!("CARGO_MANIFEST_DIR"), "/Cargo.toml");
536 let provider = OnnxEmbedding::new(path, 384).expect("file should exist");
537 let result = provider.embed_batch(&["a", "b"]).await;
538 assert!(result.is_err());
539 let msg = result.unwrap_err().to_string();
540 assert!(
541 msg.contains("ONNX Runtime not available"),
542 "unexpected error: {msg}"
543 );
544 }
545}