kaio_core/ir/operand.rs
1//! PTX operand types (register references, immediates, special registers).
2
3use std::fmt;
4
5use super::register::Register;
6
7/// An operand to a PTX instruction.
8#[derive(Debug, Clone)]
9pub enum Operand {
10 /// A virtual register.
11 Reg(Register),
12 /// 32-bit signed integer immediate.
13 ImmI32(i32),
14 /// 32-bit unsigned integer immediate.
15 ImmU32(u32),
16 /// 64-bit signed integer immediate.
17 ImmI64(i64),
18 /// 64-bit unsigned integer immediate.
19 ImmU64(u64),
20 /// 32-bit float immediate.
21 ImmF32(f32),
22 /// 64-bit float immediate.
23 ImmF64(f64),
24 /// A PTX special register (`%tid.x`, `%ntid.x`, etc.).
25 SpecialReg(SpecialReg),
26 /// Address of a named shared memory allocation.
27 ///
28 /// Used with `Mov` to load a shared allocation's base address into a
29 /// register: `mov.u32 %r0, sdata;`. Displays as the bare name.
30 SharedAddr(String),
31}
32
33impl fmt::Display for Operand {
34 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35 match self {
36 Self::Reg(r) => write!(f, "{r}"),
37 Self::ImmI32(v) => write!(f, "{v}"),
38 Self::ImmU32(v) => write!(f, "{v}"),
39 Self::ImmI64(v) => write!(f, "{v}"),
40 Self::ImmU64(v) => write!(f, "{v}"),
41 Self::ImmF32(v) => {
42 // PTX requires decimal point for floats
43 if v.fract() == 0.0 {
44 write!(f, "{v:.1}")
45 } else {
46 write!(f, "{v}")
47 }
48 }
49 Self::ImmF64(v) => {
50 if v.fract() == 0.0 {
51 write!(f, "{v:.1}")
52 } else {
53 write!(f, "{v}")
54 }
55 }
56 Self::SpecialReg(sr) => write!(f, "{}", sr.ptx_name()),
57 Self::SharedAddr(name) => write!(f, "{name}"),
58 }
59 }
60}
61
62/// PTX special registers for thread/block/grid indexing.
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
64pub enum SpecialReg {
65 /// `%tid.x` — thread index X
66 TidX,
67 /// `%tid.y` — thread index Y
68 TidY,
69 /// `%tid.z` — thread index Z
70 TidZ,
71 /// `%ntid.x` — block dimension X (threads per block)
72 NtidX,
73 /// `%ntid.y` — block dimension Y
74 NtidY,
75 /// `%ntid.z` — block dimension Z
76 NtidZ,
77 /// `%ctaid.x` — block/CTA index X
78 CtaidX,
79 /// `%ctaid.y` — block/CTA index Y
80 CtaidY,
81 /// `%ctaid.z` — block/CTA index Z
82 CtaidZ,
83 /// `%nctaid.x` — grid dimension X (blocks per grid)
84 NctaidX,
85 /// `%nctaid.y` — grid dimension Y
86 NctaidY,
87 /// `%nctaid.z` — grid dimension Z
88 NctaidZ,
89}
90
91impl SpecialReg {
92 /// The PTX name of this special register (e.g. `%tid.x`).
93 pub fn ptx_name(&self) -> &'static str {
94 match self {
95 Self::TidX => "%tid.x",
96 Self::TidY => "%tid.y",
97 Self::TidZ => "%tid.z",
98 Self::NtidX => "%ntid.x",
99 Self::NtidY => "%ntid.y",
100 Self::NtidZ => "%ntid.z",
101 Self::CtaidX => "%ctaid.x",
102 Self::CtaidY => "%ctaid.y",
103 Self::CtaidZ => "%ctaid.z",
104 Self::NctaidX => "%nctaid.x",
105 Self::NctaidY => "%nctaid.y",
106 Self::NctaidZ => "%nctaid.z",
107 }
108 }
109}