use std::path::Path;
use tracing::debug;
use crate::hardware::AcceleratorType;
use crate::profile::AcceleratorProfile;
pub(crate) fn enrich_pcie(
profiles: &mut [AcceleratorProfile],
nvidia_addrs: &[String],
amdgpu_addrs: &[String],
) {
let mut count = 0usize;
let mut nvidia_idx = 0usize;
let mut amdgpu_idx = 0usize;
for profile in profiles.iter_mut() {
let addr = match &profile.accelerator {
AcceleratorType::CudaGpu { .. } => {
let a = nvidia_addrs.get(nvidia_idx).cloned();
nvidia_idx += 1;
a
}
AcceleratorType::RocmGpu { .. } => {
let a = amdgpu_addrs.get(amdgpu_idx).cloned();
amdgpu_idx += 1;
a
}
_ => None,
};
if let Some(addr) = addr {
let device_path = format!("/sys/bus/pci/devices/{}", addr);
if let Ok(canonical) = std::fs::canonicalize(&device_path) {
if !canonical.starts_with("/sys/") {
continue;
}
if let Some(bw) = read_pcie_bandwidth(&canonical) {
debug!(addr = %addr, bandwidth_gbps = bw, "PCIe link detected");
profile.pcie_bandwidth_gbps = Some(bw);
count += 1;
}
}
}
}
debug!(enriched = count, "PCIe bandwidth enrichment complete");
}
fn read_pcie_bandwidth(device_path: &Path) -> Option<f64> {
let width_str = super::read_sysfs_string(&device_path.join("current_link_width"), 256)?;
let speed_str = super::read_sysfs_string(&device_path.join("current_link_speed"), 256)?;
let width: f64 = width_str.trim().parse().ok()?;
let speed_gts = parse_link_speed(speed_str.trim())?;
let encoding_overhead = if speed_gts >= crate::units::PCIE_GEN3_SPEED_GTS {
crate::units::PCIE_GEN3_PLUS_ENCODING
} else {
crate::units::PCIE_GEN1_GEN2_ENCODING
};
let bandwidth_gbps = speed_gts * width * encoding_overhead / crate::units::BITS_PER_BYTE;
Some((bandwidth_gbps * 100.0).round() / 100.0)
}
pub fn parse_link_speed(s: &str) -> Option<f64> {
let numeric = s.split_whitespace().next()?;
numeric.parse::<f64>().ok()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_link_speed_values() {
assert_eq!(parse_link_speed("16 GT/s"), Some(16.0));
assert_eq!(parse_link_speed("8.0 GT/s PCIe"), Some(8.0));
assert_eq!(parse_link_speed("2.5 GT/s"), Some(2.5));
}
#[test]
fn pcie_bandwidth_gen4_x16() {
let bw = read_pcie_bandwidth(Path::new("/nonexistent"));
assert!(bw.is_none());
}
}