Skip to main content

oxicuda_ptx/ir/
operand.rs

1//! PTX instruction operands.
2//!
3//! Operands represent the inputs and outputs of PTX instructions beyond
4//! destination registers. They can be registers, immediate constants,
5//! memory addresses (base + optional offset), or symbolic references.
6
7use std::fmt;
8
9use super::register::Register;
10
11/// An operand in a PTX instruction.
12///
13/// Operands appear as source arguments in arithmetic, memory, and control-flow
14/// instructions. The [`Operand::Address`] variant models the `[base + offset]`
15/// addressing syntax used in load/store instructions.
16#[derive(Debug, Clone)]
17pub enum Operand {
18    /// A register operand.
19    Register(Register),
20    /// An immediate (literal) value.
21    Immediate(ImmValue),
22    /// A memory address with a base register and optional byte offset.
23    Address {
24        /// The base address register (typically 64-bit).
25        base: Register,
26        /// Optional byte offset added to the base address.
27        offset: Option<i64>,
28    },
29    /// A symbolic reference (e.g., a parameter name or label).
30    Symbol(String),
31}
32
33/// An immediate (literal) value embedded in a PTX instruction.
34#[derive(Debug, Clone)]
35pub enum ImmValue {
36    /// 32-bit unsigned integer literal.
37    U32(u32),
38    /// 64-bit unsigned integer literal.
39    U64(u64),
40    /// 32-bit signed integer literal.
41    S32(i32),
42    /// 64-bit signed integer literal.
43    S64(i64),
44    /// 32-bit floating-point literal.
45    F32(f32),
46    /// 64-bit floating-point literal.
47    F64(f64),
48}
49
50impl From<Register> for Operand {
51    fn from(reg: Register) -> Self {
52        Self::Register(reg)
53    }
54}
55
56impl fmt::Display for ImmValue {
57    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58        match self {
59            Self::U32(v) => write!(f, "{v}"),
60            Self::U64(v) => write!(f, "{v}"),
61            Self::S32(v) => write!(f, "{v}"),
62            Self::S64(v) => write!(f, "{v}"),
63            Self::F32(v) => {
64                // PTX uses C-style float literals; ensure a decimal point is present.
65                if v.fract() == 0.0 {
66                    write!(f, "{v:.1}")
67                } else {
68                    write!(f, "{v}")
69                }
70            }
71            Self::F64(v) => {
72                if v.fract() == 0.0 {
73                    write!(f, "{v:.1}")
74                } else {
75                    write!(f, "{v}")
76                }
77            }
78        }
79    }
80}
81
82impl fmt::Display for Operand {
83    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84        match self {
85            Self::Register(reg) => write!(f, "{reg}"),
86            Self::Immediate(imm) => write!(f, "{imm}"),
87            Self::Address { base, offset } => match offset {
88                Some(off) if *off != 0 => write!(f, "[{base}+{off}]"),
89                _ => write!(f, "[{base}]"),
90            },
91            Self::Symbol(sym) => write!(f, "{sym}"),
92        }
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99    use crate::ir::types::PtxType;
100
101    #[test]
102    fn operand_display_register() {
103        let reg = Register {
104            name: "%f0".into(),
105            ty: PtxType::F32,
106        };
107        let op = Operand::Register(reg);
108        assert_eq!(format!("{op}"), "%f0");
109    }
110
111    #[test]
112    fn operand_display_immediate() {
113        assert_eq!(format!("{}", ImmValue::U32(42)), "42");
114        assert_eq!(format!("{}", ImmValue::F32(3.0)), "3.0");
115        assert_eq!(format!("{}", ImmValue::F32(1.5)), "1.5");
116        assert_eq!(format!("{}", ImmValue::S32(-7)), "-7");
117    }
118
119    #[test]
120    fn operand_display_address() {
121        let base = Register {
122            name: "%rd0".into(),
123            ty: PtxType::U64,
124        };
125        let op_no_offset = Operand::Address {
126            base: base.clone(),
127            offset: None,
128        };
129        assert_eq!(format!("{op_no_offset}"), "[%rd0]");
130
131        let op_with_offset = Operand::Address {
132            base,
133            offset: Some(16),
134        };
135        assert_eq!(format!("{op_with_offset}"), "[%rd0+16]");
136    }
137
138    #[test]
139    fn operand_display_symbol() {
140        let op = Operand::Symbol("_param_0".into());
141        assert_eq!(format!("{op}"), "_param_0");
142    }
143
144    #[test]
145    fn operand_from_register() {
146        let reg = Register {
147            name: "%r0".into(),
148            ty: PtxType::U32,
149        };
150        let op: Operand = reg.into();
151        assert!(matches!(op, Operand::Register(_)));
152    }
153}