#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PtxType {
S8,
S16,
S32,
S64,
U8,
U16,
U32,
U64,
F16,
F32,
F64,
B8,
B16,
B32,
B64,
Pred,
}
impl PtxType {
pub fn size_bytes(&self) -> usize {
match self {
PtxType::S8 | PtxType::U8 | PtxType::B8 => 1,
PtxType::S16 | PtxType::U16 | PtxType::B16 | PtxType::F16 => 2,
PtxType::S32 | PtxType::U32 | PtxType::B32 | PtxType::F32 => 4,
PtxType::S64 | PtxType::U64 | PtxType::B64 | PtxType::F64 => 8,
PtxType::Pred => 1,
}
}
pub fn is_signed(&self) -> bool {
matches!(
self,
PtxType::S8 | PtxType::S16 | PtxType::S32 | PtxType::S64
)
}
pub fn is_float(&self) -> bool {
matches!(self, PtxType::F16 | PtxType::F32 | PtxType::F64)
}
pub fn is_64bit(&self) -> bool {
matches!(
self,
PtxType::S64 | PtxType::U64 | PtxType::B64 | PtxType::F64
)
}
}
impl std::fmt::Display for PtxType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
PtxType::S8 => ".s8",
PtxType::S16 => ".s16",
PtxType::S32 => ".s32",
PtxType::S64 => ".s64",
PtxType::U8 => ".u8",
PtxType::U16 => ".u16",
PtxType::U32 => ".u32",
PtxType::U64 => ".u64",
PtxType::F16 => ".f16",
PtxType::F32 => ".f32",
PtxType::F64 => ".f64",
PtxType::B8 => ".b8",
PtxType::B16 => ".b16",
PtxType::B32 => ".b32",
PtxType::B64 => ".b64",
PtxType::Pred => ".pred",
};
write!(f, "{}", s)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum AddressSpace {
Generic,
Global,
Shared,
Local,
Const,
Param,
Texture,
Surface,
}
impl std::fmt::Display for AddressSpace {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
AddressSpace::Generic => "",
AddressSpace::Global => ".global",
AddressSpace::Shared => ".shared",
AddressSpace::Local => ".local",
AddressSpace::Const => ".const",
AddressSpace::Param => ".param",
AddressSpace::Texture => ".tex",
AddressSpace::Surface => ".surf",
};
write!(f, "{}", s)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum SmTarget {
#[default]
Unknown,
Sm50,
Sm52,
Sm60,
Sm61,
Sm70,
Sm75,
Sm80,
Sm86,
Sm89,
Sm90,
}
impl SmTarget {
pub fn min_ptx_version(&self) -> (u8, u8) {
match self {
SmTarget::Unknown => (1, 0),
SmTarget::Sm50 | SmTarget::Sm52 => (4, 0),
SmTarget::Sm60 | SmTarget::Sm61 => (5, 0),
SmTarget::Sm70 => (6, 0),
SmTarget::Sm75 => (6, 3),
SmTarget::Sm80 | SmTarget::Sm86 => (7, 0),
SmTarget::Sm89 => (7, 8),
SmTarget::Sm90 => (8, 0),
}
}
pub fn has_tensor_cores(&self) -> bool {
matches!(
self,
SmTarget::Sm70
| SmTarget::Sm75
| SmTarget::Sm80
| SmTarget::Sm86
| SmTarget::Sm89
| SmTarget::Sm90
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Opcode {
Ld,
St,
Mov,
Cvta,
Cvt,
Add,
Sub,
Mul,
Div,
Rem,
Mad,
Fma,
Neg,
Abs,
Min,
Max,
And,
Or,
Xor,
Not,
Shl,
Shr,
Setp,
Selp,
Bra,
Call,
Ret,
Exit,
Bar,
MemBar,
Atom,
Red,
Tex,
Tld4,
Suld,
Sust,
Shfl,
Vote,
Mma,
Wmma,
LdMatrix,
Cp,
Prefetch,
Unknown,
}
impl Opcode {
pub fn is_load(&self) -> bool {
matches!(
self,
Opcode::Ld | Opcode::Tex | Opcode::Tld4 | Opcode::Suld | Opcode::LdMatrix
)
}
pub fn is_store(&self) -> bool {
matches!(self, Opcode::St | Opcode::Sust)
}
pub fn is_memory_op(&self) -> bool {
self.is_load() || self.is_store() || matches!(self, Opcode::Atom | Opcode::Red)
}
pub fn is_sync(&self) -> bool {
matches!(self, Opcode::Bar | Opcode::MemBar)
}
pub fn is_branch(&self) -> bool {
matches!(
self,
Opcode::Bra | Opcode::Call | Opcode::Ret | Opcode::Exit
)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Modifier {
Shared,
Global,
Local,
Const,
Param,
U32,
U64,
S32,
S64,
F32,
F64,
B32,
B64,
Sync,
Cta,
Gl,
Sys,
AtomicAdd,
AtomicCas,
AtomicExch,
AtomicMin,
AtomicMax,
Other(String),
}
impl Modifier {
pub fn as_address_space(&self) -> Option<AddressSpace> {
match self {
Modifier::Shared => Some(AddressSpace::Shared),
Modifier::Global => Some(AddressSpace::Global),
Modifier::Local => Some(AddressSpace::Local),
Modifier::Const => Some(AddressSpace::Const),
Modifier::Param => Some(AddressSpace::Param),
_ => None,
}
}
pub fn as_type(&self) -> Option<PtxType> {
match self {
Modifier::U32 => Some(PtxType::U32),
Modifier::U64 => Some(PtxType::U64),
Modifier::S32 => Some(PtxType::S32),
Modifier::S64 => Some(PtxType::S64),
Modifier::F32 => Some(PtxType::F32),
Modifier::F64 => Some(PtxType::F64),
Modifier::B32 => Some(PtxType::B32),
Modifier::B64 => Some(PtxType::B64),
_ => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ptx_type_size() {
assert_eq!(PtxType::U8.size_bytes(), 1);
assert_eq!(PtxType::U16.size_bytes(), 2);
assert_eq!(PtxType::U32.size_bytes(), 4);
assert_eq!(PtxType::U64.size_bytes(), 8);
assert_eq!(PtxType::F32.size_bytes(), 4);
assert_eq!(PtxType::F64.size_bytes(), 8);
}
#[test]
fn test_ptx_type_properties() {
assert!(PtxType::S32.is_signed());
assert!(!PtxType::U32.is_signed());
assert!(PtxType::F32.is_float());
assert!(!PtxType::U32.is_float());
assert!(PtxType::U64.is_64bit());
assert!(!PtxType::U32.is_64bit());
}
#[test]
fn test_sm_target_ptx_version() {
assert!(SmTarget::Sm90.min_ptx_version() >= (8, 0));
assert!(SmTarget::Sm70.min_ptx_version() >= (6, 0));
}
#[test]
fn test_opcode_categories() {
assert!(Opcode::Ld.is_load());
assert!(Opcode::St.is_store());
assert!(Opcode::Bar.is_sync());
assert!(Opcode::Bra.is_branch());
}
#[test]
fn test_modifier_conversion() {
assert_eq!(
Modifier::Shared.as_address_space(),
Some(AddressSpace::Shared)
);
assert_eq!(Modifier::U32.as_type(), Some(PtxType::U32));
}
}