Skip to main content

oxicuda_ptx/ir/
types.rs

1//! PTX type system and supporting enumerations.
2//!
3//! This module defines the full set of PTX data types as specified in the PTX ISA,
4//! including integer, floating-point, bit-width, and predicate types. It also
5//! provides enumerations for rounding modes, comparison operators, memory spaces,
6//! and special registers used throughout the IR.
7
8use std::fmt;
9
10/// PTX data types as defined in the PTX ISA.
11///
12/// Covers unsigned/signed integers, all floating-point widths (including FP8/FP6/FP4
13/// formats introduced in Hopper and Blackwell architectures), untyped bit-width types,
14/// and predicates.
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
16pub enum PtxType {
17    // Unsigned integers
18    /// 8-bit unsigned integer.
19    U8,
20    /// 16-bit unsigned integer.
21    U16,
22    /// 32-bit unsigned integer.
23    U32,
24    /// 64-bit unsigned integer.
25    U64,
26    // Signed integers
27    /// 8-bit signed integer.
28    S8,
29    /// 16-bit signed integer.
30    S16,
31    /// 32-bit signed integer.
32    S32,
33    /// 64-bit signed integer.
34    S64,
35    // Floating point
36    /// IEEE 754 half-precision (16-bit) float.
37    F16,
38    /// Packed pair of half-precision floats.
39    F16x2,
40    /// Brain floating-point (16-bit, 8-bit exponent).
41    BF16,
42    /// Packed pair of BF16 floats.
43    BF16x2,
44    /// IEEE 754 single-precision (32-bit) float.
45    F32,
46    /// IEEE 754 double-precision (64-bit) float.
47    F64,
48    // Special floating point
49    /// TensorFloat-32 (19-bit, used in Tensor Cores).
50    TF32,
51    /// FP8 E4M3 format (Hopper+).
52    E4M3,
53    /// FP8 E5M2 format (Hopper+).
54    E5M2,
55    /// FP6 E2M3 format (Blackwell).
56    E2M3,
57    /// FP6 E3M2 format (Blackwell).
58    E3M2,
59    /// FP4 E2M1 format (Blackwell).
60    E2M1,
61    // Bit-width types (untyped)
62    /// 8-bit untyped.
63    B8,
64    /// 16-bit untyped.
65    B16,
66    /// 32-bit untyped.
67    B32,
68    /// 64-bit untyped.
69    B64,
70    /// 128-bit untyped.
71    B128,
72    // Predicate
73    /// 1-bit predicate register type.
74    Pred,
75}
76
77impl PtxType {
78    /// Returns the PTX ISA string representation of this type (e.g., `".f32"`, `".u64"`).
79    #[must_use]
80    pub const fn as_ptx_str(&self) -> &'static str {
81        match self {
82            Self::U8 => ".u8",
83            Self::U16 => ".u16",
84            Self::U32 => ".u32",
85            Self::U64 => ".u64",
86            Self::S8 => ".s8",
87            Self::S16 => ".s16",
88            Self::S32 => ".s32",
89            Self::S64 => ".s64",
90            Self::F16 => ".f16",
91            Self::F16x2 => ".f16x2",
92            Self::BF16 => ".bf16",
93            Self::BF16x2 => ".bf16x2",
94            Self::F32 => ".f32",
95            Self::F64 => ".f64",
96            Self::TF32 => ".tf32",
97            Self::E4M3 => ".e4m3",
98            Self::E5M2 => ".e5m2",
99            Self::E2M3 => ".e2m3",
100            Self::E3M2 => ".e3m2",
101            Self::E2M1 => ".e2m1",
102            Self::B8 => ".b8",
103            Self::B16 => ".b16",
104            Self::B32 => ".b32",
105            Self::B64 => ".b64",
106            Self::B128 => ".b128",
107            Self::Pred => ".pred",
108        }
109    }
110
111    /// Returns the size in bytes of a single value of this type.
112    ///
113    /// Packed types (e.g., `F16x2`) return the size of the packed value.
114    /// Predicates return 1 byte (the minimum addressable unit).
115    #[must_use]
116    pub const fn size_bytes(&self) -> usize {
117        match self {
118            Self::U8 | Self::S8 | Self::B8 | Self::E4M3 | Self::E5M2 | Self::E2M1 | Self::Pred => 1,
119            Self::U16
120            | Self::S16
121            | Self::F16
122            | Self::BF16
123            | Self::B16
124            | Self::E2M3
125            | Self::E3M2 => 2,
126            Self::U32
127            | Self::S32
128            | Self::F32
129            | Self::F16x2
130            | Self::BF16x2
131            | Self::B32
132            | Self::TF32 => 4,
133            Self::U64 | Self::S64 | Self::F64 | Self::B64 => 8,
134            Self::B128 => 16,
135        }
136    }
137
138    /// Returns the register-width class type used in `.reg` declarations.
139    ///
140    /// PTX uses register classes based on width: 16-bit, 32-bit, 64-bit, and predicate.
141    /// Sub-32-bit types are promoted to 32-bit registers; 128-bit uses 64-bit pairs.
142    #[must_use]
143    pub const fn reg_type(&self) -> Self {
144        match self {
145            Self::Pred => Self::Pred,
146            Self::F64 | Self::U64 | Self::S64 | Self::B64 => Self::B64,
147            Self::B128 => Self::B128,
148            Self::F16 | Self::BF16 | Self::U16 | Self::S16 | Self::B16 => Self::B16,
149            _ => Self::B32,
150        }
151    }
152
153    /// Returns `true` if this is an integer type (signed or unsigned).
154    #[must_use]
155    pub const fn is_integer(&self) -> bool {
156        matches!(
157            self,
158            Self::U8
159                | Self::U16
160                | Self::U32
161                | Self::U64
162                | Self::S8
163                | Self::S16
164                | Self::S32
165                | Self::S64
166        )
167    }
168
169    /// Returns `true` if this is a floating-point type (including packed and special formats).
170    #[must_use]
171    pub const fn is_float(&self) -> bool {
172        matches!(
173            self,
174            Self::F16
175                | Self::F16x2
176                | Self::BF16
177                | Self::BF16x2
178                | Self::F32
179                | Self::F64
180                | Self::TF32
181                | Self::E4M3
182                | Self::E5M2
183                | Self::E2M3
184                | Self::E3M2
185                | Self::E2M1
186        )
187    }
188
189    /// Returns the bit-width of a single element of this type.
190    ///
191    /// For sub-byte types (E2M1 = FP4), returns 4. For packed types like
192    /// `F16x2` and `BF16x2`, returns the total packed width (32 bits).
193    /// Predicates are reported as 1 bit.
194    #[must_use]
195    pub const fn bit_width(&self) -> u32 {
196        match self {
197            // Sub-byte: FP4 (E2M1)
198            Self::E2M1 => 4,
199            // 6-bit types (stored in 8-bit containers but logically 6 bits)
200            Self::E2M3 | Self::E3M2 => 6,
201            // 8-bit types
202            Self::U8 | Self::S8 | Self::B8 | Self::E4M3 | Self::E5M2 => 8,
203            // Predicate (1 bit)
204            Self::Pred => 1,
205            // 16-bit types
206            Self::U16 | Self::S16 | Self::F16 | Self::BF16 | Self::B16 => 16,
207            // 32-bit types (including packed 16-bit pairs)
208            Self::U32
209            | Self::S32
210            | Self::F32
211            | Self::F16x2
212            | Self::BF16x2
213            | Self::B32
214            | Self::TF32 => 32,
215            // 64-bit types
216            Self::U64 | Self::S64 | Self::F64 | Self::B64 => 64,
217            // 128-bit types
218            Self::B128 => 128,
219        }
220    }
221
222    /// Returns `true` if this is a signed type (signed integers or all floats).
223    #[must_use]
224    pub const fn is_signed(&self) -> bool {
225        matches!(
226            self,
227            Self::S8
228                | Self::S16
229                | Self::S32
230                | Self::S64
231                | Self::F16
232                | Self::F16x2
233                | Self::BF16
234                | Self::BF16x2
235                | Self::F32
236                | Self::F64
237                | Self::TF32
238                | Self::E4M3
239                | Self::E5M2
240                | Self::E2M3
241                | Self::E3M2
242                | Self::E2M1
243        )
244    }
245}
246
247impl fmt::Display for PtxType {
248    /// Formats the type as its PTX ISA string without the leading dot.
249    ///
250    /// For example, `PtxType::F32` displays as `"f32"`, and
251    /// `PtxType::E2M1` displays as `"e2m1"`.
252    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
253        // as_ptx_str() returns ".f32" — strip the leading dot for Display
254        let s = self.as_ptx_str();
255        f.write_str(s.trim_start_matches('.'))
256    }
257}
258
259/// Atomic operation type for `atom` and `red` instructions.
260///
261/// These operations are performed atomically on global or shared memory
262/// locations, ensuring correctness under concurrent access from multiple threads.
263#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
264pub enum AtomOp {
265    /// Atomic addition.
266    Add,
267    /// Atomic minimum.
268    Min,
269    /// Atomic maximum.
270    Max,
271    /// Atomic increment (wraps at value).
272    Inc,
273    /// Atomic decrement (wraps at value).
274    Dec,
275    /// Atomic bitwise AND.
276    And,
277    /// Atomic bitwise OR.
278    Or,
279    /// Atomic bitwise XOR.
280    Xor,
281    /// Atomic exchange (swap).
282    Exch,
283}
284
285impl AtomOp {
286    /// Returns the PTX modifier string (e.g., `".add"`, `".exch"`).
287    #[must_use]
288    pub const fn as_ptx_str(&self) -> &'static str {
289        match self {
290            Self::Add => ".add",
291            Self::Min => ".min",
292            Self::Max => ".max",
293            Self::Inc => ".inc",
294            Self::Dec => ".dec",
295            Self::And => ".and",
296            Self::Or => ".or",
297            Self::Xor => ".xor",
298            Self::Exch => ".exch",
299        }
300    }
301}
302
303/// Vector width for vectorized load/store operations.
304#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
305pub enum VectorWidth {
306    /// Scalar (no vectorization).
307    V1,
308    /// 2-element vector.
309    V2,
310    /// 4-element vector.
311    V4,
312}
313
314impl VectorWidth {
315    /// Returns the PTX modifier string (e.g., `".v2"`, `".v4"`), or empty for scalar.
316    #[must_use]
317    pub const fn as_ptx_str(&self) -> &'static str {
318        match self {
319            Self::V1 => "",
320            Self::V2 => ".v2",
321            Self::V4 => ".v4",
322        }
323    }
324}
325
326/// IEEE 754 rounding modes for floating-point operations.
327#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
328pub enum RoundingMode {
329    /// Round to nearest even.
330    Rn,
331    /// Round towards zero.
332    Rz,
333    /// Round towards positive infinity.
334    Ru,
335    /// Round towards negative infinity.
336    Rd,
337}
338
339impl RoundingMode {
340    /// Returns the PTX modifier string (e.g., `".rn"`, `".rz"`).
341    #[must_use]
342    pub const fn as_ptx_str(&self) -> &'static str {
343        match self {
344            Self::Rn => ".rn",
345            Self::Rz => ".rz",
346            Self::Ru => ".ru",
347            Self::Rd => ".rd",
348        }
349    }
350}
351
352/// Multiplication mode controlling which portion of the product is retained.
353#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
354pub enum MulMode {
355    /// Low bits of the product (default for same-width result).
356    Lo,
357    /// High bits of the product.
358    Hi,
359    /// Wide multiplication (result is twice the input width).
360    Wide,
361}
362
363impl MulMode {
364    /// Returns the PTX modifier string (e.g., `".lo"`, `".hi"`, `".wide"`).
365    #[must_use]
366    pub const fn as_ptx_str(&self) -> &'static str {
367        match self {
368            Self::Lo => ".lo",
369            Self::Hi => ".hi",
370            Self::Wide => ".wide",
371        }
372    }
373}
374
375/// Comparison operators for `setp` and related instructions.
376///
377/// The first group (Eq..Hs) are ordered comparisons; the second group
378/// (Equ..Nan) are unordered comparisons for floating-point NaN handling.
379#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
380pub enum CmpOp {
381    /// Equal.
382    Eq,
383    /// Not equal.
384    Ne,
385    /// Less than (signed).
386    Lt,
387    /// Less than or equal (signed).
388    Le,
389    /// Greater than (signed).
390    Gt,
391    /// Greater than or equal (signed).
392    Ge,
393    /// Lower (unsigned less than).
394    Lo,
395    /// Lower or same (unsigned less than or equal).
396    Ls,
397    /// Higher (unsigned greater than).
398    Hi,
399    /// Higher or same (unsigned greater than or equal).
400    Hs,
401    /// Equal (unordered).
402    Equ,
403    /// Not equal (unordered).
404    Neu,
405    /// Less than (unordered).
406    Ltu,
407    /// Less than or equal (unordered).
408    Leu,
409    /// Greater than (unordered).
410    Gtu,
411    /// Greater than or equal (unordered).
412    Geu,
413    /// Numeric (both operands are not NaN).
414    Num,
415    /// NaN (at least one operand is NaN).
416    Nan,
417}
418
419impl CmpOp {
420    /// Returns the PTX modifier string (e.g., `".eq"`, `".lt"`, `".geu"`).
421    #[must_use]
422    pub const fn as_ptx_str(&self) -> &'static str {
423        match self {
424            Self::Eq => ".eq",
425            Self::Ne => ".ne",
426            Self::Lt => ".lt",
427            Self::Le => ".le",
428            Self::Gt => ".gt",
429            Self::Ge => ".ge",
430            Self::Lo => ".lo",
431            Self::Ls => ".ls",
432            Self::Hi => ".hi",
433            Self::Hs => ".hs",
434            Self::Equ => ".equ",
435            Self::Neu => ".neu",
436            Self::Ltu => ".ltu",
437            Self::Leu => ".leu",
438            Self::Gtu => ".gtu",
439            Self::Geu => ".geu",
440            Self::Num => ".num",
441            Self::Nan => ".nan",
442        }
443    }
444}
445
446/// PTX memory address spaces.
447#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
448pub enum MemorySpace {
449    /// Global device memory.
450    Global,
451    /// Shared memory (per-block scratchpad).
452    Shared,
453    /// Local memory (per-thread, spills to DRAM).
454    Local,
455    /// Constant memory (read-only, cached).
456    Constant,
457    /// Parameter memory (kernel arguments).
458    Param,
459}
460
461impl MemorySpace {
462    /// Returns the PTX modifier string (e.g., `".global"`, `".shared"`).
463    #[must_use]
464    pub const fn as_ptx_str(&self) -> &'static str {
465        match self {
466            Self::Global => ".global",
467            Self::Shared => ".shared",
468            Self::Local => ".local",
469            Self::Constant => ".const",
470            Self::Param => ".param",
471        }
472    }
473}
474
475/// Cache operation qualifiers for load/store instructions.
476#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
477pub enum CacheQualifier {
478    /// No explicit cache qualifier.
479    None,
480    /// Cache at all levels.
481    Ca,
482    /// Cache at L2, bypass L1.
483    Cg,
484    /// Streaming (evict first).
485    Cs,
486    /// Last use (evict after use).
487    Lu,
488    /// Volatile (don't cache).
489    Cv,
490}
491
492impl CacheQualifier {
493    /// Returns the PTX modifier string, or empty for `None`.
494    #[must_use]
495    pub const fn as_ptx_str(&self) -> &'static str {
496        match self {
497            Self::None => "",
498            Self::Ca => ".ca",
499            Self::Cg => ".cg",
500            Self::Cs => ".cs",
501            Self::Lu => ".lu",
502            Self::Cv => ".cv",
503        }
504    }
505}
506
507/// Scope for fence and memory ordering instructions.
508#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
509pub enum FenceScope {
510    /// CTA (cooperative thread array / block) scope.
511    Cta,
512    /// GPU (device) scope.
513    Gpu,
514    /// System scope (across GPUs and host).
515    Sys,
516}
517
518impl FenceScope {
519    /// Returns the PTX modifier string (e.g., `".cta"`, `".gpu"`, `".sys"`).
520    #[must_use]
521    pub const fn as_ptx_str(&self) -> &'static str {
522        match self {
523            Self::Cta => ".cta",
524            Self::Gpu => ".gpu",
525            Self::Sys => ".sys",
526        }
527    }
528}
529
530/// Special registers accessible via `mov.u32` / `mov.u64` in PTX.
531///
532/// These provide thread identity, block identity, grid dimensions, and
533/// hardware state information.
534#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
535pub enum SpecialReg {
536    /// Thread index X (`%tid.x`).
537    TidX,
538    /// Thread index Y (`%tid.y`).
539    TidY,
540    /// Thread index Z (`%tid.z`).
541    TidZ,
542    /// Block index X (`%ctaid.x`).
543    CtaidX,
544    /// Block index Y (`%ctaid.y`).
545    CtaidY,
546    /// Block index Z (`%ctaid.z`).
547    CtaidZ,
548    /// Block dimension X (`%ntid.x`).
549    NtidX,
550    /// Block dimension Y (`%ntid.y`).
551    NtidY,
552    /// Block dimension Z (`%ntid.z`).
553    NtidZ,
554    /// Grid dimension X (`%nctaid.x`).
555    NctaidX,
556    /// Grid dimension Y (`%nctaid.y`).
557    NctaidY,
558    /// Grid dimension Z (`%nctaid.z`).
559    NctaidZ,
560    /// Warp ID within the CTA.
561    WarpId,
562    /// Lane ID within the warp (0..31).
563    LaneId,
564    /// Streaming multiprocessor ID.
565    SmId,
566    /// 32-bit clock counter.
567    Clock,
568    /// 64-bit clock counter.
569    Clock64,
570    /// Dynamic shared memory size in bytes.
571    DynamicSmemSize,
572}
573
574impl SpecialReg {
575    /// Returns the PTX special register name (e.g., `"%tid.x"`, `"%laneid"`).
576    #[must_use]
577    pub const fn as_ptx_str(&self) -> &'static str {
578        match self {
579            Self::TidX => "%tid.x",
580            Self::TidY => "%tid.y",
581            Self::TidZ => "%tid.z",
582            Self::CtaidX => "%ctaid.x",
583            Self::CtaidY => "%ctaid.y",
584            Self::CtaidZ => "%ctaid.z",
585            Self::NtidX => "%ntid.x",
586            Self::NtidY => "%ntid.y",
587            Self::NtidZ => "%ntid.z",
588            Self::NctaidX => "%nctaid.x",
589            Self::NctaidY => "%nctaid.y",
590            Self::NctaidZ => "%nctaid.z",
591            Self::WarpId => "%warpid",
592            Self::LaneId => "%laneid",
593            Self::SmId => "%smid",
594            Self::Clock => "%clock",
595            Self::Clock64 => "%clock64",
596            Self::DynamicSmemSize => "%dynamic_smem_size",
597        }
598    }
599}
600
601#[cfg(test)]
602mod tests {
603    use super::*;
604
605    #[test]
606    fn ptx_type_as_ptx_str() {
607        assert_eq!(PtxType::F32.as_ptx_str(), ".f32");
608        assert_eq!(PtxType::U64.as_ptx_str(), ".u64");
609        assert_eq!(PtxType::Pred.as_ptx_str(), ".pred");
610        assert_eq!(PtxType::B128.as_ptx_str(), ".b128");
611        assert_eq!(PtxType::E4M3.as_ptx_str(), ".e4m3");
612        assert_eq!(PtxType::BF16x2.as_ptx_str(), ".bf16x2");
613        assert_eq!(PtxType::S32.as_ptx_str(), ".s32");
614    }
615
616    #[test]
617    fn ptx_type_size_bytes() {
618        assert_eq!(PtxType::U8.size_bytes(), 1);
619        assert_eq!(PtxType::F16.size_bytes(), 2);
620        assert_eq!(PtxType::F32.size_bytes(), 4);
621        assert_eq!(PtxType::F64.size_bytes(), 8);
622        assert_eq!(PtxType::B128.size_bytes(), 16);
623        assert_eq!(PtxType::Pred.size_bytes(), 1);
624        assert_eq!(PtxType::F16x2.size_bytes(), 4);
625        assert_eq!(PtxType::BF16x2.size_bytes(), 4);
626        assert_eq!(PtxType::E2M1.size_bytes(), 1);
627    }
628
629    #[test]
630    fn ptx_type_reg_type() {
631        assert_eq!(PtxType::F32.reg_type(), PtxType::B32);
632        assert_eq!(PtxType::F64.reg_type(), PtxType::B64);
633        assert_eq!(PtxType::U64.reg_type(), PtxType::B64);
634        assert_eq!(PtxType::Pred.reg_type(), PtxType::Pred);
635        assert_eq!(PtxType::F16.reg_type(), PtxType::B16);
636        assert_eq!(PtxType::B128.reg_type(), PtxType::B128);
637        assert_eq!(PtxType::U8.reg_type(), PtxType::B32);
638    }
639
640    #[test]
641    fn ptx_type_classification() {
642        assert!(PtxType::U32.is_integer());
643        assert!(PtxType::S64.is_integer());
644        assert!(!PtxType::F32.is_integer());
645        assert!(!PtxType::Pred.is_integer());
646
647        assert!(PtxType::F32.is_float());
648        assert!(PtxType::F16x2.is_float());
649        assert!(PtxType::E4M3.is_float());
650        assert!(!PtxType::U32.is_float());
651        assert!(!PtxType::B32.is_float());
652
653        assert!(PtxType::S32.is_signed());
654        assert!(PtxType::F32.is_signed());
655        assert!(!PtxType::U32.is_signed());
656        assert!(!PtxType::B32.is_signed());
657    }
658
659    #[test]
660    fn special_reg_ptx_str() {
661        assert_eq!(SpecialReg::TidX.as_ptx_str(), "%tid.x");
662        assert_eq!(SpecialReg::CtaidY.as_ptx_str(), "%ctaid.y");
663        assert_eq!(SpecialReg::LaneId.as_ptx_str(), "%laneid");
664        assert_eq!(SpecialReg::Clock64.as_ptx_str(), "%clock64");
665        assert_eq!(
666            SpecialReg::DynamicSmemSize.as_ptx_str(),
667            "%dynamic_smem_size"
668        );
669    }
670
671    #[test]
672    fn rounding_mode_ptx_str() {
673        assert_eq!(RoundingMode::Rn.as_ptx_str(), ".rn");
674        assert_eq!(RoundingMode::Rz.as_ptx_str(), ".rz");
675        assert_eq!(RoundingMode::Ru.as_ptx_str(), ".ru");
676        assert_eq!(RoundingMode::Rd.as_ptx_str(), ".rd");
677    }
678
679    #[test]
680    fn memory_space_ptx_str() {
681        assert_eq!(MemorySpace::Global.as_ptx_str(), ".global");
682        assert_eq!(MemorySpace::Shared.as_ptx_str(), ".shared");
683        assert_eq!(MemorySpace::Constant.as_ptx_str(), ".const");
684        assert_eq!(MemorySpace::Param.as_ptx_str(), ".param");
685    }
686
687    #[test]
688    fn cmp_op_ptx_str() {
689        assert_eq!(CmpOp::Eq.as_ptx_str(), ".eq");
690        assert_eq!(CmpOp::Ltu.as_ptx_str(), ".ltu");
691        assert_eq!(CmpOp::Nan.as_ptx_str(), ".nan");
692    }
693
694    #[test]
695    fn vector_width_ptx_str() {
696        assert_eq!(VectorWidth::V1.as_ptx_str(), "");
697        assert_eq!(VectorWidth::V2.as_ptx_str(), ".v2");
698        assert_eq!(VectorWidth::V4.as_ptx_str(), ".v4");
699    }
700
701    #[test]
702    fn mul_mode_ptx_str() {
703        assert_eq!(MulMode::Lo.as_ptx_str(), ".lo");
704        assert_eq!(MulMode::Hi.as_ptx_str(), ".hi");
705        assert_eq!(MulMode::Wide.as_ptx_str(), ".wide");
706    }
707
708    #[test]
709    fn cache_qualifier_ptx_str() {
710        assert_eq!(CacheQualifier::None.as_ptx_str(), "");
711        assert_eq!(CacheQualifier::Ca.as_ptx_str(), ".ca");
712        assert_eq!(CacheQualifier::Cv.as_ptx_str(), ".cv");
713    }
714
715    #[test]
716    fn fence_scope_ptx_str() {
717        assert_eq!(FenceScope::Cta.as_ptx_str(), ".cta");
718        assert_eq!(FenceScope::Gpu.as_ptx_str(), ".gpu");
719        assert_eq!(FenceScope::Sys.as_ptx_str(), ".sys");
720    }
721
722    #[test]
723    fn atom_op_ptx_str() {
724        assert_eq!(AtomOp::Add.as_ptx_str(), ".add");
725        assert_eq!(AtomOp::Min.as_ptx_str(), ".min");
726        assert_eq!(AtomOp::Max.as_ptx_str(), ".max");
727        assert_eq!(AtomOp::Inc.as_ptx_str(), ".inc");
728        assert_eq!(AtomOp::Dec.as_ptx_str(), ".dec");
729        assert_eq!(AtomOp::And.as_ptx_str(), ".and");
730        assert_eq!(AtomOp::Or.as_ptx_str(), ".or");
731        assert_eq!(AtomOp::Xor.as_ptx_str(), ".xor");
732        assert_eq!(AtomOp::Exch.as_ptx_str(), ".exch");
733    }
734
735    #[test]
736    fn test_fp4_e2m1_type() {
737        assert_eq!(PtxType::E2M1.bit_width(), 4);
738        assert!(PtxType::E2M1.is_float());
739        assert_eq!(format!("{}", PtxType::E2M1), "e2m1");
740    }
741
742    #[test]
743    fn test_bit_width_correctness() {
744        assert_eq!(PtxType::Pred.bit_width(), 1);
745        assert_eq!(PtxType::E2M3.bit_width(), 6);
746        assert_eq!(PtxType::E3M2.bit_width(), 6);
747        assert_eq!(PtxType::E4M3.bit_width(), 8);
748        assert_eq!(PtxType::E5M2.bit_width(), 8);
749        assert_eq!(PtxType::U8.bit_width(), 8);
750        assert_eq!(PtxType::F16.bit_width(), 16);
751        assert_eq!(PtxType::BF16.bit_width(), 16);
752        assert_eq!(PtxType::F16x2.bit_width(), 32);
753        assert_eq!(PtxType::F32.bit_width(), 32);
754        assert_eq!(PtxType::TF32.bit_width(), 32);
755        assert_eq!(PtxType::F64.bit_width(), 64);
756        assert_eq!(PtxType::B128.bit_width(), 128);
757    }
758
759    #[test]
760    fn test_display_format() {
761        assert_eq!(format!("{}", PtxType::F32), "f32");
762        assert_eq!(format!("{}", PtxType::U64), "u64");
763        assert_eq!(format!("{}", PtxType::E4M3), "e4m3");
764        assert_eq!(format!("{}", PtxType::BF16x2), "bf16x2");
765        assert_eq!(format!("{}", PtxType::B128), "b128");
766        assert_eq!(format!("{}", PtxType::Pred), "pred");
767    }
768}