#![cfg(feature = "gemma-default")]
use std::process::Stdio;
use atomr_infer_core::error::InferenceError;
use crate::hf_cache::HfCache;
#[derive(Debug)]
pub enum ProbeResult {
Ready {
vram_free_gb: f32,
hf_cache: HfCache,
},
Skipped { reason: String, hint: String },
Error(InferenceError),
}
pub fn probe(
model_id: &str,
min_vram_gb: f32,
min_disk_gb: f32,
suggest_smaller_variant: Option<&str>,
) -> ProbeResult {
let hf_cache = match HfCache::resolve() {
Ok(c) => c,
Err(e) => return ProbeResult::Error(e),
};
let vram_free_gb = match probe_gpu() {
Ok(gb) => gb,
Err(reason) => {
return ProbeResult::Skipped {
reason,
hint: "set ATOMR_INFER_GEMMA_AUTO=skip-quietly to suppress this message; \
set ATOMR_INFER_GEMMA_MODEL=<remote-provider> to use a remote backend"
.into(),
}
}
};
if vram_free_gb < min_vram_gb {
let hint = if let Some(smaller) = suggest_smaller_variant {
format!(
"free VRAM {vram_free_gb:.1} GB < {min_vram_gb:.1} GB needed for {model_id}; \
try ATOMR_INFER_GEMMA_MODEL={smaller}"
)
} else {
format!(
"free VRAM {vram_free_gb:.1} GB < {min_vram_gb:.1} GB needed for {model_id}; \
no smaller supported variant — consider a remote backend"
)
};
return ProbeResult::Skipped {
reason: format!("insufficient VRAM for {model_id}"),
hint,
};
}
match probe_python() {
Ok(()) => {}
Err(reason) => {
return ProbeResult::Skipped {
reason,
hint: "install Python 3.10+ and ensure `python3` is on PATH".into(),
}
}
}
match probe_vllm() {
Ok(version) => tracing::debug!(version, "vllm import probe ok"),
Err(reason) => {
return ProbeResult::Skipped {
reason,
hint: "install vLLM in the active venv: `pip install 'vllm>=0.6.4'`".into(),
}
}
}
match hf_cache.discover_token() {
Ok(Some(_)) => {}
Ok(None) => {
return ProbeResult::Skipped {
reason: "no HuggingFace token found".into(),
hint: format!(
"run `huggingface-cli login`, then accept the Gemma ToS at \
https://huggingface.co/{model_id}"
),
}
}
Err(e) => return ProbeResult::Error(e),
}
if let Some(free) = hf_cache.free_bytes() {
let free_gb = free as f32 / 1e9;
if free_gb < min_disk_gb {
return ProbeResult::Skipped {
reason: format!(
"insufficient disk for {model_id}: {free_gb:.1} GB free at {} \
(need ~{min_disk_gb:.1} GB)",
hf_cache.hub_cache.display()
),
hint: "free up space or set HF_HUB_CACHE to a different mountpoint".into(),
};
}
}
ProbeResult::Ready {
vram_free_gb,
hf_cache,
}
}
fn probe_gpu() -> Result<f32, String> {
let out = std::process::Command::new("nvidia-smi")
.args([
"--query-gpu=memory.free",
"--format=csv,noheader,nounits",
"--id=0",
])
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.output()
.map_err(|e| format!("nvidia-smi not on PATH (no NVIDIA driver?): {e}"))?;
if !out.status.success() {
let stderr = String::from_utf8_lossy(&out.stderr);
return Err(format!(
"nvidia-smi exited non-zero: {}",
stderr.trim()
));
}
let stdout = String::from_utf8_lossy(&out.stdout);
let free_mib: f32 = stdout
.trim()
.lines()
.next()
.ok_or_else(|| "nvidia-smi returned empty output".to_string())?
.trim()
.parse()
.map_err(|e| format!("could not parse nvidia-smi output `{}`: {e}", stdout.trim()))?;
Ok(free_mib / 1024.0)
}
fn probe_python() -> Result<(), String> {
let out = std::process::Command::new("python3")
.arg("--version")
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.output()
.map_err(|e| format!("python3 not on PATH: {e}"))?;
if !out.status.success() {
return Err("python3 --version returned non-zero exit".into());
}
let combined = format!(
"{}{}",
String::from_utf8_lossy(&out.stdout),
String::from_utf8_lossy(&out.stderr)
);
let version = combined
.split_whitespace()
.nth(1)
.ok_or_else(|| format!("could not parse `{}`", combined.trim()))?;
let mut parts = version.split('.');
let major: u32 = parts
.next()
.ok_or("missing major")?
.parse()
.map_err(|e| format!("bad major: {e}"))?;
let minor: u32 = parts
.next()
.ok_or("missing minor")?
.parse()
.map_err(|e| format!("bad minor: {e}"))?;
if (major, minor) < (3, 10) {
return Err(format!(
"Python {major}.{minor} on PATH; vLLM requires 3.10+"
));
}
Ok(())
}
fn probe_vllm() -> Result<String, String> {
let out = std::process::Command::new("python3")
.arg("-c")
.arg("import vllm; print(vllm.__version__)")
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.output()
.map_err(|e| format!("python3 spawn failed: {e}"))?;
if !out.status.success() {
let stderr = String::from_utf8_lossy(&out.stderr);
if stderr.contains("ModuleNotFoundError") || stderr.contains("No module named 'vllm'") {
return Err("vLLM not importable in active python3".into());
}
return Err(format!(
"vllm import probe failed: {}",
stderr.trim()
));
}
let version = String::from_utf8_lossy(&out.stdout).trim().to_string();
Ok(version)
}