nornir 0.4.28

Companion to cargo: dependency tracking, release gating, deploy, benchmarks, and documentation assembly. Project-agnostic.
//! `onnx` generative backend (`gen-onnx`) — generative inference over the ORT
//! (ONNX Runtime) stack, the onnxruntime-genai-style path.
//!
//! Loads an exported decoder model (`model.onnx`) + its `tokenizer.json` and
//! runs a real autoregressive loop: tokenize → session.run → greedy/sampled
//! pick → append → repeat, until the token cap, an EOS token, or a stop
//! sequence. This is the same shape `onnxruntime-genai` automates; we drive the
//! `ort` session directly so the backend needs only the `ort` + `tokenizers`
//! crates (no extra genai C library).
//!
//! ## Model spec
//! `onnx:<path>` points at a directory holding `model.onnx` + `tokenizer.json`
//! (a Hugging Face Optimum/genai export). `$NORNIR_GEN_ONNX_DIR` overrides an
//! empty spec. The session is built lazily on first `complete`.
//!
//! ## `available()`
//! Reports `true` only when BOTH the ORT shared library can be located AND the
//! model dir holds `model.onnx` + `tokenizer.json` — a real probe of "the genai
//! libs + the model are present", with no canned fallback.

use std::path::PathBuf;
use std::sync::Mutex;

use anyhow::{anyhow, Context, Result};
use ort::session::Session;
use ort::value::Value;
use tokenizers::Tokenizer;

use super::{Backend, GenAnswer, GenRequest, Generator};

/// `$NORNIR_GEN_ONNX_DIR` — overrides an empty `onnx:` model dir.
const ONNX_DIR_ENV: &str = "NORNIR_GEN_ONNX_DIR";

/// Loaded session + tokenizer (the expensive state), built lazily.
struct Loaded {
    session: Session,
    tokenizer: Tokenizer,
}

/// The ONNX generator. Holds the model dir + a lazily-built [`Session`] behind a
/// mutex (a session's `run` takes `&mut self`; the mutex makes the generator
/// `Sync` for the bake-off/server).
pub struct OnnxGenerator {
    id: String,
    dir: PathBuf,
    loaded: Mutex<Option<Loaded>>,
}

impl OnnxGenerator {
    /// Build the generator for `model` (a dir path, or empty → `$NORNIR_GEN_ONNX_DIR`,
    /// else the current dir). Does NOT build the session — that's lazy.
    pub fn new(model: &str) -> Result<Self> {
        let dir = if !model.is_empty() {
            PathBuf::from(model)
        } else if let Ok(d) = std::env::var(ONNX_DIR_ENV) {
            PathBuf::from(d)
        } else {
            PathBuf::from(".")
        };
        Ok(Self {
            id: format!("onnx:{}", dir.display()),
            dir,
            loaded: Mutex::new(None),
        })
    }

    fn model_path(&self) -> PathBuf {
        self.dir.join("model.onnx")
    }
    fn tokenizer_path(&self) -> PathBuf {
        self.dir.join("tokenizer.json")
    }

    /// Build the ORT session + tokenizer from the model dir.
    fn load(&self) -> Result<Loaded> {
        let model_path = self.model_path();
        let tok_path = self.tokenizer_path();
        if !model_path.exists() {
            return Err(anyhow!("onnx: no model.onnx in {}", self.dir.display()));
        }
        if !tok_path.exists() {
            return Err(anyhow!("onnx: no tokenizer.json in {}", self.dir.display()));
        }
        let session = Session::builder()
            .context("onnx: session builder")?
            .commit_from_file(&model_path)
            .with_context(|| format!("onnx: load {}", model_path.display()))?;
        let tokenizer =
            Tokenizer::from_file(&tok_path).map_err(|e| anyhow!("onnx: load tokenizer: {e}"))?;
        Ok(Loaded { session, tokenizer })
    }

    /// Greedy argmax over the last position's logits row.
    fn argmax(logits: &[f32]) -> usize {
        let mut best = 0usize;
        let mut best_v = f32::MIN;
        for (i, &v) in logits.iter().enumerate() {
            if v > best_v {
                best_v = v;
                best = i;
            }
        }
        best
    }
}

impl Backend for OnnxGenerator {
    fn id(&self) -> &str {
        &self.id
    }

    /// Available when the ORT runtime lib is locatable AND the model dir holds
    /// `model.onnx` + `tokenizer.json`. No canned fallback: a missing lib or
    /// model reports `false`.
    fn available(&self) -> bool {
        if !self.model_path().exists() || !self.tokenizer_path().exists() {
            return false;
        }
        // Probe the ORT runtime: building a (throwaway) session env succeeds only
        // when the onnxruntime shared library is present/loadable.
        Session::builder().is_ok()
    }
}

