use std::path::PathBuf;
use std::sync::Mutex;
use anyhow::{anyhow, Context, Result};
use ort::session::Session;
use ort::value::Value;
use tokenizers::Tokenizer;
use super::{Backend, GenAnswer, GenRequest, Generator};
const ONNX_DIR_ENV: &str = "NORNIR_GEN_ONNX_DIR";
struct Loaded {
session: Session,
tokenizer: Tokenizer,
}
pub struct OnnxGenerator {
id: String,
dir: PathBuf,
loaded: Mutex<Option<Loaded>>,
}
impl OnnxGenerator {
pub fn new(model: &str) -> Result<Self> {
let dir = if !model.is_empty() {
PathBuf::from(model)
} else if let Ok(d) = std::env::var(ONNX_DIR_ENV) {
PathBuf::from(d)
} else {
PathBuf::from(".")
};
Ok(Self {
id: format!("onnx:{}", dir.display()),
dir,
loaded: Mutex::new(None),
})
}
fn model_path(&self) -> PathBuf {
self.dir.join("model.onnx")
}
fn tokenizer_path(&self) -> PathBuf {
self.dir.join("tokenizer.json")
}
fn load(&self) -> Result<Loaded> {
let model_path = self.model_path();
let tok_path = self.tokenizer_path();
if !model_path.exists() {
return Err(anyhow!("onnx: no model.onnx in {}", self.dir.display()));
}
if !tok_path.exists() {
return Err(anyhow!("onnx: no tokenizer.json in {}", self.dir.display()));
}
let session = Session::builder()
.context("onnx: session builder")?
.commit_from_file(&model_path)
.with_context(|| format!("onnx: load {}", model_path.display()))?;
let tokenizer =
Tokenizer::from_file(&tok_path).map_err(|e| anyhow!("onnx: load tokenizer: {e}"))?;
Ok(Loaded { session, tokenizer })
}
fn argmax(logits: &[f32]) -> usize {
let mut best = 0usize;
let mut best_v = f32::MIN;
for (i, &v) in logits.iter().enumerate() {
if v > best_v {
best_v = v;
best = i;
}
}
best
}
}
impl Backend for OnnxGenerator {
fn id(&self) -> &str {
&self.id
}
fn available(&self) -> bool {
if !self.model_path().exists() || !self.tokenizer_path().exists() {
return false;
}
Session::builder().is_ok()
}
}
impl Generator for OnnxGenerator {
fn complete(&self, req: &GenRequest) -> Result<GenAnswer> {
let started = std::time::Instant::now();
let mut guard = self.loaded.lock().expect("onnx loaded mutex");
if guard.is_none() {
*guard = Some(self.load()?);
}
let Loaded { session, tokenizer } = guard.as_mut().expect("loaded set above");
let prompt = match &req.system {
Some(sys) => format!("{sys}\n{}", req.prompt),
None => req.prompt.clone(),
};
let encoding =
tokenizer.encode(prompt, true).map_err(|e| anyhow!("onnx: encode: {e}"))?;
let mut ids: Vec<i64> = encoding.get_ids().iter().map(|&i| i as i64).collect();
let tokens_in = ids.len() as i64;
let eos = tokenizer.token_to_id("</s>").map(|i| i as i64).unwrap_or(-1);
let mut generated: Vec<u32> = Vec::new();
for _ in 0..req.max_tokens {
let seq_len = ids.len();
let input_ids = Value::from_array(([1usize, seq_len], ids.clone()))
.map_err(|e| anyhow!("onnx: input_ids tensor: {e}"))?;
let mask: Vec<i64> = vec![1; seq_len];
let attn = Value::from_array(([1usize, seq_len], mask))
.map_err(|e| anyhow!("onnx: attention_mask tensor: {e}"))?;
let outputs = session
.run(ort::inputs![
"input_ids" => input_ids,
"attention_mask" => attn,
])
.map_err(|e| anyhow!("onnx: session run: {e}"))?;
let (shape, data) = outputs[0]
.try_extract_tensor::<f32>()
.map_err(|e| anyhow!("onnx: extract logits: {e}"))?;
let vocab = *shape.last().ok_or_else(|| anyhow!("onnx: empty logits shape"))? as usize;
let last_row_start = data.len().saturating_sub(vocab);
let last_logits = &data[last_row_start..];
let next = Self::argmax(last_logits) as i64;
if next == eos {
break;
}
generated.push(next as u32);
ids.push(next);
if !req.stop.is_empty() {
let so_far = tokenizer
.decode(&generated, true)
.map_err(|e| anyhow!("onnx: decode: {e}"))?;
if req.stop.iter().any(|s| so_far.contains(s)) {
break;
}
}
}
let text = tokenizer
.decode(&generated, true)
.map_err(|e| anyhow!("onnx: final decode: {e}"))?;
let latency_ms = started.elapsed().as_secs_f64() * 1000.0;
let tokens_out = generated.len() as i64;
let tokens_per_s = if latency_ms > 0.0 {
tokens_out as f64 / (latency_ms / 1000.0)
} else {
0.0
};
Ok(GenAnswer { text, tokens_in, tokens_out, tokens_per_s, latency_ms })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn id_reflects_the_model_dir() {
let gen = OnnxGenerator::new("/models/qwen-onnx").unwrap();
assert_eq!(gen.id(), "onnx:/models/qwen-onnx");
}
#[test]
fn unavailable_when_model_dir_is_empty() {
let tmp = tempfile::tempdir().unwrap();
let gen = OnnxGenerator::new(tmp.path().to_str().unwrap()).unwrap();
assert!(!gen.available(), "no model in dir ⇒ unavailable");
}
#[test]
#[ignore = "needs a real ONNX model export on disk"]
fn real_generation_round_trips() {
let dir = std::env::var(ONNX_DIR_ENV).expect("set NORNIR_GEN_ONNX_DIR");
let gen = OnnxGenerator::new(&dir).unwrap();
assert!(gen.available());
let req = GenRequest::new("Reply with the single word: pong").with_max_tokens(8);
let ans = gen.complete(&req).unwrap();
assert!(ans.tokens_in > 0);
assert!(!ans.text.is_empty());
}
}