use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PtxType {
U8,
U16,
U32,
U64,
S8,
S16,
S32,
S64,
F16,
F16x2,
BF16,
BF16x2,
F32,
F64,
TF32,
E4M3,
E5M2,
E2M3,
E3M2,
E2M1,
B8,
B16,
B32,
B64,
B128,
Pred,
}
impl PtxType {
#[must_use]
pub const fn as_ptx_str(&self) -> &'static str {
match self {
Self::U8 => ".u8",
Self::U16 => ".u16",
Self::U32 => ".u32",
Self::U64 => ".u64",
Self::S8 => ".s8",
Self::S16 => ".s16",
Self::S32 => ".s32",
Self::S64 => ".s64",
Self::F16 => ".f16",
Self::F16x2 => ".f16x2",
Self::BF16 => ".bf16",
Self::BF16x2 => ".bf16x2",
Self::F32 => ".f32",
Self::F64 => ".f64",
Self::TF32 => ".tf32",
Self::E4M3 => ".e4m3",
Self::E5M2 => ".e5m2",
Self::E2M3 => ".e2m3",
Self::E3M2 => ".e3m2",
Self::E2M1 => ".e2m1",
Self::B8 => ".b8",
Self::B16 => ".b16",
Self::B32 => ".b32",
Self::B64 => ".b64",
Self::B128 => ".b128",
Self::Pred => ".pred",
}
}
#[must_use]
pub const fn size_bytes(&self) -> usize {
match self {
Self::U8 | Self::S8 | Self::B8 | Self::E4M3 | Self::E5M2 | Self::E2M1 | Self::Pred => 1,
Self::U16
| Self::S16
| Self::F16
| Self::BF16
| Self::B16
| Self::E2M3
| Self::E3M2 => 2,
Self::U32
| Self::S32
| Self::F32
| Self::F16x2
| Self::BF16x2
| Self::B32
| Self::TF32 => 4,
Self::U64 | Self::S64 | Self::F64 | Self::B64 => 8,
Self::B128 => 16,
}
}
#[must_use]
pub const fn reg_type(&self) -> Self {
match self {
Self::Pred => Self::Pred,
Self::F64 | Self::U64 | Self::S64 | Self::B64 => Self::B64,
Self::B128 => Self::B128,
Self::F16 | Self::BF16 | Self::U16 | Self::S16 | Self::B16 => Self::B16,
_ => Self::B32,
}
}
#[must_use]
pub const fn is_integer(&self) -> bool {
matches!(
self,
Self::U8
| Self::U16
| Self::U32
| Self::U64
| Self::S8
| Self::S16
| Self::S32
| Self::S64
)
}
#[must_use]
pub const fn is_float(&self) -> bool {
matches!(
self,
Self::F16
| Self::F16x2
| Self::BF16
| Self::BF16x2
| Self::F32
| Self::F64
| Self::TF32
| Self::E4M3
| Self::E5M2
| Self::E2M3
| Self::E3M2
| Self::E2M1
)
}
#[must_use]
pub const fn bit_width(&self) -> u32 {
match self {
Self::E2M1 => 4,
Self::E2M3 | Self::E3M2 => 6,
Self::U8 | Self::S8 | Self::B8 | Self::E4M3 | Self::E5M2 => 8,
Self::Pred => 1,
Self::U16 | Self::S16 | Self::F16 | Self::BF16 | Self::B16 => 16,
Self::U32
| Self::S32
| Self::F32
| Self::F16x2
| Self::BF16x2
| Self::B32
| Self::TF32 => 32,
Self::U64 | Self::S64 | Self::F64 | Self::B64 => 64,
Self::B128 => 128,
}
}
#[must_use]
pub const fn is_signed(&self) -> bool {
matches!(
self,
Self::S8
| Self::S16
| Self::S32
| Self::S64
| Self::F16
| Self::F16x2
| Self::BF16
| Self::BF16x2
| Self::F32
| Self::F64
| Self::TF32
| Self::E4M3
| Self::E5M2
| Self::E2M3
| Self::E3M2
| Self::E2M1
)
}
}
impl fmt::Display for PtxType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = self.as_ptx_str();
f.write_str(s.trim_start_matches('.'))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum AtomOp {
Add,
Min,
Max,
Inc,
Dec,
And,
Or,
Xor,
Exch,
}
impl AtomOp {
#[must_use]
pub const fn as_ptx_str(&self) -> &'static str {
match self {
Self::Add => ".add",
Self::Min => ".min",
Self::Max => ".max",
Self::Inc => ".inc",
Self::Dec => ".dec",
Self::And => ".and",
Self::Or => ".or",
Self::Xor => ".xor",
Self::Exch => ".exch",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum VectorWidth {
V1,
V2,
V4,
}
impl VectorWidth {
#[must_use]
pub const fn as_ptx_str(&self) -> &'static str {
match self {
Self::V1 => "",
Self::V2 => ".v2",
Self::V4 => ".v4",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum RoundingMode {
Rn,
Rz,
Ru,
Rd,
}
impl RoundingMode {
#[must_use]
pub const fn as_ptx_str(&self) -> &'static str {
match self {
Self::Rn => ".rn",
Self::Rz => ".rz",
Self::Ru => ".ru",
Self::Rd => ".rd",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MulMode {
Lo,
Hi,
Wide,
}
impl MulMode {
#[must_use]
pub const fn as_ptx_str(&self) -> &'static str {
match self {
Self::Lo => ".lo",
Self::Hi => ".hi",
Self::Wide => ".wide",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CmpOp {
Eq,
Ne,
Lt,
Le,
Gt,
Ge,
Lo,
Ls,
Hi,
Hs,
Equ,
Neu,
Ltu,
Leu,
Gtu,
Geu,
Num,
Nan,
}
impl CmpOp {
#[must_use]
pub const fn as_ptx_str(&self) -> &'static str {
match self {
Self::Eq => ".eq",
Self::Ne => ".ne",
Self::Lt => ".lt",
Self::Le => ".le",
Self::Gt => ".gt",
Self::Ge => ".ge",
Self::Lo => ".lo",
Self::Ls => ".ls",
Self::Hi => ".hi",
Self::Hs => ".hs",
Self::Equ => ".equ",
Self::Neu => ".neu",
Self::Ltu => ".ltu",
Self::Leu => ".leu",
Self::Gtu => ".gtu",
Self::Geu => ".geu",
Self::Num => ".num",
Self::Nan => ".nan",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MemorySpace {
Global,
Shared,
Local,
Constant,
Param,
}
impl MemorySpace {
#[must_use]
pub const fn as_ptx_str(&self) -> &'static str {
match self {
Self::Global => ".global",
Self::Shared => ".shared",
Self::Local => ".local",
Self::Constant => ".const",
Self::Param => ".param",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CacheQualifier {
None,
Ca,
Cg,
Cs,
Lu,
Cv,
}
impl CacheQualifier {
#[must_use]
pub const fn as_ptx_str(&self) -> &'static str {
match self {
Self::None => "",
Self::Ca => ".ca",
Self::Cg => ".cg",
Self::Cs => ".cs",
Self::Lu => ".lu",
Self::Cv => ".cv",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum FenceScope {
Cta,
Gpu,
Sys,
}
impl FenceScope {
#[must_use]
pub const fn as_ptx_str(&self) -> &'static str {
match self {
Self::Cta => ".cta",
Self::Gpu => ".gpu",
Self::Sys => ".sys",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum SpecialReg {
TidX,
TidY,
TidZ,
CtaidX,
CtaidY,
CtaidZ,
NtidX,
NtidY,
NtidZ,
NctaidX,
NctaidY,
NctaidZ,
WarpId,
LaneId,
SmId,
Clock,
Clock64,
DynamicSmemSize,
}
impl SpecialReg {
#[must_use]
pub const fn as_ptx_str(&self) -> &'static str {
match self {
Self::TidX => "%tid.x",
Self::TidY => "%tid.y",
Self::TidZ => "%tid.z",
Self::CtaidX => "%ctaid.x",
Self::CtaidY => "%ctaid.y",
Self::CtaidZ => "%ctaid.z",
Self::NtidX => "%ntid.x",
Self::NtidY => "%ntid.y",
Self::NtidZ => "%ntid.z",
Self::NctaidX => "%nctaid.x",
Self::NctaidY => "%nctaid.y",
Self::NctaidZ => "%nctaid.z",
Self::WarpId => "%warpid",
Self::LaneId => "%laneid",
Self::SmId => "%smid",
Self::Clock => "%clock",
Self::Clock64 => "%clock64",
Self::DynamicSmemSize => "%dynamic_smem_size",
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ptx_type_as_ptx_str() {
assert_eq!(PtxType::F32.as_ptx_str(), ".f32");
assert_eq!(PtxType::U64.as_ptx_str(), ".u64");
assert_eq!(PtxType::Pred.as_ptx_str(), ".pred");
assert_eq!(PtxType::B128.as_ptx_str(), ".b128");
assert_eq!(PtxType::E4M3.as_ptx_str(), ".e4m3");
assert_eq!(PtxType::BF16x2.as_ptx_str(), ".bf16x2");
assert_eq!(PtxType::S32.as_ptx_str(), ".s32");
}
#[test]
fn ptx_type_size_bytes() {
assert_eq!(PtxType::U8.size_bytes(), 1);
assert_eq!(PtxType::F16.size_bytes(), 2);
assert_eq!(PtxType::F32.size_bytes(), 4);
assert_eq!(PtxType::F64.size_bytes(), 8);
assert_eq!(PtxType::B128.size_bytes(), 16);
assert_eq!(PtxType::Pred.size_bytes(), 1);
assert_eq!(PtxType::F16x2.size_bytes(), 4);
assert_eq!(PtxType::BF16x2.size_bytes(), 4);
assert_eq!(PtxType::E2M1.size_bytes(), 1);
}
#[test]
fn ptx_type_reg_type() {
assert_eq!(PtxType::F32.reg_type(), PtxType::B32);
assert_eq!(PtxType::F64.reg_type(), PtxType::B64);
assert_eq!(PtxType::U64.reg_type(), PtxType::B64);
assert_eq!(PtxType::Pred.reg_type(), PtxType::Pred);
assert_eq!(PtxType::F16.reg_type(), PtxType::B16);
assert_eq!(PtxType::B128.reg_type(), PtxType::B128);
assert_eq!(PtxType::U8.reg_type(), PtxType::B32);
}
#[test]
fn ptx_type_classification() {
assert!(PtxType::U32.is_integer());
assert!(PtxType::S64.is_integer());
assert!(!PtxType::F32.is_integer());
assert!(!PtxType::Pred.is_integer());
assert!(PtxType::F32.is_float());
assert!(PtxType::F16x2.is_float());
assert!(PtxType::E4M3.is_float());
assert!(!PtxType::U32.is_float());
assert!(!PtxType::B32.is_float());
assert!(PtxType::S32.is_signed());
assert!(PtxType::F32.is_signed());
assert!(!PtxType::U32.is_signed());
assert!(!PtxType::B32.is_signed());
}
#[test]
fn special_reg_ptx_str() {
assert_eq!(SpecialReg::TidX.as_ptx_str(), "%tid.x");
assert_eq!(SpecialReg::CtaidY.as_ptx_str(), "%ctaid.y");
assert_eq!(SpecialReg::LaneId.as_ptx_str(), "%laneid");
assert_eq!(SpecialReg::Clock64.as_ptx_str(), "%clock64");
assert_eq!(
SpecialReg::DynamicSmemSize.as_ptx_str(),
"%dynamic_smem_size"
);
}
#[test]
fn rounding_mode_ptx_str() {
assert_eq!(RoundingMode::Rn.as_ptx_str(), ".rn");
assert_eq!(RoundingMode::Rz.as_ptx_str(), ".rz");
assert_eq!(RoundingMode::Ru.as_ptx_str(), ".ru");
assert_eq!(RoundingMode::Rd.as_ptx_str(), ".rd");
}
#[test]
fn memory_space_ptx_str() {
assert_eq!(MemorySpace::Global.as_ptx_str(), ".global");
assert_eq!(MemorySpace::Shared.as_ptx_str(), ".shared");
assert_eq!(MemorySpace::Constant.as_ptx_str(), ".const");
assert_eq!(MemorySpace::Param.as_ptx_str(), ".param");
}
#[test]
fn cmp_op_ptx_str() {
assert_eq!(CmpOp::Eq.as_ptx_str(), ".eq");
assert_eq!(CmpOp::Ltu.as_ptx_str(), ".ltu");
assert_eq!(CmpOp::Nan.as_ptx_str(), ".nan");
}
#[test]
fn vector_width_ptx_str() {
assert_eq!(VectorWidth::V1.as_ptx_str(), "");
assert_eq!(VectorWidth::V2.as_ptx_str(), ".v2");
assert_eq!(VectorWidth::V4.as_ptx_str(), ".v4");
}
#[test]
fn mul_mode_ptx_str() {
assert_eq!(MulMode::Lo.as_ptx_str(), ".lo");
assert_eq!(MulMode::Hi.as_ptx_str(), ".hi");
assert_eq!(MulMode::Wide.as_ptx_str(), ".wide");
}
#[test]
fn cache_qualifier_ptx_str() {
assert_eq!(CacheQualifier::None.as_ptx_str(), "");
assert_eq!(CacheQualifier::Ca.as_ptx_str(), ".ca");
assert_eq!(CacheQualifier::Cv.as_ptx_str(), ".cv");
}
#[test]
fn fence_scope_ptx_str() {
assert_eq!(FenceScope::Cta.as_ptx_str(), ".cta");
assert_eq!(FenceScope::Gpu.as_ptx_str(), ".gpu");
assert_eq!(FenceScope::Sys.as_ptx_str(), ".sys");
}
#[test]
fn atom_op_ptx_str() {
assert_eq!(AtomOp::Add.as_ptx_str(), ".add");
assert_eq!(AtomOp::Min.as_ptx_str(), ".min");
assert_eq!(AtomOp::Max.as_ptx_str(), ".max");
assert_eq!(AtomOp::Inc.as_ptx_str(), ".inc");
assert_eq!(AtomOp::Dec.as_ptx_str(), ".dec");
assert_eq!(AtomOp::And.as_ptx_str(), ".and");
assert_eq!(AtomOp::Or.as_ptx_str(), ".or");
assert_eq!(AtomOp::Xor.as_ptx_str(), ".xor");
assert_eq!(AtomOp::Exch.as_ptx_str(), ".exch");
}
#[test]
fn test_fp4_e2m1_type() {
assert_eq!(PtxType::E2M1.bit_width(), 4);
assert!(PtxType::E2M1.is_float());
assert_eq!(format!("{}", PtxType::E2M1), "e2m1");
}
#[test]
fn test_bit_width_correctness() {
assert_eq!(PtxType::Pred.bit_width(), 1);
assert_eq!(PtxType::E2M3.bit_width(), 6);
assert_eq!(PtxType::E3M2.bit_width(), 6);
assert_eq!(PtxType::E4M3.bit_width(), 8);
assert_eq!(PtxType::E5M2.bit_width(), 8);
assert_eq!(PtxType::U8.bit_width(), 8);
assert_eq!(PtxType::F16.bit_width(), 16);
assert_eq!(PtxType::BF16.bit_width(), 16);
assert_eq!(PtxType::F16x2.bit_width(), 32);
assert_eq!(PtxType::F32.bit_width(), 32);
assert_eq!(PtxType::TF32.bit_width(), 32);
assert_eq!(PtxType::F64.bit_width(), 64);
assert_eq!(PtxType::B128.bit_width(), 128);
}
#[test]
fn test_display_format() {
assert_eq!(format!("{}", PtxType::F32), "f32");
assert_eq!(format!("{}", PtxType::U64), "u64");
assert_eq!(format!("{}", PtxType::E4M3), "e4m3");
assert_eq!(format!("{}", PtxType::BF16x2), "bf16x2");
assert_eq!(format!("{}", PtxType::B128), "b128");
assert_eq!(format!("{}", PtxType::Pred), "pred");
}
}