#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct SmVersion(pub u32);
impl SmVersion {
pub fn ptx_version_str(self) -> &'static str {
if self.0 >= 100 {
"8.7"
} else if self.0 >= 90 {
"8.4"
} else if self.0 >= 80 {
"8.0"
} else {
"7.5"
}
}
pub fn target_str(self) -> String {
format!("sm_{}", self.0)
}
}
impl std::fmt::Display for SmVersion {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "SM {}.{}", self.0 / 10, self.0 % 10)
}
}
#[derive(Debug, Clone)]
pub struct LmHandle {
pub device: i32,
pub sm_version: SmVersion,
}
impl LmHandle {
pub fn new(device: i32, sm_version: SmVersion) -> Self {
Self { device, sm_version }
}
pub fn default_handle() -> Self {
Self::new(0, SmVersion(80))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sm_version_ptx_strings() {
assert_eq!(SmVersion(75).ptx_version_str(), "7.5");
assert_eq!(SmVersion(80).ptx_version_str(), "8.0");
assert_eq!(SmVersion(86).ptx_version_str(), "8.0");
assert_eq!(SmVersion(90).ptx_version_str(), "8.4");
assert_eq!(SmVersion(100).ptx_version_str(), "8.7");
assert_eq!(SmVersion(120).ptx_version_str(), "8.7");
}
#[test]
fn sm_version_target_str() {
assert_eq!(SmVersion(80).target_str(), "sm_80");
assert_eq!(SmVersion(90).target_str(), "sm_90");
assert_eq!(SmVersion(120).target_str(), "sm_120");
}
#[test]
fn sm_version_display() {
assert_eq!(SmVersion(80).to_string(), "SM 8.0");
assert_eq!(SmVersion(90).to_string(), "SM 9.0");
}
#[test]
fn sm_version_ordering() {
assert!(SmVersion(80) < SmVersion(90));
assert!(SmVersion(100) > SmVersion(90));
assert_eq!(SmVersion(80), SmVersion(80));
}
#[test]
fn lm_handle_default() {
let h = LmHandle::default_handle();
assert_eq!(h.device, 0);
assert_eq!(h.sm_version, SmVersion(80));
}
#[test]
fn lm_handle_custom() {
let h = LmHandle::new(2, SmVersion(90));
assert_eq!(h.device, 2);
assert_eq!(h.sm_version, SmVersion(90));
}
}