#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct NcclCapabilities {
pub version: (i32, i32, i32),
pub has_fp8: bool,
pub has_nvls: bool,
pub has_sharp: bool,
}
impl NcclCapabilities {
pub fn zeroed() -> Self {
Self::default()
}
}
pub fn probe_capabilities() -> NcclCapabilities {
let version_int =
std::panic::catch_unwind(cudarc::nccl::result::get_nccl_version).unwrap_or(Ok(0));
let v = match version_int {
Ok(v) => v,
Err(_) => return NcclCapabilities::zeroed(),
};
if v == 0 {
return NcclCapabilities::zeroed();
}
let (major, minor, patch) = if v >= 20000 {
(v / 10000, (v / 100) % 100, v % 100)
} else {
(v / 1000, (v / 100) % 10, v % 100)
};
let supports_fp8 = (major, minor) >= (2, 20);
NcclCapabilities {
version: (major, minor, patch),
has_fp8: cfg!(feature = "nccl-fp8") && supports_fp8,
has_nvls: cfg!(feature = "nccl-nvls"),
has_sharp: false,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn probe_returns_zeroed_when_nccl_uninit() {
let caps = probe_capabilities();
if caps.version == (0, 0, 0) {
assert_eq!(caps, NcclCapabilities::zeroed());
} else {
assert!(caps.version.0 >= 1);
}
}
}