use std::path::Path;
use tracing::debug;
use crate::error::DetectionError;
use crate::hardware::AcceleratorType;
use crate::profile::AcceleratorProfile;
use super::command::{DEFAULT_TIMEOUT, run_tool};
const NVIDIA_BW_ARGS: &[&str] = &[
"--query-gpu=clocks.max.memory,compute_cap",
"--format=csv,noheader,nounits",
];
pub(crate) fn enrich_bandwidth(
profiles: &mut [AcceleratorProfile],
warnings: &mut Vec<DetectionError>,
) {
let all_cuda_have_bw = profiles.iter().all(|p| {
!matches!(p.accelerator, AcceleratorType::CudaGpu { .. })
|| p.memory_bandwidth_gbps.is_some()
});
let cuda_bw = if all_cuda_have_bw {
Vec::new()
} else {
query_nvidia_bandwidth(warnings)
};
let count = apply_bandwidth(profiles, &cuda_bw);
debug!(enriched = count, "memory bandwidth enrichment complete");
}
#[cfg(feature = "async-detect")]
pub(crate) async fn enrich_bandwidth_async(
profiles: &mut [AcceleratorProfile],
warnings: &mut Vec<DetectionError>,
) {
let cuda_bw =
match super::command::run_tool_async("nvidia-smi", NVIDIA_BW_ARGS, DEFAULT_TIMEOUT).await {
Ok(o) => parse_nvidia_bandwidth_output(&o.stdout),
Err(_) => Vec::new(),
};
let count = apply_bandwidth(profiles, &cuda_bw);
debug!(enriched = count, "memory bandwidth enrichment complete");
let _ = warnings; }
fn apply_bandwidth(profiles: &mut [AcceleratorProfile], cuda_bw: &[Option<f64>]) -> usize {
let mut nvidia_idx = 0usize;
let mut count = 0usize;
for profile in profiles.iter_mut() {
match &profile.accelerator {
AcceleratorType::CudaGpu { device_id } => {
if profile.memory_bandwidth_gbps.is_none() {
if let Some(bw) = cuda_bw.get(nvidia_idx).copied().flatten() {
profile.memory_bandwidth_gbps = Some(bw);
count += 1;
} else if let Some(cc) = &profile.compute_capability {
if let Some(bw) = estimate_nvidia_bandwidth_from_cc(cc) {
profile.memory_bandwidth_gbps = Some(bw);
count += 1;
} else {
let device_id = *device_id;
debug!(device_id, "no memory bandwidth data available for CUDA GPU");
}
} else {
let device_id = *device_id;
debug!(device_id, "no memory bandwidth data available for CUDA GPU");
}
} else {
count += 1;
}
nvidia_idx += 1;
}
AcceleratorType::RocmGpu { device_id } => {
if let Some(bw) = probe_rocm_bandwidth(*device_id) {
profile.memory_bandwidth_gbps = Some(bw);
count += 1;
}
}
_ => {}
}
}
count
}
fn query_nvidia_bandwidth(_warnings: &mut Vec<DetectionError>) -> Vec<Option<f64>> {
let output = match run_tool("nvidia-smi", NVIDIA_BW_ARGS, DEFAULT_TIMEOUT) {
Ok(o) => o,
Err(_) => return Vec::new(),
};
parse_nvidia_bandwidth_output(&output.stdout)
}
pub fn parse_nvidia_bandwidth_output(stdout: &str) -> Vec<Option<f64>> {
stdout
.lines()
.map(|line| {
let parts: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
if parts.len() < 2 {
return None;
}
let max_mem_clock_mhz: f64 = parts[0].parse().ok()?;
let cc = parts[1];
let bus_width = nvidia_bus_width_bits(cc)?;
let bw = calculate_bandwidth(max_mem_clock_mhz, bus_width);
debug!(
compute_cap = cc,
mem_clock_mhz = max_mem_clock_mhz,
bus_width,
bandwidth_gbps = bw,
"NVIDIA memory bandwidth calculated"
);
Some(bw)
})
.collect()
}
fn calculate_bandwidth(clock_mhz: f64, bus_width_bits: u32) -> f64 {
let bw = clock_mhz * bus_width_bits as f64 * crate::units::DDR_MULTIPLIER
/ crate::units::BITS_PER_BYTE
/ crate::units::MHZ_PER_GHZ;
(bw * 10.0).round() / 10.0 }
pub fn nvidia_bus_width_bits(cc: &str) -> Option<u32> {
match cc {
"9.0" => Some(5120),
"8.9" => Some(384), "8.0" => Some(5120), "8.6" => Some(384),
"7.5" => Some(352),
"7.0" => Some(4096),
"6.0" => Some(4096), "6.1" => Some(352),
"10.0" => Some(8192),
_ => None,
}
}
pub fn estimate_nvidia_bandwidth_from_cc(cc: &str) -> Option<f64> {
match cc {
"10.0" => Some(8000.0), "9.0" => Some(3350.0), "8.9" => Some(1008.0), "8.6" => Some(936.0), "8.0" => Some(2039.0), "7.5" => Some(616.0), "7.0" => Some(900.0), "6.1" => Some(484.0), "6.0" => Some(732.0), _ => None,
}
}
fn probe_rocm_bandwidth(device_id: u32) -> Option<f64> {
let drm = Path::new("/sys/class/drm");
let mut card_idx = 0u32;
for entry in std::fs::read_dir(drm).ok()?.flatten() {
let name = entry.file_name();
let name_str = name.to_string_lossy();
if !name_str.starts_with("card") || name_str.contains('-') {
continue;
}
let device_dir = entry.path().join("device");
let driver_link = device_dir.join("driver");
let driver_name = std::fs::read_link(&driver_link)
.ok()
.and_then(|p| p.file_name().map(|n| n.to_string_lossy().into_owned()));
if driver_name.as_deref() != Some("amdgpu") {
continue;
}
if card_idx == device_id {
return read_rocm_bandwidth(&device_dir);
}
card_idx += 1;
}
None
}
fn read_rocm_bandwidth(device_dir: &Path) -> Option<f64> {
let dpm_mclk = super::read_sysfs_string(&device_dir.join("pp_dpm_mclk"), 4096)?;
let max_mclk_mhz = parse_max_dpm_clock(&dpm_mclk)?;
let bus_width = read_amd_bus_width(device_dir);
let bw = calculate_bandwidth(max_mclk_mhz, bus_width);
debug!(
mem_clock_mhz = max_mclk_mhz,
bus_width,
bandwidth_gbps = bw,
"AMD memory bandwidth calculated"
);
Some(bw)
}
pub fn parse_max_dpm_clock(content: &str) -> Option<f64> {
let mut max_clock = 0.0f64;
for line in content.lines() {
if let Some(mhz_str) = line
.split_whitespace()
.nth(1)
.and_then(|s| s.strip_suffix("Mhz").or_else(|| s.strip_suffix("MHz")))
&& let Ok(mhz) = mhz_str.parse::<f64>()
{
max_clock = max_clock.max(mhz);
}
}
if max_clock > 0.0 {
Some(max_clock)
} else {
None
}
}
fn read_amd_bus_width(device_dir: &Path) -> u32 {
if let Some(width) = read_amd_bus_width_from_device_id(device_dir) {
return width;
}
let vram = super::read_sysfs_u64(&device_dir.join("mem_info_vram_total")).unwrap_or(0);
let vram_gb = vram / (1024 * 1024 * 1024);
match vram_gb {
0..=4 => 128, 5..=8 => 256, 9..=16 => 256, 17..=24 => 384, 25..=48 => 4096, 49..=96 => 4096, 97..=192 => 8192, _ => 256,
}
}
fn read_amd_bus_width_from_device_id(device_dir: &Path) -> Option<u32> {
let device_id = super::read_sysfs_string(&device_dir.join("device"), 64)?;
let device_id = device_id
.trim()
.strip_prefix("0x")
.unwrap_or(device_id.trim());
match device_id {
"740c" | "740f" => Some(8192),
"740a" | "7408" => Some(4096),
"7400" | "7401" | "7402" | "7403" | "7404" | "7405" => Some(4096),
"738c" | "738e" => Some(4096),
"66a1" | "66a0" | "66af" => Some(4096),
"744c" | "7448" => Some(384),
"7480" | "7470" => Some(256),
"73bf" | "73a5" => Some(256),
"73df" => Some(192),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn calculate_bandwidth_h100() {
let bw = calculate_bandwidth(2619.0, 5120);
assert!((bw - 3352.3).abs() < 1.0, "H100 bandwidth: {}", bw);
}
#[test]
fn calculate_bandwidth_a100() {
let bw = calculate_bandwidth(1593.0, 5120);
assert!((bw - 2039.0).abs() < 1.0, "A100 bandwidth: {}", bw);
}
#[test]
fn calculate_bandwidth_rtx4090() {
let bw = calculate_bandwidth(10501.0, 384);
assert!((bw - 1008.1).abs() < 1.0, "RTX 4090 bandwidth: {}", bw);
}
#[test]
fn parse_max_dpm_clock_normal() {
let input = "0: 96Mhz\n1: 1000Mhz *\n";
assert_eq!(parse_max_dpm_clock(input), Some(1000.0));
}
#[test]
fn parse_max_dpm_clock_hbm() {
let input = "0: 500Mhz\n1: 900Mhz\n2: 1600Mhz *\n";
assert_eq!(parse_max_dpm_clock(input), Some(1600.0));
}
#[test]
fn parse_max_dpm_clock_empty() {
assert_eq!(parse_max_dpm_clock(""), None);
}
}