use std::collections::HashMap;
use hf_fetch_model::inspect::TensorInfo;
use hypomnesis::GpuDeviceInfo;
use crate::format::format_size;
#[derive(Debug)]
#[allow(clippy::exhaustive_structs)] pub struct GpuCheckResult {
pub device_index: u32,
pub device: Option<GpuDeviceInfo>,
pub error: Option<String>,
}
#[must_use]
pub fn query_gpu(index: u32) -> GpuCheckResult {
match hypomnesis::device_info(index) {
Ok(device) => GpuCheckResult {
device_index: index,
device: Some(device),
error: None,
},
Err(e) => GpuCheckResult {
device_index: index,
device: None,
error: Some(friendly_error(&e)),
},
}
}
fn friendly_error(err: &hypomnesis::HypomnesisError) -> String {
use hypomnesis::HypomnesisError as E;
match err {
E::DeviceIndexOutOfRange { index, count } => {
let plural = if *count == 1 { "device" } else { "devices" };
format!("index {index} out of range (have {count} {plural})")
}
E::NoGpuSource => "no NVIDIA device detected (NVML / DXGI not usable)".to_owned(),
E::Nvml(s) => format!("NVML backend reported: {s}"),
E::Dxgi(s) => format!("DXGI backend reported: {s}"),
E::NvidiaSmi(s) => format!("nvidia-smi backend reported: {s}"),
E::Ram(_) | E::Io(_) => format!("unexpected error: {err}"),
_ => format!("hypomnesis error: {err}"),
}
}
#[must_use]
pub fn sum_tensor_bytes(tensors: &[TensorInfo]) -> u64 {
tensors
.iter()
.map(TensorInfo::byte_len)
.fold(0u64, u64::saturating_add)
}
#[must_use]
pub fn dominant_dtype_label(tensors: &[TensorInfo]) -> String {
if tensors.is_empty() {
return "unknown".to_owned();
}
let mut by_dtype: HashMap<&str, u64> = HashMap::new();
for t in tensors {
let entry = by_dtype.entry(t.dtype.as_str()).or_insert(0u64);
*entry = entry.saturating_add(t.num_elements());
}
let Some((dominant, &dominant_count)) = by_dtype.iter().max_by_key(|(_, c)| **c) else {
return "unknown".to_owned();
};
if by_dtype.len() == 1 {
return (*dominant).to_owned();
}
let total = by_dtype.values().copied().fold(0u64, u64::saturating_add);
if total == 0 {
return (*dominant).to_owned();
}
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
let ratio = dominant_count as f64 / total as f64;
if ratio >= 0.99 {
(*dominant).to_owned()
} else {
format!("{dominant} + others")
}
}
pub fn print_gpu_check(
result: &GpuCheckResult,
weight_bytes: u64,
dtype_label: &str,
total_params: u64,
) {
println!();
println!(
" Model weights: {} ({dtype_label}, {} params)",
format_size(weight_bytes),
format_params(total_params),
);
let Some(ref dev) = result.device else {
let msg = result
.error
.as_deref()
.unwrap_or("device info unavailable (no further detail)");
println!(
" GPU {}: unavailable — {msg}",
result.device_index
);
return;
};
let name = dev.name.as_deref().unwrap_or("unknown GPU");
println!(
" GPU {}: {name} — {} VRAM",
result.device_index,
format_size(dev.total_bytes),
);
println!(
" free: {}, used: {}",
format_size(dev.free_bytes),
format_size(dev.used_bytes),
);
if dev.free_bytes >= weight_bytes {
let headroom = dev.free_bytes - weight_bytes;
println!(
" Fit: \u{2713} {} headroom for weights + KV cache + runtime",
format_size(headroom),
);
} else {
let short = weight_bytes - dev.free_bytes;
println!(
" Fit: \u{2717} short by {} for the weights alone",
format_size(short),
);
}
println!();
println!(
" Note: reports weights only. Large-context inference typically needs ~1.3\u{2013}1.5\u{00d7}"
);
println!(" weight size for KV cache and activations.");
}
#[must_use]
pub fn gpu_check_json(
result: &GpuCheckResult,
weight_bytes: u64,
dtype_label: &str,
total_params: u64,
) -> serde_json::Value {
let mut out = serde_json::Map::new();
out.insert(
"device_index".to_owned(),
serde_json::Value::Number(result.device_index.into()),
);
if let Some(ref dev) = result.device {
let mut dev_obj = serde_json::Map::new();
if let Some(ref name) = dev.name {
dev_obj.insert("name".to_owned(), serde_json::Value::String(name.clone()));
}
dev_obj.insert(
"total_bytes".to_owned(),
serde_json::Value::Number(dev.total_bytes.into()),
);
dev_obj.insert(
"free_bytes".to_owned(),
serde_json::Value::Number(dev.free_bytes.into()),
);
dev_obj.insert(
"used_bytes".to_owned(),
serde_json::Value::Number(dev.used_bytes.into()),
);
out.insert("device".to_owned(), serde_json::Value::Object(dev_obj));
let fits = dev.free_bytes >= weight_bytes;
out.insert("fits".to_owned(), serde_json::Value::Bool(fits));
if fits {
out.insert(
"headroom_bytes".to_owned(),
serde_json::Value::Number((dev.free_bytes - weight_bytes).into()),
);
} else {
out.insert(
"short_bytes".to_owned(),
serde_json::Value::Number((weight_bytes - dev.free_bytes).into()),
);
}
}
if let Some(ref msg) = result.error {
out.insert("error".to_owned(), serde_json::Value::String(msg.clone()));
}
let mut model = serde_json::Map::new();
model.insert(
"weight_bytes".to_owned(),
serde_json::Value::Number(weight_bytes.into()),
);
model.insert(
"dtype_label".to_owned(),
serde_json::Value::String(dtype_label.to_owned()),
);
model.insert(
"total_params".to_owned(),
serde_json::Value::Number(total_params.into()),
);
out.insert("model".to_owned(), serde_json::Value::Object(model));
serde_json::Value::Object(out)
}
fn format_params(n: u64) -> String {
const M: u64 = 1_000_000;
const B: u64 = 1_000_000_000;
const T: u64 = 1_000_000_000_000;
if n >= T {
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
let v = n as f64 / T as f64;
format!("{v:.2}T")
} else if n >= B {
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
let v = n as f64 / B as f64;
format!("{v:.2}B")
} else if n >= M {
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
let v = n as f64 / M as f64;
format!("{v:.1}M")
} else if n >= 1_000 {
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
let v = n as f64 / 1_000.0_f64;
format!("{v:.1}K")
} else {
format!("{n}")
}
}
#[cfg(test)]
mod tests {
#![allow(
clippy::panic,
clippy::unwrap_used,
clippy::expect_used,
clippy::indexing_slicing
)]
use super::{
dominant_dtype_label, format_params, gpu_check_json, sum_tensor_bytes, GpuCheckResult,
};
use hf_fetch_model::inspect::TensorInfo;
fn make_tensor(name: &str, dtype: &str, shape: Vec<usize>, byte_len: u64) -> TensorInfo {
TensorInfo {
name: name.to_owned(),
dtype: dtype.to_owned(),
shape,
data_offsets: (0, byte_len),
}
}
#[test]
fn sum_tensor_bytes_empty() {
assert_eq!(sum_tensor_bytes(&[]), 0);
}
#[test]
fn sum_tensor_bytes_one_tensor() {
let t = make_tensor("a", "BF16", vec![10, 10], 200);
assert_eq!(sum_tensor_bytes(&[t]), 200);
}
#[test]
fn sum_tensor_bytes_many() {
let a = make_tensor("a", "BF16", vec![4, 4], 32);
let b = make_tensor("b", "F16", vec![4, 4], 32);
let c = make_tensor("c", "F32", vec![4, 4], 64);
assert_eq!(sum_tensor_bytes(&[a, b, c]), 128);
}
#[test]
fn dominant_dtype_label_empty() {
assert_eq!(dominant_dtype_label(&[]), "unknown");
}
#[test]
fn dominant_dtype_label_pure() {
let tensors = vec![
make_tensor("a", "BF16", vec![1000, 1000], 2_000_000),
make_tensor("b", "BF16", vec![1000, 1000], 2_000_000),
];
assert_eq!(dominant_dtype_label(&tensors), "BF16");
}
#[test]
fn dominant_dtype_label_near_pure_collapses_to_dominant() {
let tensors = vec![
make_tensor("weight", "BF16", vec![1000, 1000], 2_000_000),
make_tensor("norm", "F32", vec![10], 40),
];
assert_eq!(dominant_dtype_label(&tensors), "BF16");
}
#[test]
fn dominant_dtype_label_true_mixed_flags_others() {
let tensors = vec![
make_tensor("a", "BF16", vec![60, 100], 12_000),
make_tensor("b", "F8_E4M3", vec![40, 100], 4_000),
];
assert_eq!(dominant_dtype_label(&tensors), "BF16 + others");
}
#[test]
fn format_params_buckets() {
assert_eq!(format_params(0), "0");
assert_eq!(format_params(999), "999");
assert_eq!(format_params(1_500), "1.5K");
assert_eq!(format_params(2_000_000), "2.0M");
assert_eq!(format_params(5_120_000_000), "5.12B");
assert_eq!(format_params(1_500_000_000_000), "1.50T");
}
#[test]
fn gpu_check_json_error_path() {
let result = GpuCheckResult {
device_index: 3,
device: None,
error: Some("index 3 out of range (have 1 device)".to_owned()),
};
let v = gpu_check_json(&result, 1024, "BF16", 100);
assert_eq!(v.get("device_index"), Some(&serde_json::json!(3)));
assert_eq!(
v.get("error"),
Some(&serde_json::json!("index 3 out of range (have 1 device)"))
);
assert!(v.get("device").is_none());
assert!(v.get("fits").is_none());
let model = v.get("model").expect("model object present");
assert_eq!(model.get("weight_bytes"), Some(&serde_json::json!(1024)));
assert_eq!(model.get("dtype_label"), Some(&serde_json::json!("BF16")));
assert_eq!(model.get("total_params"), Some(&serde_json::json!(100)));
}
}