use std::path::{Path, PathBuf};
pub fn default_models_dir() -> PathBuf {
if let Ok(dir) = std::env::var("ZER_MODEL_DIR") {
return PathBuf::from(dir);
}
if let Some(home) = std::env::var_os("HOME") {
let cache = PathBuf::from(home).join(".cache").join("zer").join("models");
if cache.exists() {
return cache;
}
}
PathBuf::from("models")
}
#[derive(Debug, Clone)]
pub enum TokenizerSource {
File(PathBuf),
HuggingFace(String),
}
impl TokenizerSource {
pub fn file(path: impl AsRef<Path>) -> Self {
Self::File(path.as_ref().to_owned())
}
pub fn hub(model_id: impl Into<String>) -> Self {
Self::HuggingFace(model_id.into())
}
}
pub trait JudgeModelSpec: Send + Sync {
fn name(&self) -> &str;
fn model_path(&self) -> &Path;
fn tokenizer_source(&self) -> &TokenizerSource;
fn max_length(&self) -> usize;
fn entailment_idx(&self) -> usize;
fn vram_bytes(&self) -> u64;
}
#[derive(Debug, Clone, Copy, Default)]
pub enum ModelPrecision {
Base,
Fp16,
#[default]
Fp16Fused,
}
impl ModelPrecision {
pub fn subfolder(self) -> &'static str {
match self {
Self::Base => "base",
Self::Fp16 => "fp16",
Self::Fp16Fused => "fp16_fused",
}
}
}
pub struct MiniLmSpec {
model_path: PathBuf,
tokenizer_source: TokenizerSource,
}
impl MiniLmSpec {
pub fn new(model_path: impl AsRef<Path>, tokenizer_source: TokenizerSource) -> Self {
Self {
model_path: model_path.as_ref().to_owned(),
tokenizer_source,
}
}
pub fn from_dir(dir: impl AsRef<Path>) -> Self {
let dir = dir.as_ref();
Self {
model_path: dir.join("model.onnx"),
tokenizer_source: TokenizerSource::file(dir.join("tokenizer.json")),
}
}
pub fn from_env(precision: ModelPrecision) -> Self {
let base = default_models_dir()
.join("nli-base")
.join(precision.subfolder())
.join("nli-minilm-onnx");
Self::from_dir(base)
}
}
impl JudgeModelSpec for MiniLmSpec {
fn name(&self) -> &str { "cross-encoder/nli-MiniLM2-L6-H768" }
fn model_path(&self) -> &Path { &self.model_path }
fn tokenizer_source(&self) -> &TokenizerSource { &self.tokenizer_source }
fn max_length(&self) -> usize { 512 }
fn entailment_idx(&self) -> usize { 1 }
fn vram_bytes(&self) -> u64 { 256 * 1024 * 1024 } }
pub struct DebertaBaseSpec {
model_path: PathBuf,
tokenizer_source: TokenizerSource,
}
impl DebertaBaseSpec {
pub fn new(model_path: impl AsRef<Path>, tokenizer_source: TokenizerSource) -> Self {
Self {
model_path: model_path.as_ref().to_owned(),
tokenizer_source,
}
}
pub fn from_dir(dir: impl AsRef<Path>) -> Self {
let dir = dir.as_ref();
Self {
model_path: dir.join("model.onnx"),
tokenizer_source: TokenizerSource::file(dir.join("tokenizer.json")),
}
}
pub fn from_env(precision: ModelPrecision) -> Self {
let base = default_models_dir()
.join("nli-base")
.join(precision.subfolder())
.join("nli-deberta-v3-base-onnx");
Self::from_dir(base)
}
}
impl JudgeModelSpec for DebertaBaseSpec {
fn name(&self) -> &str { "cross-encoder/nli-deberta-v3-base" }
fn model_path(&self) -> &Path { &self.model_path }
fn tokenizer_source(&self) -> &TokenizerSource { &self.tokenizer_source }
fn max_length(&self) -> usize { 512 }
fn entailment_idx(&self) -> usize { 1 }
fn vram_bytes(&self) -> u64 { 2 * 1024 * 1024 * 1024 } }
pub fn spec_from_env(precision: ModelPrecision, available_vram_bytes: u64) -> Box<dyn JudgeModelSpec> {
let models_dir = default_models_dir().join("nli-base").join(precision.subfolder());
spec_from_vram(&models_dir, available_vram_bytes)
}
pub fn spec_from_vram(models_dir: &Path, available_vram_bytes: u64) -> Box<dyn JudgeModelSpec> {
let base = models_dir.join("nli-deberta-v3-base-onnx");
let mini = models_dir.join("nli-minilm-onnx");
if available_vram_bytes >= 2 * 1024 * 1024 * 1024 && base.exists() {
tracing::info!("judge: selecting DeBERTa-v3-base ({:.1} GB VRAM available)",
available_vram_bytes as f64 / 1e9);
return Box::new(DebertaBaseSpec::from_dir(&base));
}
tracing::info!("judge: selecting MiniLM-L6 (CPU or low VRAM)");
Box::new(MiniLmSpec::from_dir(&mini))
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
fn dummy_path(name: &str) -> PathBuf {
PathBuf::from(format!("/nonexistent/{name}"))
}
#[test]
fn minilm_from_dir_sets_expected_paths() {
let spec = MiniLmSpec::from_dir("/some/dir");
assert_eq!(spec.model_path(), Path::new("/some/dir/model.onnx"));
assert!(matches!(spec.tokenizer_source(), TokenizerSource::File(p) if p == Path::new("/some/dir/tokenizer.json")));
}
#[test]
fn minilm_metadata() {
let spec = MiniLmSpec::from_dir("/d");
assert_eq!(spec.name(), "cross-encoder/nli-MiniLM2-L6-H768");
assert_eq!(spec.max_length(), 512);
assert_eq!(spec.entailment_idx(), 1);
assert_eq!(spec.vram_bytes(), 256 * 1024 * 1024);
}
#[test]
fn deberta_base_from_dir_sets_expected_paths() {
let spec = DebertaBaseSpec::from_dir("/fp16_fused/dir");
assert_eq!(spec.model_path(), Path::new("/fp16_fused/dir/model.onnx"));
assert!(matches!(spec.tokenizer_source(), TokenizerSource::File(p) if p == Path::new("/fp16_fused/dir/tokenizer.json")));
}
#[test]
fn deberta_base_metadata() {
let spec = DebertaBaseSpec::from_dir("/d");
assert_eq!(spec.name(), "cross-encoder/nli-deberta-v3-base");
assert_eq!(spec.max_length(), 512);
assert_eq!(spec.entailment_idx(), 1);
assert_eq!(spec.vram_bytes(), 2 * 1024 * 1024 * 1024);
}
#[test]
fn spec_from_vram_no_dirs_returns_minilm() {
let spec = spec_from_vram(Path::new("/nonexistent"), 16 * 1024 * 1024 * 1024);
assert_eq!(spec.name(), "cross-encoder/nli-MiniLM2-L6-H768");
}
#[test]
fn spec_from_vram_selects_minilm_when_low_vram() {
let spec = spec_from_vram(Path::new("/nonexistent"), 512 * 1024 * 1024);
assert_eq!(spec.name(), "cross-encoder/nli-MiniLM2-L6-H768");
}
#[test]
fn spec_from_vram_with_real_models_dir_selects_best_available() {
let models_dir = Path::new("../../models/nli-base/fp16_fused");
if !models_dir.exists() {
return; }
let spec = spec_from_vram(models_dir, 0);
assert_eq!(spec.name(), "cross-encoder/nli-MiniLM2-L6-H768");
}
#[test]
fn token_source_file_convenience() {
let ts = TokenizerSource::file("/tmp/tok.json");
assert!(matches!(ts, TokenizerSource::File(p) if p == Path::new("/tmp/tok.json")));
}
#[test]
fn token_source_hub_convenience() {
let ts = TokenizerSource::hub("cross-encoder/nli-deberta-v3-base");
assert!(matches!(ts, TokenizerSource::HuggingFace(s) if s == "cross-encoder/nli-deberta-v3-base"));
}
#[test]
fn minilm_new_constructor() {
let spec = MiniLmSpec::new(
dummy_path("model.onnx"),
TokenizerSource::file(dummy_path("tok.json")),
);
assert_eq!(spec.model_path(), Path::new("/nonexistent/model.onnx"));
}
}