impl Generator for OnnxGenerator {
    fn complete(&self, req: &GenRequest) -> Result<GenAnswer> {
        let started = std::time::Instant::now();
        let mut guard = self.loaded.lock().expect("onnx loaded mutex");
        if guard.is_none() {
            *guard = Some(self.load()?);
        }
        let Loaded { session, tokenizer } = guard.as_mut().expect("loaded set above");

        let prompt = match &req.system {
            Some(sys) => format!("{sys}\n{}", req.prompt),
            None => req.prompt.clone(),
        };
        let encoding =
            tokenizer.encode(prompt, true).map_err(|e| anyhow!("onnx: encode: {e}"))?;
        let mut ids: Vec<i64> = encoding.get_ids().iter().map(|&i| i as i64).collect();
        let tokens_in = ids.len() as i64;

        let eos = tokenizer.token_to_id("</s>").map(|i| i as i64).unwrap_or(-1);
        let mut generated: Vec<u32> = Vec::new();

        // Autoregressive loop: feed the running sequence, take the last-position
        // logits, pick the next token, append. This is the genai decode loop;
        // models with a KV-cache export run faster but the dense path is correct.
        for _ in 0..req.max_tokens {
            let seq_len = ids.len();
            let input_ids = Value::from_array(([1usize, seq_len], ids.clone()))
                .map_err(|e| anyhow!("onnx: input_ids tensor: {e}"))?;
            let mask: Vec<i64> = vec![1; seq_len];
            let attn = Value::from_array(([1usize, seq_len], mask))
                .map_err(|e| anyhow!("onnx: attention_mask tensor: {e}"))?;

            let outputs = session
                .run(ort::inputs![
                    "input_ids" => input_ids,
                    "attention_mask" => attn,
                ])
                .map_err(|e| anyhow!("onnx: session run: {e}"))?;

            // Logits: [batch, seq, vocab]. Take the last position's row.
            let (shape, data) = outputs[0]
                .try_extract_tensor::<f32>()
                .map_err(|e| anyhow!("onnx: extract logits: {e}"))?;
            let vocab = *shape.last().ok_or_else(|| anyhow!("onnx: empty logits shape"))? as usize;
            let last_row_start = data.len().saturating_sub(vocab);
            let last_logits = &data[last_row_start..];
            let next = Self::argmax(last_logits) as i64;

            if next == eos {
                break;
            }
            generated.push(next as u32);
            ids.push(next);

            if !req.stop.is_empty() {
                let so_far = tokenizer
                    .decode(&generated, true)
                    .map_err(|e| anyhow!("onnx: decode: {e}"))?;
                if req.stop.iter().any(|s| so_far.contains(s)) {
                    break;
                }
            }
        }

        let text = tokenizer
            .decode(&generated, true)
            .map_err(|e| anyhow!("onnx: final decode: {e}"))?;
        let latency_ms = started.elapsed().as_secs_f64() * 1000.0;
        let tokens_out = generated.len() as i64;
        let tokens_per_s = if latency_ms > 0.0 {
            tokens_out as f64 / (latency_ms / 1000.0)
        } else {
            0.0
        };
        Ok(GenAnswer { text, tokens_in, tokens_out, tokens_per_s, latency_ms })
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn id_reflects_the_model_dir() {
        let gen = OnnxGenerator::new("/models/qwen-onnx").unwrap();
        assert_eq!(gen.id(), "onnx:/models/qwen-onnx");
    }

    #[test]
    fn unavailable_when_model_dir_is_empty() {
        // A dir with no model.onnx / tokenizer.json must report not-available
        // (real probe of "the genai libs + model are present" — no stub).
        let tmp = tempfile::tempdir().unwrap();
        let gen = OnnxGenerator::new(tmp.path().to_str().unwrap()).unwrap();
        assert!(!gen.available(), "no model in dir ⇒ unavailable");
    }

    /// Heavy: needs a real ONNX export in `$NORNIR_GEN_ONNX_DIR`.
    #[test]
    #[ignore = "needs a real ONNX model export on disk"]
    fn real_generation_round_trips() {
        let dir = std::env::var(ONNX_DIR_ENV).expect("set NORNIR_GEN_ONNX_DIR");
        let gen = OnnxGenerator::new(&dir).unwrap();
        assert!(gen.available());
        let req = GenRequest::new("Reply with the single word: pong").with_max_tokens(8);
        let ans = gen.complete(&req).unwrap();
        assert!(ans.tokens_in > 0);
        assert!(!ans.text.is_empty());
    }
}