use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PtxType {
Pred,
U8,
U16,
U32,
U64,
S8,
S16,
S32,
S64,
F16,
BF16,
F32,
F64,
B8,
B16,
B32,
B64,
V2F32,
V4F32,
}
impl PtxType {
#[must_use]
pub const fn size_bytes(self) -> usize {
match self {
Self::Pred | Self::U8 | Self::S8 | Self::B8 => 1,
Self::U16 | Self::S16 | Self::F16 | Self::BF16 | Self::B16 => 2,
Self::U32 | Self::S32 | Self::F32 | Self::B32 => 4,
Self::U64 | Self::S64 | Self::F64 | Self::B64 | Self::V2F32 => 8,
Self::V4F32 => 16,
}
}
#[must_use]
pub const fn size_bits(self) -> usize {
self.size_bytes() * 8
}
#[must_use]
pub const fn to_ptx_string(self) -> &'static str {
match self {
Self::Pred => ".pred",
Self::U8 => ".u8",
Self::U16 => ".u16",
Self::U32 => ".u32",
Self::U64 => ".u64",
Self::S8 => ".s8",
Self::S16 => ".s16",
Self::S32 => ".s32",
Self::S64 => ".s64",
Self::F16 => ".f16",
Self::BF16 => ".bf16",
Self::F32 => ".f32",
Self::F64 => ".f64",
Self::B8 => ".b8",
Self::B16 => ".b16",
Self::B32 => ".b32",
Self::B64 => ".b64",
Self::V2F32 => ".v2.f32",
Self::V4F32 => ".v4.f32",
}
}
#[must_use]
pub const fn is_float(self) -> bool {
matches!(self, Self::F16 | Self::BF16 | Self::F32 | Self::F64 | Self::V2F32 | Self::V4F32)
}
#[must_use]
pub const fn is_signed(self) -> bool {
matches!(self, Self::S8 | Self::S16 | Self::S32 | Self::S64)
}
#[must_use]
pub const fn is_unsigned(self) -> bool {
matches!(self, Self::U8 | Self::U16 | Self::U32 | Self::U64)
}
#[must_use]
pub const fn register_declaration_type(self) -> &'static str {
match self {
Self::U8 => ".u16", Self::S8 => ".s16",
Self::B8 => ".b16",
_ => self.to_ptx_string(),
}
}
#[must_use]
pub const fn register_prefix(self) -> &'static str {
match self {
Self::Pred => "%p",
Self::U8 | Self::B8 => "%rs", Self::S8 => "%rsi", Self::U16 | Self::B16 => "%rh", Self::S16 => "%rhi", Self::U32 => "%r", Self::S32 => "%ri", Self::B32 => "%rb", Self::U64 | Self::B64 => "%rd", Self::S64 => "%rdi", Self::F16 | Self::BF16 => "%h",
Self::F32 | Self::V2F32 | Self::V4F32 => "%f", Self::F64 => "%fd",
}
}
}
impl fmt::Display for PtxType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_ptx_string())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PtxStateSpace {
Reg,
Shared,
Global,
Local,
Const,
Tex,
Param,
}
impl PtxStateSpace {
#[must_use]
pub const fn to_ptx_string(self) -> &'static str {
match self {
Self::Reg => ".reg",
Self::Shared => ".shared",
Self::Global => ".global",
Self::Local => ".local",
Self::Const => ".const",
Self::Tex => ".tex",
Self::Param => ".param",
}
}
#[must_use]
pub const fn is_cached(self) -> bool {
matches!(self, Self::Const | Self::Tex)
}
#[must_use]
pub const fn is_per_thread(self) -> bool {
matches!(self, Self::Reg | Self::Local)
}
}
impl fmt::Display for PtxStateSpace {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_ptx_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_type_sizes() {
assert_eq!(PtxType::Pred.size_bytes(), 1);
assert_eq!(PtxType::U8.size_bytes(), 1);
assert_eq!(PtxType::U16.size_bytes(), 2);
assert_eq!(PtxType::U32.size_bytes(), 4);
assert_eq!(PtxType::U64.size_bytes(), 8);
assert_eq!(PtxType::F16.size_bytes(), 2);
assert_eq!(PtxType::F32.size_bytes(), 4);
assert_eq!(PtxType::F64.size_bytes(), 8);
}
#[test]
fn test_type_bits() {
assert_eq!(PtxType::U8.size_bits(), 8);
assert_eq!(PtxType::U32.size_bits(), 32);
assert_eq!(PtxType::U64.size_bits(), 64);
}
#[test]
fn test_float_detection() {
assert!(PtxType::F16.is_float());
assert!(PtxType::F32.is_float());
assert!(PtxType::F64.is_float());
assert!(PtxType::BF16.is_float());
assert!(!PtxType::U32.is_float());
assert!(!PtxType::S32.is_float());
}
#[test]
fn test_signed_detection() {
assert!(PtxType::S8.is_signed());
assert!(PtxType::S32.is_signed());
assert!(!PtxType::U32.is_signed());
assert!(!PtxType::F32.is_signed());
}
#[test]
fn test_state_space_strings() {
assert_eq!(PtxStateSpace::Global.to_ptx_string(), ".global");
assert_eq!(PtxStateSpace::Shared.to_ptx_string(), ".shared");
assert_eq!(PtxStateSpace::Reg.to_ptx_string(), ".reg");
}
#[test]
fn test_display_impl() {
assert_eq!(format!("{}", PtxType::F32), ".f32");
assert_eq!(format!("{}", PtxStateSpace::Global), ".global");
}
#[test]
fn test_unsigned_detection() {
assert!(PtxType::U8.is_unsigned());
assert!(PtxType::U16.is_unsigned());
assert!(PtxType::U32.is_unsigned());
assert!(PtxType::U64.is_unsigned());
assert!(!PtxType::S32.is_unsigned());
assert!(!PtxType::F32.is_unsigned());
}
#[test]
fn test_state_space_cached() {
assert!(PtxStateSpace::Const.is_cached());
assert!(PtxStateSpace::Tex.is_cached());
assert!(!PtxStateSpace::Global.is_cached());
assert!(!PtxStateSpace::Shared.is_cached());
}
#[test]
fn test_state_space_per_thread() {
assert!(PtxStateSpace::Reg.is_per_thread());
assert!(PtxStateSpace::Local.is_per_thread());
assert!(!PtxStateSpace::Global.is_per_thread());
assert!(!PtxStateSpace::Shared.is_per_thread());
}
#[test]
fn test_register_prefix() {
assert_eq!(PtxType::Pred.register_prefix(), "%p");
assert_eq!(PtxType::F16.register_prefix(), "%h");
assert_eq!(PtxType::BF16.register_prefix(), "%h");
assert_eq!(PtxType::F32.register_prefix(), "%f");
assert_eq!(PtxType::F64.register_prefix(), "%fd");
assert_eq!(PtxType::U32.register_prefix(), "%r");
assert_eq!(PtxType::S32.register_prefix(), "%ri"); assert_eq!(PtxType::U64.register_prefix(), "%rd");
assert_eq!(PtxType::S64.register_prefix(), "%rdi"); assert_eq!(PtxType::U8.register_prefix(), "%rs");
assert_eq!(PtxType::S8.register_prefix(), "%rsi"); assert_eq!(PtxType::U16.register_prefix(), "%rh");
assert_eq!(PtxType::S16.register_prefix(), "%rhi"); }
#[test]
fn test_all_type_strings() {
assert_eq!(PtxType::Pred.to_ptx_string(), ".pred");
assert_eq!(PtxType::S8.to_ptx_string(), ".s8");
assert_eq!(PtxType::S16.to_ptx_string(), ".s16");
assert_eq!(PtxType::S64.to_ptx_string(), ".s64");
assert_eq!(PtxType::B8.to_ptx_string(), ".b8");
assert_eq!(PtxType::B16.to_ptx_string(), ".b16");
assert_eq!(PtxType::B32.to_ptx_string(), ".b32");
assert_eq!(PtxType::B64.to_ptx_string(), ".b64");
assert_eq!(PtxType::BF16.to_ptx_string(), ".bf16");
}
#[test]
fn test_all_state_space_strings() {
assert_eq!(PtxStateSpace::Local.to_ptx_string(), ".local");
assert_eq!(PtxStateSpace::Param.to_ptx_string(), ".param");
assert_eq!(PtxStateSpace::Tex.to_ptx_string(), ".tex");
assert_eq!(PtxStateSpace::Const.to_ptx_string(), ".const");
}
#[test]
fn test_state_space_display() {
assert_eq!(format!("{}", PtxStateSpace::Shared), ".shared");
assert_eq!(format!("{}", PtxStateSpace::Reg), ".reg");
assert_eq!(format!("{}", PtxStateSpace::Local), ".local");
}
#[test]
fn test_byte_type_sizes() {
assert_eq!(PtxType::B8.size_bytes(), 1);
assert_eq!(PtxType::B16.size_bytes(), 2);
assert_eq!(PtxType::B32.size_bytes(), 4);
assert_eq!(PtxType::B64.size_bytes(), 8);
assert_eq!(PtxType::S8.size_bytes(), 1);
assert_eq!(PtxType::S16.size_bytes(), 2);
}
#[test]
fn test_vector_types() {
assert_eq!(PtxType::V2F32.size_bytes(), 8);
assert_eq!(PtxType::V2F32.size_bits(), 64);
assert!(PtxType::V2F32.is_float());
assert!(!PtxType::V2F32.is_signed());
assert!(!PtxType::V2F32.is_unsigned());
assert_eq!(PtxType::V2F32.register_prefix(), "%f");
assert_eq!(PtxType::V2F32.to_ptx_string(), ".v2.f32");
assert_eq!(PtxType::V4F32.size_bytes(), 16);
assert_eq!(PtxType::V4F32.size_bits(), 128);
assert!(PtxType::V4F32.is_float());
assert!(!PtxType::V4F32.is_signed());
assert!(!PtxType::V4F32.is_unsigned());
assert_eq!(PtxType::V4F32.register_prefix(), "%f");
assert_eq!(PtxType::V4F32.to_ptx_string(), ".v4.f32");
}
#[test]
fn test_b32_register_prefix() {
assert_eq!(PtxType::B32.register_prefix(), "%rb");
}
#[test]
fn test_type_display_all() {
assert_eq!(format!("{}", PtxType::Pred), ".pred");
assert_eq!(format!("{}", PtxType::V2F32), ".v2.f32");
assert_eq!(format!("{}", PtxType::V4F32), ".v4.f32");
assert_eq!(format!("{}", PtxType::B32), ".b32");
}
}