use anyhow::Result;
use ndarray::Array2;
use std::path::PathBuf;
use super::EmbeddingProvider;
pub struct OnnxEmbedding {
model_path: PathBuf,
dimensions: usize,
_max_seq_length: usize,
model_name: String,
}
impl OnnxEmbedding {
pub fn new(model_path: &str, dimensions: Option<usize>) -> Result<Self> {
let path = PathBuf::from(model_path);
if !path.exists() {
anyhow::bail!("ONNX model path does not exist: {}", model_path);
}
let _model_file = if path.join("model_optimized.onnx").exists() {
path.join("model_optimized.onnx")
} else if path.join("model.onnx").exists() {
path.join("model.onnx")
} else {
anyhow::bail!("No ONNX model file found in {}", model_path);
};
let _tokenizer_file = path.join("tokenizer.json");
let dimensions = dimensions.unwrap_or(768);
Ok(Self {
model_path: path,
dimensions,
_max_seq_length: 512,
model_name: model_path.to_string(),
})
}
pub fn model_file(&self) -> PathBuf {
if self.model_path.join("model_optimized.onnx").exists() {
self.model_path.join("model_optimized.onnx")
} else {
self.model_path.join("model.onnx")
}
}
pub fn tokenizer_file(&self) -> PathBuf {
self.model_path.join("tokenizer.json")
}
}
impl EmbeddingProvider for OnnxEmbedding {
fn compute_embeddings(
&self,
chunks: &[String],
_progress: Option<&dyn crate::hnsw::IndexProgress>,
) -> Result<Array2<f32>> {
if chunks.is_empty() {
return Ok(Array2::zeros((0, self.dimensions)));
}
anyhow::bail!(
"ONNX Runtime inference not yet enabled. \
Install the ort crate and ONNX Runtime library, \
or use --embedding-mode openai/ollama instead. \
Model path: {}",
self.model_path.display()
)
}
fn dimensions(&self) -> usize {
self.dimensions
}
fn name(&self) -> &str {
&self.model_name
}
}