Skip to main content

kaio_core/ir/
operand.rs

1//! PTX operand types (register references, immediates, special registers).
2
3use std::fmt;
4
5use super::register::Register;
6
7/// An operand to a PTX instruction.
8#[derive(Debug, Clone)]
9pub enum Operand {
10    /// A virtual register.
11    Reg(Register),
12    /// 32-bit signed integer immediate.
13    ImmI32(i32),
14    /// 32-bit unsigned integer immediate.
15    ImmU32(u32),
16    /// 64-bit signed integer immediate.
17    ImmI64(i64),
18    /// 64-bit unsigned integer immediate.
19    ImmU64(u64),
20    /// 32-bit float immediate.
21    ImmF32(f32),
22    /// 64-bit float immediate.
23    ImmF64(f64),
24    /// A PTX special register (`%tid.x`, `%ntid.x`, etc.).
25    SpecialReg(SpecialReg),
26    /// Address of a named shared memory allocation.
27    ///
28    /// Used with `Mov` to load a shared allocation's base address into a
29    /// register: `mov.u32 %r0, sdata;`. Displays as the bare name.
30    SharedAddr(String),
31}
32
33impl fmt::Display for Operand {
34    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35        match self {
36            Self::Reg(r) => write!(f, "{r}"),
37            Self::ImmI32(v) => write!(f, "{v}"),
38            Self::ImmU32(v) => write!(f, "{v}"),
39            Self::ImmI64(v) => write!(f, "{v}"),
40            Self::ImmU64(v) => write!(f, "{v}"),
41            Self::ImmF32(v) => {
42                // PTX requires decimal point for floats
43                if v.fract() == 0.0 {
44                    write!(f, "{v:.1}")
45                } else {
46                    write!(f, "{v}")
47                }
48            }
49            Self::ImmF64(v) => {
50                if v.fract() == 0.0 {
51                    write!(f, "{v:.1}")
52                } else {
53                    write!(f, "{v}")
54                }
55            }
56            Self::SpecialReg(sr) => write!(f, "{}", sr.ptx_name()),
57            Self::SharedAddr(name) => write!(f, "{name}"),
58        }
59    }
60}
61
62/// PTX special registers for thread/block/grid indexing.
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
64pub enum SpecialReg {
65    /// `%tid.x` — thread index X
66    TidX,
67    /// `%tid.y` — thread index Y
68    TidY,
69    /// `%tid.z` — thread index Z
70    TidZ,
71    /// `%ntid.x` — block dimension X (threads per block)
72    NtidX,
73    /// `%ntid.y` — block dimension Y
74    NtidY,
75    /// `%ntid.z` — block dimension Z
76    NtidZ,
77    /// `%ctaid.x` — block/CTA index X
78    CtaidX,
79    /// `%ctaid.y` — block/CTA index Y
80    CtaidY,
81    /// `%ctaid.z` — block/CTA index Z
82    CtaidZ,
83    /// `%nctaid.x` — grid dimension X (blocks per grid)
84    NctaidX,
85    /// `%nctaid.y` — grid dimension Y
86    NctaidY,
87    /// `%nctaid.z` — grid dimension Z
88    NctaidZ,
89}
90
91impl SpecialReg {
92    /// The PTX name of this special register (e.g. `%tid.x`).
93    pub fn ptx_name(&self) -> &'static str {
94        match self {
95            Self::TidX => "%tid.x",
96            Self::TidY => "%tid.y",
97            Self::TidZ => "%tid.z",
98            Self::NtidX => "%ntid.x",
99            Self::NtidY => "%ntid.y",
100            Self::NtidZ => "%ntid.z",
101            Self::CtaidX => "%ctaid.x",
102            Self::CtaidY => "%ctaid.y",
103            Self::CtaidZ => "%ctaid.z",
104            Self::NctaidX => "%nctaid.x",
105            Self::NctaidY => "%nctaid.y",
106            Self::NctaidZ => "%nctaid.z",
107        }
108    }
109}