1use std::fmt;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
16pub enum PtxType {
17 U8,
20 U16,
22 U32,
24 U64,
26 S8,
29 S16,
31 S32,
33 S64,
35 F16,
38 F16x2,
40 BF16,
42 BF16x2,
44 F32,
46 F64,
48 TF32,
51 E4M3,
53 E5M2,
55 E2M3,
57 E3M2,
59 E2M1,
61 B8,
64 B16,
66 B32,
68 B64,
70 B128,
72 Pred,
75}
76
77impl PtxType {
78 #[must_use]
80 pub const fn as_ptx_str(&self) -> &'static str {
81 match self {
82 Self::U8 => ".u8",
83 Self::U16 => ".u16",
84 Self::U32 => ".u32",
85 Self::U64 => ".u64",
86 Self::S8 => ".s8",
87 Self::S16 => ".s16",
88 Self::S32 => ".s32",
89 Self::S64 => ".s64",
90 Self::F16 => ".f16",
91 Self::F16x2 => ".f16x2",
92 Self::BF16 => ".bf16",
93 Self::BF16x2 => ".bf16x2",
94 Self::F32 => ".f32",
95 Self::F64 => ".f64",
96 Self::TF32 => ".tf32",
97 Self::E4M3 => ".e4m3",
98 Self::E5M2 => ".e5m2",
99 Self::E2M3 => ".e2m3",
100 Self::E3M2 => ".e3m2",
101 Self::E2M1 => ".e2m1",
102 Self::B8 => ".b8",
103 Self::B16 => ".b16",
104 Self::B32 => ".b32",
105 Self::B64 => ".b64",
106 Self::B128 => ".b128",
107 Self::Pred => ".pred",
108 }
109 }
110
111 #[must_use]
116 pub const fn size_bytes(&self) -> usize {
117 match self {
118 Self::U8 | Self::S8 | Self::B8 | Self::E4M3 | Self::E5M2 | Self::E2M1 | Self::Pred => 1,
119 Self::U16
120 | Self::S16
121 | Self::F16
122 | Self::BF16
123 | Self::B16
124 | Self::E2M3
125 | Self::E3M2 => 2,
126 Self::U32
127 | Self::S32
128 | Self::F32
129 | Self::F16x2
130 | Self::BF16x2
131 | Self::B32
132 | Self::TF32 => 4,
133 Self::U64 | Self::S64 | Self::F64 | Self::B64 => 8,
134 Self::B128 => 16,
135 }
136 }
137
138 #[must_use]
143 pub const fn reg_type(&self) -> Self {
144 match self {
145 Self::Pred => Self::Pred,
146 Self::F64 | Self::U64 | Self::S64 | Self::B64 => Self::B64,
147 Self::B128 => Self::B128,
148 Self::F16 | Self::BF16 | Self::U16 | Self::S16 | Self::B16 => Self::B16,
149 _ => Self::B32,
150 }
151 }
152
153 #[must_use]
155 pub const fn is_integer(&self) -> bool {
156 matches!(
157 self,
158 Self::U8
159 | Self::U16
160 | Self::U32
161 | Self::U64
162 | Self::S8
163 | Self::S16
164 | Self::S32
165 | Self::S64
166 )
167 }
168
169 #[must_use]
171 pub const fn is_float(&self) -> bool {
172 matches!(
173 self,
174 Self::F16
175 | Self::F16x2
176 | Self::BF16
177 | Self::BF16x2
178 | Self::F32
179 | Self::F64
180 | Self::TF32
181 | Self::E4M3
182 | Self::E5M2
183 | Self::E2M3
184 | Self::E3M2
185 | Self::E2M1
186 )
187 }
188
189 #[must_use]
195 pub const fn bit_width(&self) -> u32 {
196 match self {
197 Self::E2M1 => 4,
199 Self::E2M3 | Self::E3M2 => 6,
201 Self::U8 | Self::S8 | Self::B8 | Self::E4M3 | Self::E5M2 => 8,
203 Self::Pred => 1,
205 Self::U16 | Self::S16 | Self::F16 | Self::BF16 | Self::B16 => 16,
207 Self::U32
209 | Self::S32
210 | Self::F32
211 | Self::F16x2
212 | Self::BF16x2
213 | Self::B32
214 | Self::TF32 => 32,
215 Self::U64 | Self::S64 | Self::F64 | Self::B64 => 64,
217 Self::B128 => 128,
219 }
220 }
221
222 #[must_use]
224 pub const fn is_signed(&self) -> bool {
225 matches!(
226 self,
227 Self::S8
228 | Self::S16
229 | Self::S32
230 | Self::S64
231 | Self::F16
232 | Self::F16x2
233 | Self::BF16
234 | Self::BF16x2
235 | Self::F32
236 | Self::F64
237 | Self::TF32
238 | Self::E4M3
239 | Self::E5M2
240 | Self::E2M3
241 | Self::E3M2
242 | Self::E2M1
243 )
244 }
245}
246
247impl fmt::Display for PtxType {
248 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
253 let s = self.as_ptx_str();
255 f.write_str(s.trim_start_matches('.'))
256 }
257}
258
259#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
264pub enum AtomOp {
265 Add,
267 Min,
269 Max,
271 Inc,
273 Dec,
275 And,
277 Or,
279 Xor,
281 Exch,
283}
284
285impl AtomOp {
286 #[must_use]
288 pub const fn as_ptx_str(&self) -> &'static str {
289 match self {
290 Self::Add => ".add",
291 Self::Min => ".min",
292 Self::Max => ".max",
293 Self::Inc => ".inc",
294 Self::Dec => ".dec",
295 Self::And => ".and",
296 Self::Or => ".or",
297 Self::Xor => ".xor",
298 Self::Exch => ".exch",
299 }
300 }
301}
302
303#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
305pub enum VectorWidth {
306 V1,
308 V2,
310 V4,
312}
313
314impl VectorWidth {
315 #[must_use]
317 pub const fn as_ptx_str(&self) -> &'static str {
318 match self {
319 Self::V1 => "",
320 Self::V2 => ".v2",
321 Self::V4 => ".v4",
322 }
323 }
324}
325
326#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
328pub enum RoundingMode {
329 Rn,
331 Rz,
333 Ru,
335 Rd,
337}
338
339impl RoundingMode {
340 #[must_use]
342 pub const fn as_ptx_str(&self) -> &'static str {
343 match self {
344 Self::Rn => ".rn",
345 Self::Rz => ".rz",
346 Self::Ru => ".ru",
347 Self::Rd => ".rd",
348 }
349 }
350}
351
352#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
354pub enum MulMode {
355 Lo,
357 Hi,
359 Wide,
361}
362
363impl MulMode {
364 #[must_use]
366 pub const fn as_ptx_str(&self) -> &'static str {
367 match self {
368 Self::Lo => ".lo",
369 Self::Hi => ".hi",
370 Self::Wide => ".wide",
371 }
372 }
373}
374
375#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
380pub enum CmpOp {
381 Eq,
383 Ne,
385 Lt,
387 Le,
389 Gt,
391 Ge,
393 Lo,
395 Ls,
397 Hi,
399 Hs,
401 Equ,
403 Neu,
405 Ltu,
407 Leu,
409 Gtu,
411 Geu,
413 Num,
415 Nan,
417}
418
419impl CmpOp {
420 #[must_use]
422 pub const fn as_ptx_str(&self) -> &'static str {
423 match self {
424 Self::Eq => ".eq",
425 Self::Ne => ".ne",
426 Self::Lt => ".lt",
427 Self::Le => ".le",
428 Self::Gt => ".gt",
429 Self::Ge => ".ge",
430 Self::Lo => ".lo",
431 Self::Ls => ".ls",
432 Self::Hi => ".hi",
433 Self::Hs => ".hs",
434 Self::Equ => ".equ",
435 Self::Neu => ".neu",
436 Self::Ltu => ".ltu",
437 Self::Leu => ".leu",
438 Self::Gtu => ".gtu",
439 Self::Geu => ".geu",
440 Self::Num => ".num",
441 Self::Nan => ".nan",
442 }
443 }
444}
445
446#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
448pub enum MemorySpace {
449 Global,
451 Shared,
453 Local,
455 Constant,
457 Param,
459}
460
461impl MemorySpace {
462 #[must_use]
464 pub const fn as_ptx_str(&self) -> &'static str {
465 match self {
466 Self::Global => ".global",
467 Self::Shared => ".shared",
468 Self::Local => ".local",
469 Self::Constant => ".const",
470 Self::Param => ".param",
471 }
472 }
473}
474
475#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
477pub enum CacheQualifier {
478 None,
480 Ca,
482 Cg,
484 Cs,
486 Lu,
488 Cv,
490}
491
492impl CacheQualifier {
493 #[must_use]
495 pub const fn as_ptx_str(&self) -> &'static str {
496 match self {
497 Self::None => "",
498 Self::Ca => ".ca",
499 Self::Cg => ".cg",
500 Self::Cs => ".cs",
501 Self::Lu => ".lu",
502 Self::Cv => ".cv",
503 }
504 }
505}
506
507#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
509pub enum FenceScope {
510 Cta,
512 Gpu,
514 Sys,
516}
517
518impl FenceScope {
519 #[must_use]
521 pub const fn as_ptx_str(&self) -> &'static str {
522 match self {
523 Self::Cta => ".cta",
524 Self::Gpu => ".gpu",
525 Self::Sys => ".sys",
526 }
527 }
528}
529
530#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
535pub enum SpecialReg {
536 TidX,
538 TidY,
540 TidZ,
542 CtaidX,
544 CtaidY,
546 CtaidZ,
548 NtidX,
550 NtidY,
552 NtidZ,
554 NctaidX,
556 NctaidY,
558 NctaidZ,
560 WarpId,
562 LaneId,
564 SmId,
566 Clock,
568 Clock64,
570 DynamicSmemSize,
572}
573
574impl SpecialReg {
575 #[must_use]
577 pub const fn as_ptx_str(&self) -> &'static str {
578 match self {
579 Self::TidX => "%tid.x",
580 Self::TidY => "%tid.y",
581 Self::TidZ => "%tid.z",
582 Self::CtaidX => "%ctaid.x",
583 Self::CtaidY => "%ctaid.y",
584 Self::CtaidZ => "%ctaid.z",
585 Self::NtidX => "%ntid.x",
586 Self::NtidY => "%ntid.y",
587 Self::NtidZ => "%ntid.z",
588 Self::NctaidX => "%nctaid.x",
589 Self::NctaidY => "%nctaid.y",
590 Self::NctaidZ => "%nctaid.z",
591 Self::WarpId => "%warpid",
592 Self::LaneId => "%laneid",
593 Self::SmId => "%smid",
594 Self::Clock => "%clock",
595 Self::Clock64 => "%clock64",
596 Self::DynamicSmemSize => "%dynamic_smem_size",
597 }
598 }
599}
600
601#[cfg(test)]
602mod tests {
603 use super::*;
604
605 #[test]
606 fn ptx_type_as_ptx_str() {
607 assert_eq!(PtxType::F32.as_ptx_str(), ".f32");
608 assert_eq!(PtxType::U64.as_ptx_str(), ".u64");
609 assert_eq!(PtxType::Pred.as_ptx_str(), ".pred");
610 assert_eq!(PtxType::B128.as_ptx_str(), ".b128");
611 assert_eq!(PtxType::E4M3.as_ptx_str(), ".e4m3");
612 assert_eq!(PtxType::BF16x2.as_ptx_str(), ".bf16x2");
613 assert_eq!(PtxType::S32.as_ptx_str(), ".s32");
614 }
615
616 #[test]
617 fn ptx_type_size_bytes() {
618 assert_eq!(PtxType::U8.size_bytes(), 1);
619 assert_eq!(PtxType::F16.size_bytes(), 2);
620 assert_eq!(PtxType::F32.size_bytes(), 4);
621 assert_eq!(PtxType::F64.size_bytes(), 8);
622 assert_eq!(PtxType::B128.size_bytes(), 16);
623 assert_eq!(PtxType::Pred.size_bytes(), 1);
624 assert_eq!(PtxType::F16x2.size_bytes(), 4);
625 assert_eq!(PtxType::BF16x2.size_bytes(), 4);
626 assert_eq!(PtxType::E2M1.size_bytes(), 1);
627 }
628
629 #[test]
630 fn ptx_type_reg_type() {
631 assert_eq!(PtxType::F32.reg_type(), PtxType::B32);
632 assert_eq!(PtxType::F64.reg_type(), PtxType::B64);
633 assert_eq!(PtxType::U64.reg_type(), PtxType::B64);
634 assert_eq!(PtxType::Pred.reg_type(), PtxType::Pred);
635 assert_eq!(PtxType::F16.reg_type(), PtxType::B16);
636 assert_eq!(PtxType::B128.reg_type(), PtxType::B128);
637 assert_eq!(PtxType::U8.reg_type(), PtxType::B32);
638 }
639
640 #[test]
641 fn ptx_type_classification() {
642 assert!(PtxType::U32.is_integer());
643 assert!(PtxType::S64.is_integer());
644 assert!(!PtxType::F32.is_integer());
645 assert!(!PtxType::Pred.is_integer());
646
647 assert!(PtxType::F32.is_float());
648 assert!(PtxType::F16x2.is_float());
649 assert!(PtxType::E4M3.is_float());
650 assert!(!PtxType::U32.is_float());
651 assert!(!PtxType::B32.is_float());
652
653 assert!(PtxType::S32.is_signed());
654 assert!(PtxType::F32.is_signed());
655 assert!(!PtxType::U32.is_signed());
656 assert!(!PtxType::B32.is_signed());
657 }
658
659 #[test]
660 fn special_reg_ptx_str() {
661 assert_eq!(SpecialReg::TidX.as_ptx_str(), "%tid.x");
662 assert_eq!(SpecialReg::CtaidY.as_ptx_str(), "%ctaid.y");
663 assert_eq!(SpecialReg::LaneId.as_ptx_str(), "%laneid");
664 assert_eq!(SpecialReg::Clock64.as_ptx_str(), "%clock64");
665 assert_eq!(
666 SpecialReg::DynamicSmemSize.as_ptx_str(),
667 "%dynamic_smem_size"
668 );
669 }
670
671 #[test]
672 fn rounding_mode_ptx_str() {
673 assert_eq!(RoundingMode::Rn.as_ptx_str(), ".rn");
674 assert_eq!(RoundingMode::Rz.as_ptx_str(), ".rz");
675 assert_eq!(RoundingMode::Ru.as_ptx_str(), ".ru");
676 assert_eq!(RoundingMode::Rd.as_ptx_str(), ".rd");
677 }
678
679 #[test]
680 fn memory_space_ptx_str() {
681 assert_eq!(MemorySpace::Global.as_ptx_str(), ".global");
682 assert_eq!(MemorySpace::Shared.as_ptx_str(), ".shared");
683 assert_eq!(MemorySpace::Constant.as_ptx_str(), ".const");
684 assert_eq!(MemorySpace::Param.as_ptx_str(), ".param");
685 }
686
687 #[test]
688 fn cmp_op_ptx_str() {
689 assert_eq!(CmpOp::Eq.as_ptx_str(), ".eq");
690 assert_eq!(CmpOp::Ltu.as_ptx_str(), ".ltu");
691 assert_eq!(CmpOp::Nan.as_ptx_str(), ".nan");
692 }
693
694 #[test]
695 fn vector_width_ptx_str() {
696 assert_eq!(VectorWidth::V1.as_ptx_str(), "");
697 assert_eq!(VectorWidth::V2.as_ptx_str(), ".v2");
698 assert_eq!(VectorWidth::V4.as_ptx_str(), ".v4");
699 }
700
701 #[test]
702 fn mul_mode_ptx_str() {
703 assert_eq!(MulMode::Lo.as_ptx_str(), ".lo");
704 assert_eq!(MulMode::Hi.as_ptx_str(), ".hi");
705 assert_eq!(MulMode::Wide.as_ptx_str(), ".wide");
706 }
707
708 #[test]
709 fn cache_qualifier_ptx_str() {
710 assert_eq!(CacheQualifier::None.as_ptx_str(), "");
711 assert_eq!(CacheQualifier::Ca.as_ptx_str(), ".ca");
712 assert_eq!(CacheQualifier::Cv.as_ptx_str(), ".cv");
713 }
714
715 #[test]
716 fn fence_scope_ptx_str() {
717 assert_eq!(FenceScope::Cta.as_ptx_str(), ".cta");
718 assert_eq!(FenceScope::Gpu.as_ptx_str(), ".gpu");
719 assert_eq!(FenceScope::Sys.as_ptx_str(), ".sys");
720 }
721
722 #[test]
723 fn atom_op_ptx_str() {
724 assert_eq!(AtomOp::Add.as_ptx_str(), ".add");
725 assert_eq!(AtomOp::Min.as_ptx_str(), ".min");
726 assert_eq!(AtomOp::Max.as_ptx_str(), ".max");
727 assert_eq!(AtomOp::Inc.as_ptx_str(), ".inc");
728 assert_eq!(AtomOp::Dec.as_ptx_str(), ".dec");
729 assert_eq!(AtomOp::And.as_ptx_str(), ".and");
730 assert_eq!(AtomOp::Or.as_ptx_str(), ".or");
731 assert_eq!(AtomOp::Xor.as_ptx_str(), ".xor");
732 assert_eq!(AtomOp::Exch.as_ptx_str(), ".exch");
733 }
734
735 #[test]
736 fn test_fp4_e2m1_type() {
737 assert_eq!(PtxType::E2M1.bit_width(), 4);
738 assert!(PtxType::E2M1.is_float());
739 assert_eq!(format!("{}", PtxType::E2M1), "e2m1");
740 }
741
742 #[test]
743 fn test_bit_width_correctness() {
744 assert_eq!(PtxType::Pred.bit_width(), 1);
745 assert_eq!(PtxType::E2M3.bit_width(), 6);
746 assert_eq!(PtxType::E3M2.bit_width(), 6);
747 assert_eq!(PtxType::E4M3.bit_width(), 8);
748 assert_eq!(PtxType::E5M2.bit_width(), 8);
749 assert_eq!(PtxType::U8.bit_width(), 8);
750 assert_eq!(PtxType::F16.bit_width(), 16);
751 assert_eq!(PtxType::BF16.bit_width(), 16);
752 assert_eq!(PtxType::F16x2.bit_width(), 32);
753 assert_eq!(PtxType::F32.bit_width(), 32);
754 assert_eq!(PtxType::TF32.bit_width(), 32);
755 assert_eq!(PtxType::F64.bit_width(), 64);
756 assert_eq!(PtxType::B128.bit_width(), 128);
757 }
758
759 #[test]
760 fn test_display_format() {
761 assert_eq!(format!("{}", PtxType::F32), "f32");
762 assert_eq!(format!("{}", PtxType::U64), "u64");
763 assert_eq!(format!("{}", PtxType::E4M3), "e4m3");
764 assert_eq!(format!("{}", PtxType::BF16x2), "bf16x2");
765 assert_eq!(format!("{}", PtxType::B128), "b128");
766 assert_eq!(format!("{}", PtxType::Pred), "pred");
767 }
768}