use std::path::PathBuf;
pub fn validate_compute_target(target: &str) -> anyhow::Result<()> {
match target {
"auto" | "cpu" => {}
"cuda" => {
#[cfg(not(feature = "cuda"))]
anyhow::bail!(
"--target cuda requires compiling with --features=cuda\n \
Rebuild: cargo build --release -p zer-bench --features=cuda"
);
}
"avx2" => {
#[cfg(not(feature = "avx2"))]
anyhow::bail!(
"--target avx2 requires compiling with --features=avx2\n \
Rebuild: cargo build --release -p zer-bench --features=avx2"
);
}
"vulkan" => {
#[cfg(not(feature = "vulkan"))]
anyhow::bail!(
"--target vulkan requires compiling with --features=vulkan\n \
Rebuild: cargo build --release -p zer-bench --features=vulkan"
);
}
other => {
anyhow::bail!("unknown compute target: {other:?}; valid: auto, cpu, cuda, avx2, vulkan")
}
}
Ok(())
}
pub fn validate_judge_target(target: &str) -> anyhow::Result<()> {
match target {
"cpu" => {}
"cuda" => {
#[cfg(not(feature = "judge_cuda"))]
anyhow::bail!(
"--judge-target cuda requires compiling with --features=judge_cuda\n \
Rebuild: cargo build --release -p zer-bench --features=judge_cuda"
);
}
"tensorrt" => {
#[cfg(not(feature = "judge_tensorrt"))]
anyhow::bail!(
"--judge-target tensorrt requires compiling with --features=judge_tensorrt\n \
Rebuild: cargo build --release -p zer-bench --features=judge_tensorrt"
);
}
"rocm" => {
#[cfg(not(feature = "judge_rocm"))]
anyhow::bail!(
"--judge-target rocm requires compiling with --features=judge_rocm\n \
Rebuild: cargo build --release -p zer-bench --features=judge_rocm"
);
}
"directml" => {
#[cfg(not(feature = "judge_directml"))]
anyhow::bail!(
"--judge-target directml requires compiling with --features=judge_directml\n \
Rebuild: cargo build --release -p zer-bench --features=judge_directml"
);
}
"openvino" => {
#[cfg(not(feature = "judge_openvino"))]
anyhow::bail!(
"--judge-target openvino requires compiling with --features=judge_openvino\n \
Rebuild: cargo build --release -p zer-bench --features=judge_openvino"
);
}
other => anyhow::bail!(
"unknown judge target: {other:?}; valid: cpu, cuda, tensorrt, rocm, directml, openvino"
),
}
Ok(())
}
pub fn print_bench_header(parts: &[&str]) {
let content = parts.join(" │ ");
let width = content.chars().count() + 4; let bar = "═".repeat(width.max(52));
println!("\n{bar}");
println!(" {content}");
println!("{bar}");
}
pub fn log_trt_cache_status() {
let cache_dir = format!(
"{}/.cache/zer-judge/trt-engines",
std::env::var("HOME").unwrap_or_else(|_| ".".to_owned())
);
let engine_count = std::fs::read_dir(&cache_dir)
.map(|rd| {
rd.flatten()
.filter(|e| e.path().extension().map_or(false, |x| x == "engine"))
.count()
})
.unwrap_or(0);
if engine_count > 0 {
println!(
"TRT warm: cached engines found engine_count={engine_count} cache_dir={cache_dir}"
);
} else {
println!("TRT cold: no cached engines, TRT will compile now (takes 2-5 min, cached after) cache_dir={cache_dir}");
}
}
pub fn workspace_root() -> PathBuf {
let mut dir = std::env::current_dir().unwrap_or_default();
loop {
let candidate = dir.join("Cargo.toml");
if candidate.exists() {
if let Ok(content) = std::fs::read_to_string(&candidate) {
if content.contains("[workspace]") {
return dir;
}
}
}
if !dir.pop() {
return PathBuf::from(".");
}
}
}
pub fn bench_data_root() -> PathBuf {
if let Ok(dir) = std::env::var("ZER_DATASET_DIR") {
return PathBuf::from(dir);
}
workspace_root().join("data")
}
pub fn resolve_out_dir(path: &str) -> PathBuf {
let p = PathBuf::from(path);
if p.is_absolute() {
p
} else {
workspace_root().join(p)
}
}