use ort::ep::ExecutionProviderDispatch;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TrtProfile {
pub min_batch: usize,
pub min_seq: usize,
pub opt_batch: usize,
pub opt_seq: usize,
pub max_batch: usize,
pub max_seq: usize,
}
impl TrtProfile {
pub const DEFAULT: Self = Self {
min_batch: 1, min_seq: 1,
opt_batch: 32, opt_seq: 64,
max_batch: 64, max_seq: 512,
};
fn to_shape_string(self, batch: usize, seq: usize) -> String {
format!(
"input_ids:{batch}x{seq},attention_mask:{batch}x{seq},token_type_ids:{batch}x{seq}",
)
}
pub fn min_shapes(self) -> String { self.to_shape_string(self.min_batch, self.min_seq) }
pub fn opt_shapes(self) -> String { self.to_shape_string(self.opt_batch, self.opt_seq) }
pub fn max_shapes(self) -> String { self.to_shape_string(self.max_batch, self.max_seq) }
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum JudgeTarget {
#[default]
Cpu,
Cuda,
TensorRt,
Rocm,
DirectMl,
OpenVino,
}
impl JudgeTarget {
pub fn from_name(name: &str) -> Option<Self> {
match name {
"cpu" => Some(Self::Cpu),
"cuda" => Some(Self::Cuda),
"tensorrt" => Some(Self::TensorRt),
"rocm" => Some(Self::Rocm),
"directml" => Some(Self::DirectMl),
"openvino" => Some(Self::OpenVino),
_ => None,
}
}
pub fn as_str(self) -> &'static str {
match self {
Self::Cpu => "cpu",
Self::Cuda => "cuda",
Self::TensorRt => "tensorrt",
Self::Rocm => "rocm",
Self::DirectMl => "directml",
Self::OpenVino => "openvino",
}
}
}
pub struct JudgeBackend {
target: JudgeTarget,
trt_profile: TrtProfile,
}
impl JudgeBackend {
pub fn auto_detect() -> Self {
let target = std::env::args()
.find_map(|a| a.strip_prefix("--judge-target=").map(str::to_owned));
match target.as_deref() {
Some(t) => Self::from_target(t),
None => Self::cpu(),
}
}
pub fn cpu() -> Self {
Self { target: JudgeTarget::Cpu, trt_profile: TrtProfile::DEFAULT }
}
pub fn cuda_or_cpu() -> Self {
if cfg!(feature = "judge_cuda") {
Self { target: JudgeTarget::Cuda, trt_profile: TrtProfile::DEFAULT }
} else {
Self::cpu()
}
}
pub fn with_trt_profile(mut self, profile: TrtProfile) -> Self {
self.trt_profile = profile;
self
}
pub fn from_target(target: &str) -> Self {
match JudgeTarget::from_name(target) {
Some(t) => {
if !Self::target_compiled_in(t) {
tracing::error!(target, "judge target not compiled in; rebuild with 'judge_{{target}}' feature flag");
std::process::exit(1);
}
Self { target: t, trt_profile: TrtProfile::DEFAULT }
}
None => {
tracing::error!(target, "unknown --judge-target; valid: cpu, cuda, tensorrt, rocm, directml, openvino");
std::process::exit(1);
}
}
}
fn target_compiled_in(target: JudgeTarget) -> bool {
match target {
JudgeTarget::Cpu => true,
JudgeTarget::Cuda => cfg!(feature = "judge_cuda"),
JudgeTarget::TensorRt => cfg!(feature = "judge_tensorrt"),
JudgeTarget::Rocm => cfg!(feature = "judge_rocm"),
JudgeTarget::DirectMl => cfg!(feature = "judge_directml"),
JudgeTarget::OpenVino => cfg!(feature = "judge_openvino"),
}
}
pub fn name(&self) -> &'static str {
self.target.as_str()
}
pub fn models_subdir(&self) -> &'static str {
match self.target {
JudgeTarget::TensorRt | JudgeTarget::Cpu => "base",
_ => "fp16_fused",
}
}
pub fn resolve_models_dir(&self, base: &std::path::Path) -> std::path::PathBuf {
match self.target {
JudgeTarget::TensorRt | JudgeTarget::Cpu => base.join("base"),
_ => {
let fp16_fused = base.join("fp16_fused");
if fp16_fused.exists() {
return fp16_fused;
}
let fp16 = base.join("fp16");
if fp16.exists() {
return fp16;
}
base.join("base")
}
}
}
pub fn target(&self) -> JudgeTarget {
self.target
}
pub fn execution_providers(&self) -> Vec<ExecutionProviderDispatch> {
let mut eps: Vec<ExecutionProviderDispatch> = vec![];
match self.target {
JudgeTarget::Cpu => {}
JudgeTarget::Cuda => {
#[cfg(feature = "judge_cuda")]
eps.push(ort::ep::CUDA::default().build());
#[cfg(not(feature = "judge_cuda"))]
unreachable!("judge_cuda feature not compiled in, guarded by from_target()");
}
JudgeTarget::TensorRt => {
#[cfg(feature = "judge_tensorrt")]
{
let cache_dir = std::env::var("HOME")
.unwrap_or_else(|_| ".".to_string())
+ "/.cache/zer-judge/trt-engines";
let _ = std::fs::create_dir_all(&cache_dir);
let p = self.trt_profile;
eps.push(
ort::ep::TensorRT::default()
.with_fp16(true)
.with_engine_cache(true)
.with_engine_cache_path(&cache_dir)
.with_profile_min_shapes(&p.min_shapes())
.with_profile_opt_shapes(&p.opt_shapes())
.with_profile_max_shapes(&p.max_shapes())
.build(),
);
#[cfg(feature = "judge_cuda")]
eps.push(ort::ep::CUDA::default().build());
}
#[cfg(not(feature = "judge_tensorrt"))]
unreachable!("judge_tensorrt feature not compiled in, guarded by from_target()");
}
JudgeTarget::Rocm => {
#[cfg(feature = "judge_rocm")]
eps.push(ort::ep::ROCm::default().build());
#[cfg(not(feature = "judge_rocm"))]
unreachable!("judge_rocm feature not compiled in");
}
JudgeTarget::DirectMl => {
#[cfg(feature = "judge_directml")]
eps.push(ort::ep::DirectML::default().build());
#[cfg(not(feature = "judge_directml"))]
unreachable!("judge_directml feature not compiled in");
}
JudgeTarget::OpenVino => {
#[cfg(feature = "judge_openvino")]
eps.push(ort::ep::OpenVINO::default().build());
#[cfg(not(feature = "judge_openvino"))]
unreachable!("judge_openvino feature not compiled in");
}
}
eps.push(ort::ep::CPU::default().build());
eps
}
pub fn configure_session(
&self,
builder: ort::session::builder::SessionBuilder,
) -> ort::Result<ort::session::builder::SessionBuilder> {
Ok(builder.with_execution_providers(self.execution_providers())?)
}
}
impl std::fmt::Display for JudgeBackend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "JudgeBackend({})", self.name())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn from_name_cpu() {
assert_eq!(JudgeTarget::from_name("cpu"), Some(JudgeTarget::Cpu));
}
#[test]
fn from_name_cuda() {
assert_eq!(JudgeTarget::from_name("cuda"), Some(JudgeTarget::Cuda));
}
#[test]
fn from_name_tensorrt() {
assert_eq!(JudgeTarget::from_name("tensorrt"), Some(JudgeTarget::TensorRt));
}
#[test]
fn from_name_rocm() {
assert_eq!(JudgeTarget::from_name("rocm"), Some(JudgeTarget::Rocm));
}
#[test]
fn from_name_directml() {
assert_eq!(JudgeTarget::from_name("directml"), Some(JudgeTarget::DirectMl));
}
#[test]
fn from_name_openvino() {
assert_eq!(JudgeTarget::from_name("openvino"), Some(JudgeTarget::OpenVino));
}
#[test]
fn from_name_unknown_returns_none() {
assert_eq!(JudgeTarget::from_name("vulkan"), None);
assert_eq!(JudgeTarget::from_name(""), None);
assert_eq!(JudgeTarget::from_name("CUDA"), None);
}
#[test]
fn as_str_round_trips_for_all_variants() {
let targets = [
JudgeTarget::Cpu,
JudgeTarget::Cuda,
JudgeTarget::TensorRt,
JudgeTarget::Rocm,
JudgeTarget::DirectMl,
JudgeTarget::OpenVino,
];
for target in targets {
let name = target.as_str();
assert_eq!(
JudgeTarget::from_name(name),
Some(target),
"round-trip failed for {name}"
);
}
}
#[test]
fn judge_backend_cpu_has_cpu_name() {
let backend = JudgeBackend::cpu();
assert_eq!(backend.name(), "cpu");
assert_eq!(backend.target(), JudgeTarget::Cpu);
}
#[test]
fn judge_backend_display() {
let backend = JudgeBackend::cpu();
assert_eq!(format!("{backend}"), "JudgeBackend(cpu)");
}
#[test]
fn cpu_execution_providers_has_cpu_fallback() {
let backend = JudgeBackend::cpu();
let eps = backend.execution_providers();
assert!(!eps.is_empty(), "execution_providers must never return an empty vec");
}
#[test]
fn cpu_target_is_always_compiled_in() {
assert!(JudgeBackend::target_compiled_in(JudgeTarget::Cpu));
}
#[test]
fn models_subdir_trt_returns_base() {
let mut backend = JudgeBackend::cpu();
backend.target = JudgeTarget::TensorRt;
assert_eq!(backend.models_subdir(), "base");
}
#[test]
fn models_subdir_cpu_returns_base() {
let backend = JudgeBackend::cpu();
assert_eq!(backend.models_subdir(), "base");
}
#[test]
fn models_subdir_gpu_providers_return_fp16_fused() {
for target in [JudgeTarget::Cuda, JudgeTarget::Rocm, JudgeTarget::DirectMl, JudgeTarget::OpenVino] {
let mut backend = JudgeBackend::cpu();
backend.target = target;
assert_eq!(backend.models_subdir(), "fp16_fused", "expected fp16_fused for {}", target.as_str());
}
}
#[test]
fn resolve_models_dir_trt_always_returns_base() {
let tmp = std::env::temp_dir();
let mut backend = JudgeBackend::cpu();
backend.target = JudgeTarget::TensorRt;
assert_eq!(backend.resolve_models_dir(&tmp), tmp.join("base"));
}
#[test]
fn resolve_models_dir_cpu_always_returns_base() {
let tmp = std::env::temp_dir();
let backend = JudgeBackend::cpu();
assert_eq!(backend.resolve_models_dir(&tmp), tmp.join("base"));
}
#[test]
fn resolve_models_dir_cuda_falls_back_to_base_when_no_dirs_exist() {
let base = std::path::Path::new("/nonexistent/models/nli-base");
let mut backend = JudgeBackend::cpu();
backend.target = JudgeTarget::Cuda;
assert_eq!(backend.resolve_models_dir(base), base.join("base"));
}
#[test]
fn trt_profile_default_shape_strings() {
let p = TrtProfile::DEFAULT;
assert_eq!(p.min_shapes(), "input_ids:1x1,attention_mask:1x1,token_type_ids:1x1");
assert_eq!(p.opt_shapes(), "input_ids:32x64,attention_mask:32x64,token_type_ids:32x64");
assert_eq!(p.max_shapes(), "input_ids:64x512,attention_mask:64x512,token_type_ids:64x512");
}
#[test]
fn trt_profile_custom_values() {
let p = TrtProfile { min_batch: 1, min_seq: 1, opt_batch: 16, opt_seq: 128, max_batch: 32, max_seq: 256 };
assert_eq!(p.opt_shapes(), "input_ids:16x128,attention_mask:16x128,token_type_ids:16x128");
assert_eq!(p.max_shapes(), "input_ids:32x256,attention_mask:32x256,token_type_ids:32x256");
}
#[test]
fn with_trt_profile_overrides_default() {
let custom = TrtProfile { min_batch: 1, min_seq: 1, opt_batch: 8, opt_seq: 32, max_batch: 16, max_seq: 128 };
let backend = JudgeBackend::cpu().with_trt_profile(custom);
assert_eq!(backend.trt_profile, custom);
}
}