Skip to main content

leann_core/embedding/
onnx.rs

1use anyhow::Result;
2use ndarray::Array2;
3use std::path::PathBuf;
4
5use super::EmbeddingProvider;
6
7/// ONNX Runtime embedding provider for local sentence-transformer models.
8///
9/// This provider loads an ONNX-exported model and runs inference locally
10/// without needing Python, torch, or sentence-transformers.
11pub struct OnnxEmbedding {
12    model_path: PathBuf,
13    dimensions: usize,
14    _max_seq_length: usize,
15    model_name: String,
16}
17
18impl OnnxEmbedding {
19    /// Create a new ONNX embedding provider.
20    ///
21    /// `model_path` should point to a directory containing:
22    /// - `model.onnx` or `model_optimized.onnx`
23    /// - `tokenizer.json` (HuggingFace tokenizer)
24    pub fn new(model_path: &str, dimensions: Option<usize>) -> Result<Self> {
25        let path = PathBuf::from(model_path);
26
27        if !path.exists() {
28            anyhow::bail!("ONNX model path does not exist: {}", model_path);
29        }
30
31        // Check for model file
32        let _model_file = if path.join("model_optimized.onnx").exists() {
33            path.join("model_optimized.onnx")
34        } else if path.join("model.onnx").exists() {
35            path.join("model.onnx")
36        } else {
37            anyhow::bail!("No ONNX model file found in {}", model_path);
38        };
39
40        let _tokenizer_file = path.join("tokenizer.json");
41
42        // Default dimensions for common models
43        let dimensions = dimensions.unwrap_or(768);
44
45        Ok(Self {
46            model_path: path,
47            dimensions,
48            _max_seq_length: 512,
49            model_name: model_path.to_string(),
50        })
51    }
52
53    /// Get the path to the ONNX model file.
54    pub fn model_file(&self) -> PathBuf {
55        if self.model_path.join("model_optimized.onnx").exists() {
56            self.model_path.join("model_optimized.onnx")
57        } else {
58            self.model_path.join("model.onnx")
59        }
60    }
61
62    /// Get the path to the tokenizer file.
63    pub fn tokenizer_file(&self) -> PathBuf {
64        self.model_path.join("tokenizer.json")
65    }
66}
67
68impl EmbeddingProvider for OnnxEmbedding {
69    fn compute_embeddings(&self, chunks: &[String]) -> Result<Array2<f32>> {
70        if chunks.is_empty() {
71            return Ok(Array2::zeros((0, self.dimensions)));
72        }
73
74        // Note: Full ONNX Runtime integration requires the `ort` crate and
75        // a compiled ONNX Runtime library. This is a placeholder that shows
76        // the intended API. To enable, add `ort = "2"` to dependencies and
77        // uncomment the implementation below.
78        //
79        // The full implementation would:
80        // 1. Load tokenizer from tokenizer.json
81        // 2. Tokenize input texts (input_ids, attention_mask, token_type_ids)
82        // 3. Run ONNX session inference
83        // 4. Mean-pool the token embeddings using attention_mask
84        // 5. Optionally normalize to unit length
85
86        anyhow::bail!(
87            "ONNX Runtime inference not yet enabled. \
88             Install the ort crate and ONNX Runtime library, \
89             or use --embedding-mode openai/ollama instead. \
90             Model path: {}",
91            self.model_path.display()
92        )
93    }
94
95    fn dimensions(&self) -> usize {
96        self.dimensions
97    }
98
99    fn name(&self) -> &str {
100        &self.model_name
101    }
102}