Skip to main content

zer_judge/
spec.rs

1/// Model specifications, describes where to find the ONNX model and tokenizer,
2/// and how to interpret its output.
3///
4/// # Model resolution
5///
6/// Models are loaded from a local directory.  The recommended layout (produced
7/// by `scripts/download_models.sh`) is:
8///
9/// ```text
10/// $ZER_MODEL_DIR/
11///   nli-base/
12///     base/             # FP32, CPU baseline
13///     fp16/             # FP16 weights
14///     fp16_fused/       # FP16 + graph fusions (CUDA / TensorRT preferred)
15/// ```
16///
17/// Resolution order for the `from_env` constructors:
18///
19/// 1. `ZER_MODEL_DIR` environment variable (explicit override)
20/// 2. `~/.cache/zer/models` (user cache, populated by `scripts/download_models.sh`)
21/// 3. `./models` relative to the current working directory (workspace default)
22use std::path::{Path, PathBuf};
23
24// ── Model resolution helpers ──────────────────────────────────────────────────
25
26/// Returns the directory where zer looks for judge models.
27///
28/// Resolution order:
29/// 1. `ZER_MODEL_DIR` environment variable
30/// 2. `~/.cache/zer/models`
31/// 3. `./models` (workspace fallback)
32pub fn default_models_dir() -> PathBuf {
33    if let Ok(dir) = std::env::var("ZER_MODEL_DIR") {
34        return PathBuf::from(dir);
35    }
36    if let Some(home) = std::env::var_os("HOME") {
37        let cache = PathBuf::from(home)
38            .join(".cache")
39            .join("zer")
40            .join("models");
41        if cache.exists() {
42            return cache;
43        }
44    }
45    PathBuf::from("models")
46}
47
48// ── TokenizerSource ───────────────────────────────────────────────────────────
49
50/// Specifies how the tokenizer is loaded.
51#[derive(Debug, Clone)]
52pub enum TokenizerSource {
53    /// Load from a `tokenizer.json` file on disk.
54    File(PathBuf),
55    /// Use a Hugging Face model identifier (downloads on first use via the
56    /// `tokenizers` crate's built-in hub).
57    HuggingFace(String),
58}
59
60impl TokenizerSource {
61    /// Convenience: file path from any `AsRef<Path>`.
62    pub fn file(path: impl AsRef<Path>) -> Self {
63        Self::File(path.as_ref().to_owned())
64    }
65
66    /// Convenience: HuggingFace model id as a string.
67    pub fn hub(model_id: impl Into<String>) -> Self {
68        Self::HuggingFace(model_id.into())
69    }
70}
71
72// ── JudgeModelSpec ────────────────────────────────────────────────────────────
73
74/// Everything needed to load and run a judge model.
75///
76/// Implement this trait to add a new model variant; the built-in specs are
77/// [`MiniLmSpec`] (small default) and [`DebertaBaseSpec`] (large default).
78pub trait JudgeModelSpec: Send + Sync {
79    /// Human-readable model name for diagnostics.
80    fn name(&self) -> &str;
81
82    /// Path to the ONNX model file on disk.
83    fn model_path(&self) -> &Path;
84
85    /// Where to load the tokenizer from.
86    fn tokenizer_source(&self) -> &TokenizerSource;
87
88    /// Maximum token sequence length the model accepts.
89    fn max_length(&self) -> usize;
90
91    /// Index of the "entailment" class in the model output logits.
92    ///
93    /// For NLI cross-encoders: `entailment` → match, `contradiction` → no-match.
94    fn entailment_idx(&self) -> usize;
95
96    /// Approximate VRAM requirement in bytes for this model.
97    fn vram_bytes(&self) -> u64;
98}
99
100// ── ModelPrecision ────────────────────────────────────────────────────────────
101
102/// Precision variant of the ONNX model to load.
103///
104/// Matches the subdirectory layout produced by `scripts/download_models.sh`:
105///
106/// | Variant      | Subfolder      | Notes                                         |
107/// |--------------|----------------|-----------------------------------------------|
108/// | `Base`       | `base/`        | FP32, no optimisation; CPU baseline           |
109/// | `Fp16`       | `fp16/`        | FP16 weights, no graph fusions                |
110/// | `Fp16Fused`  | `fp16_fused/`  | FP16 + level-2 fusions; CUDA / TensorRT best  |
111#[derive(Debug, Clone, Copy, Default)]
112pub enum ModelPrecision {
113    Base,
114    Fp16,
115    #[default]
116    Fp16Fused,
117}
118
119impl ModelPrecision {
120    pub fn subfolder(self) -> &'static str {
121        match self {
122            Self::Base => "base",
123            Self::Fp16 => "fp16",
124            Self::Fp16Fused => "fp16_fused",
125        }
126    }
127}
128
129// ── MiniLmSpec ────────────────────────────────────────────────────────────────
130
131/// MiniLM-L6-v2 NLI cross-encoder (~23 MB ONNX, fits in 256 MB VRAM).
132pub struct MiniLmSpec {
133    model_path: PathBuf,
134    tokenizer_source: TokenizerSource,
135}
136
137impl MiniLmSpec {
138    pub fn new(model_path: impl AsRef<Path>, tokenizer_source: TokenizerSource) -> Self {
139        Self {
140            model_path: model_path.as_ref().to_owned(),
141            tokenizer_source,
142        }
143    }
144
145    /// Convenience: load both model and tokenizer from the same directory.
146    /// Expects `<dir>/model.onnx` and `<dir>/tokenizer.json`.
147    pub fn from_dir(dir: impl AsRef<Path>) -> Self {
148        let dir = dir.as_ref();
149        Self {
150            model_path: dir.join("model.onnx"),
151            tokenizer_source: TokenizerSource::file(dir.join("tokenizer.json")),
152        }
153    }
154
155    /// Load from the resolved models directory (see [`default_models_dir`]).
156    ///
157    /// Looks for the FP16-fused variant first, falling back to the FP32 base.
158    /// Download models first with `scripts/download_models.sh` or set
159    /// `ZER_MODEL_DIR` to point at a local directory.
160    pub fn from_env(precision: ModelPrecision) -> Self {
161        let base = default_models_dir()
162            .join("nli-base")
163            .join(precision.subfolder())
164            .join("nli-minilm-onnx");
165        Self::from_dir(base)
166    }
167}
168
169impl JudgeModelSpec for MiniLmSpec {
170    fn name(&self) -> &str {
171        "cross-encoder/nli-MiniLM2-L6-H768"
172    }
173    fn model_path(&self) -> &Path {
174        &self.model_path
175    }
176    fn tokenizer_source(&self) -> &TokenizerSource {
177        &self.tokenizer_source
178    }
179    fn max_length(&self) -> usize {
180        512
181    }
182    fn entailment_idx(&self) -> usize {
183        1
184    }
185    fn vram_bytes(&self) -> u64 {
186        256 * 1024 * 1024
187    } // 256 MB
188}
189
190// ── DebertaBaseSpec ───────────────────────────────────────────────────────────
191
192/// DeBERTa-v3-base NLI (~185 MB ONNX, fits in 2 GB VRAM).
193pub struct DebertaBaseSpec {
194    model_path: PathBuf,
195    tokenizer_source: TokenizerSource,
196}
197
198impl DebertaBaseSpec {
199    pub fn new(model_path: impl AsRef<Path>, tokenizer_source: TokenizerSource) -> Self {
200        Self {
201            model_path: model_path.as_ref().to_owned(),
202            tokenizer_source,
203        }
204    }
205
206    pub fn from_dir(dir: impl AsRef<Path>) -> Self {
207        let dir = dir.as_ref();
208        Self {
209            model_path: dir.join("model.onnx"),
210            tokenizer_source: TokenizerSource::file(dir.join("tokenizer.json")),
211        }
212    }
213
214    /// Load from the resolved models directory (see [`default_models_dir`]).
215    ///
216    /// Download models first with `scripts/download_models.sh` or set
217    /// `ZER_MODEL_DIR` to point at a local directory.
218    pub fn from_env(precision: ModelPrecision) -> Self {
219        let base = default_models_dir()
220            .join("nli-base")
221            .join(precision.subfolder())
222            .join("nli-deberta-v3-base-onnx");
223        Self::from_dir(base)
224    }
225}
226
227impl JudgeModelSpec for DebertaBaseSpec {
228    fn name(&self) -> &str {
229        "cross-encoder/nli-deberta-v3-base"
230    }
231    fn model_path(&self) -> &Path {
232        &self.model_path
233    }
234    fn tokenizer_source(&self) -> &TokenizerSource {
235        &self.tokenizer_source
236    }
237    fn max_length(&self) -> usize {
238        512
239    }
240    fn entailment_idx(&self) -> usize {
241        1
242    }
243    fn vram_bytes(&self) -> u64 {
244        2 * 1024 * 1024 * 1024
245    } // 2 GB
246}
247
248// ── spec_from_env / spec_from_vram ───────────────────────────────────────────
249
250/// Select the most capable spec that fits within `available_vram_bytes`, loading
251/// from the resolved models directory (see [`default_models_dir`]).
252///
253/// This is the easiest entry-point for end users: run `scripts/download_models.sh`
254/// (or set `ZER_MODEL_DIR`), then call `spec_from_env` and let zer pick the best
255/// model for the available hardware.
256pub fn spec_from_env(
257    precision: ModelPrecision,
258    available_vram_bytes: u64,
259) -> Box<dyn JudgeModelSpec> {
260    let models_dir = default_models_dir()
261        .join("nli-base")
262        .join(precision.subfolder());
263    spec_from_vram(&models_dir, available_vram_bytes)
264}
265
266/// Select the most capable built-in spec that fits within `available_vram_bytes`.
267///
268/// Two defaults: small → MiniLM-L6, large → DeBERTa-v3-base.
269///
270/// Requires a directory layout produced by `models/generate_onnx_model.py`:
271/// ```text
272/// models/nli-base/base/         ← TensorRT and CPU (plain FP32, no fusions)
273///   nli-deberta-v3-base-onnx/model.onnx
274///   nli-minilm-onnx/model.onnx
275/// models/nli-base/fp16_fused/   ← CUDA / ROCm / DirectML / OpenVINO (preferred)
276///   nli-deberta-v3-base-onnx/model.onnx
277///   nli-minilm-onnx/model.onnx
278/// models/nli-base/fp16/         ← CUDA fallback (FP16 weights, no ORT fusions)
279///   nli-deberta-v3-base-onnx/model.onnx
280///   nli-minilm-onnx/model.onnx
281/// ```
282pub fn spec_from_vram(models_dir: &Path, available_vram_bytes: u64) -> Box<dyn JudgeModelSpec> {
283    let base = models_dir.join("nli-deberta-v3-base-onnx");
284    let mini = models_dir.join("nli-minilm-onnx");
285
286    if available_vram_bytes >= 2 * 1024 * 1024 * 1024 && base.exists() {
287        tracing::info!(
288            "judge: selecting DeBERTa-v3-base ({:.1} GB VRAM available)",
289            available_vram_bytes as f64 / 1e9
290        );
291        return Box::new(DebertaBaseSpec::from_dir(&base));
292    }
293
294    tracing::info!("judge: selecting MiniLM-L6 (CPU or low VRAM)");
295    Box::new(MiniLmSpec::from_dir(&mini))
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301    use std::path::PathBuf;
302
303    fn dummy_path(name: &str) -> PathBuf {
304        PathBuf::from(format!("/nonexistent/{name}"))
305    }
306
307    // ── MiniLmSpec ────────────────────────────────────────────────────────────
308
309    #[test]
310    fn minilm_from_dir_sets_expected_paths() {
311        let spec = MiniLmSpec::from_dir("/some/dir");
312        assert_eq!(spec.model_path(), Path::new("/some/dir/model.onnx"));
313        assert!(
314            matches!(spec.tokenizer_source(), TokenizerSource::File(p) if p == Path::new("/some/dir/tokenizer.json"))
315        );
316    }
317
318    #[test]
319    fn minilm_metadata() {
320        let spec = MiniLmSpec::from_dir("/d");
321        assert_eq!(spec.name(), "cross-encoder/nli-MiniLM2-L6-H768");
322        assert_eq!(spec.max_length(), 512);
323        assert_eq!(spec.entailment_idx(), 1);
324        assert_eq!(spec.vram_bytes(), 256 * 1024 * 1024);
325    }
326
327    // ── DebertaBaseSpec ───────────────────────────────────────────────────────
328
329    #[test]
330    fn deberta_base_from_dir_sets_expected_paths() {
331        let spec = DebertaBaseSpec::from_dir("/fp16_fused/dir");
332        assert_eq!(spec.model_path(), Path::new("/fp16_fused/dir/model.onnx"));
333        assert!(
334            matches!(spec.tokenizer_source(), TokenizerSource::File(p) if p == Path::new("/fp16_fused/dir/tokenizer.json"))
335        );
336    }
337
338    #[test]
339    fn deberta_base_metadata() {
340        let spec = DebertaBaseSpec::from_dir("/d");
341        assert_eq!(spec.name(), "cross-encoder/nli-deberta-v3-base");
342        assert_eq!(spec.max_length(), 512);
343        assert_eq!(spec.entailment_idx(), 1);
344        assert_eq!(spec.vram_bytes(), 2 * 1024 * 1024 * 1024);
345    }
346
347    // ── spec_from_vram ────────────────────────────────────────────────────────
348
349    #[test]
350    fn spec_from_vram_no_dirs_returns_minilm() {
351        // Non-existent dirs → all exist() == false → always falls through to MiniLm
352        let spec = spec_from_vram(Path::new("/nonexistent"), 16 * 1024 * 1024 * 1024);
353        assert_eq!(spec.name(), "cross-encoder/nli-MiniLM2-L6-H768");
354    }
355
356    #[test]
357    fn spec_from_vram_selects_minilm_when_low_vram() {
358        // Even if dirs existed, 512 MB isn't enough for base (needs 2 GB)
359        let spec = spec_from_vram(Path::new("/nonexistent"), 512 * 1024 * 1024);
360        assert_eq!(spec.name(), "cross-encoder/nli-MiniLM2-L6-H768");
361    }
362
363    #[test]
364    fn spec_from_vram_with_real_models_dir_selects_best_available() {
365        // Test with the actual models directory if it exists
366        let models_dir = Path::new("../../models/nli-base/fp16_fused");
367        if !models_dir.exists() {
368            return; // Skip if models directory not available
369        }
370        // With 0 VRAM, must fall back to MiniLM
371        let spec = spec_from_vram(models_dir, 0);
372        assert_eq!(spec.name(), "cross-encoder/nli-MiniLM2-L6-H768");
373    }
374
375    #[test]
376    fn token_source_file_convenience() {
377        let ts = TokenizerSource::file("/tmp/tok.json");
378        assert!(matches!(ts, TokenizerSource::File(p) if p == Path::new("/tmp/tok.json")));
379    }
380
381    #[test]
382    fn token_source_hub_convenience() {
383        let ts = TokenizerSource::hub("cross-encoder/nli-deberta-v3-base");
384        assert!(
385            matches!(ts, TokenizerSource::HuggingFace(s) if s == "cross-encoder/nli-deberta-v3-base")
386        );
387    }
388
389    #[test]
390    fn minilm_new_constructor() {
391        let spec = MiniLmSpec::new(
392            dummy_path("model.onnx"),
393            TokenizerSource::file(dummy_path("tok.json")),
394        );
395        assert_eq!(spec.model_path(), Path::new("/nonexistent/model.onnx"));
396    }
397}