Skip to main content

leann_core/embedding/
onnx.rs

1use anyhow::Result;
2use ndarray::Array2;
3use std::path::PathBuf;
4
5use super::EmbeddingProvider;
6
7/// ONNX Runtime embedding provider for local sentence-transformer models.
8///
9/// This provider loads an ONNX-exported model and runs inference locally
10/// without needing Python, torch, or sentence-transformers.
11pub struct OnnxEmbedding {
12    model_path: PathBuf,
13    dimensions: usize,
14    _max_seq_length: usize,
15    model_name: String,
16}
17
18impl OnnxEmbedding {
19    /// Create a new ONNX embedding provider.
20    ///
21    /// `model_path` should point to a directory containing:
22    /// - `model.onnx` or `model_optimized.onnx`
23    /// - `tokenizer.json` (HuggingFace tokenizer)
24    pub fn new(model_path: &str, dimensions: Option<usize>) -> Result<Self> {
25        let path = PathBuf::from(model_path);
26
27        if !path.exists() {
28            anyhow::bail!("ONNX model path does not exist: {}", model_path);
29        }
30
31        // Check for model file
32        let _model_file = if path.join("model_optimized.onnx").exists() {
33            path.join("model_optimized.onnx")
34        } else if path.join("model.onnx").exists() {
35            path.join("model.onnx")
36        } else {
37            anyhow::bail!("No ONNX model file found in {}", model_path);
38        };
39
40        let _tokenizer_file = path.join("tokenizer.json");
41
42        // Default dimensions for common models
43        let dimensions = dimensions.unwrap_or(768);
44
45        Ok(Self {
46            model_path: path,
47            dimensions,
48            _max_seq_length: 512,
49            model_name: model_path.to_string(),
50        })
51    }
52
53    /// Get the path to the ONNX model file.
54    pub fn model_file(&self) -> PathBuf {
55        if self.model_path.join("model_optimized.onnx").exists() {
56            self.model_path.join("model_optimized.onnx")
57        } else {
58            self.model_path.join("model.onnx")
59        }
60    }
61
62    /// Get the path to the tokenizer file.
63    pub fn tokenizer_file(&self) -> PathBuf {
64        self.model_path.join("tokenizer.json")
65    }
66}
67
68impl EmbeddingProvider for OnnxEmbedding {
69    fn compute_embeddings(
70        &self,
71        chunks: &[String],
72        _progress: Option<&dyn crate::hnsw::IndexProgress>,
73    ) -> Result<Array2<f32>> {
74        if chunks.is_empty() {
75            return Ok(Array2::zeros((0, self.dimensions)));
76        }
77
78        // Note: Full ONNX Runtime integration requires the `ort` crate and
79        // a compiled ONNX Runtime library. This is a placeholder that shows
80        // the intended API. To enable, add `ort = "2"` to dependencies and
81        // uncomment the implementation below.
82        //
83        // The full implementation would:
84        // 1. Load tokenizer from tokenizer.json
85        // 2. Tokenize input texts (input_ids, attention_mask, token_type_ids)
86        // 3. Run ONNX session inference
87        // 4. Mean-pool the token embeddings using attention_mask
88        // 5. Optionally normalize to unit length
89
90        anyhow::bail!(
91            "ONNX Runtime inference not yet enabled. \
92             Install the ort crate and ONNX Runtime library, \
93             or use --embedding-mode openai/ollama instead. \
94             Model path: {}",
95            self.model_path.display()
96        )
97    }
98
99    fn dimensions(&self) -> usize {
100        self.dimensions
101    }
102
103    fn name(&self) -> &str {
104        &self.model_name
105    }
106}