ptx_parser/unparser/
common.rs

1use super::PtxUnparser;
2use crate::{
3    lexer::{PtxToken, tokenize},
4    r#type::common::{
5        AddressBase, AddressOffset, AddressOperand, AddressSpace, AttributeDirective, Axis,
6        CodeLinkage, CodeOrDataLinkage, DataLinkage, DataType, FunctionSymbol, GeneralOperand,
7        Immediate, Label, Operand, PredicateRegister, RegisterOperand, Sign, SpecialRegister,
8        TexHandler2, TexHandler3, TexHandler3Optional, TexType, VariableSymbol, VectorOperand,
9    },
10};
11
12fn push_tokenized(tokens: &mut Vec<PtxToken>, text: &str) {
13    if text.trim().is_empty() {
14        return;
15    }
16    let lexemes =
17        tokenize(text).unwrap_or_else(|_| panic!("failed to tokenize literal {:?}", text));
18    tokens.extend(lexemes.into_iter().map(|(token, _)| token));
19}
20
21pub(crate) fn push_directive(tokens: &mut Vec<PtxToken>, name: &str) {
22    let raw = if name.starts_with('.') {
23        name.to_string()
24    } else {
25        format!(".{}", name)
26    };
27    push_tokenized(tokens, &raw);
28}
29
30pub(crate) fn push_token_from_str(tokens: &mut Vec<PtxToken>, value: &str) {
31    push_tokenized(tokens, value);
32}
33
34pub(crate) fn push_identifier(tokens: &mut Vec<PtxToken>, name: &str) {
35    tokens.push(PtxToken::Identifier(name.to_string()));
36}
37
38pub(crate) fn push_register(tokens: &mut Vec<PtxToken>, name: &str) {
39    tokens.push(PtxToken::Register(name.to_string()));
40}
41
42pub(crate) fn push_decimal<T: ToString>(tokens: &mut Vec<PtxToken>, value: T) {
43    tokens.push(PtxToken::DecimalInteger(value.to_string()));
44}
45
46pub(crate) fn push_opcode(tokens: &mut Vec<PtxToken>, opcode: &str) {
47    push_identifier(tokens, opcode);
48}
49
50fn push_register_with_axis(tokens: &mut Vec<PtxToken>, base: &str, axis: &Axis) {
51    push_register(tokens, base);
52    match axis {
53        Axis::None { .. } => {}
54        Axis::X { .. } => push_directive(tokens, "x"),
55        Axis::Y { .. } => push_directive(tokens, "y"),
56        Axis::Z { .. } => push_directive(tokens, "z"),
57    };
58}
59
60fn numeric_token(literal: &str) -> PtxToken {
61    if literal.starts_with("0f")
62        || literal.starts_with("0F")
63        || literal.starts_with("0d")
64        || literal.starts_with("0D")
65    {
66        PtxToken::HexFloat(literal.to_string())
67    } else if literal.starts_with("0x") || literal.starts_with("0X") {
68        PtxToken::HexInteger(literal.to_string())
69    } else if literal.starts_with("0b") || literal.starts_with("0B") {
70        PtxToken::BinaryInteger(literal.to_string())
71    } else if literal.len() > 1
72        && literal.starts_with('0')
73        && literal.chars().all(|c| c >= '0' && c <= '7')
74    {
75        PtxToken::OctalInteger(literal.to_string())
76    } else if literal.contains('e') || literal.contains('E') {
77        PtxToken::FloatExponent(literal.to_string())
78    } else if literal.contains('.') {
79        PtxToken::Float(literal.to_string())
80    } else {
81        PtxToken::DecimalInteger(literal.to_string())
82    }
83}
84
85fn push_numeric(tokens: &mut Vec<PtxToken>, literal: &str) {
86    tokens.push(numeric_token(literal));
87}
88
89impl PtxUnparser for CodeLinkage {
90    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
91        match self {
92            CodeLinkage::Visible { .. } => push_directive(tokens, "visible"),
93            CodeLinkage::Extern { .. } => push_directive(tokens, "extern"),
94            CodeLinkage::Weak { .. } => push_directive(tokens, "weak"),
95        }
96    }
97}
98
99impl PtxUnparser for DataLinkage {
100    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
101        match self {
102            DataLinkage::Visible { .. } => push_directive(tokens, "visible"),
103            DataLinkage::Extern { .. } => push_directive(tokens, "extern"),
104            DataLinkage::Weak { .. } => push_directive(tokens, "weak"),
105            DataLinkage::Common { .. } => push_directive(tokens, "common"),
106        }
107    }
108}
109
110impl PtxUnparser for CodeOrDataLinkage {
111    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
112        match self {
113            CodeOrDataLinkage::Visible { .. } => push_directive(tokens, "visible"),
114            CodeOrDataLinkage::Extern { .. } => push_directive(tokens, "extern"),
115            CodeOrDataLinkage::Weak { .. } => push_directive(tokens, "weak"),
116            CodeOrDataLinkage::Common { .. } => push_directive(tokens, "common"),
117        }
118    }
119}
120
121impl PtxUnparser for TexType {
122    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
123        match self {
124            TexType::TexRef { .. } => push_directive(tokens, "texref"),
125            TexType::SamplerRef { .. } => push_directive(tokens, "samplerref"),
126            TexType::SurfRef { .. } => push_directive(tokens, "surfref"),
127        }
128    }
129}
130
131impl PtxUnparser for AddressSpace {
132    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
133        match self {
134            AddressSpace::Global { .. } => push_directive(tokens, "global"),
135            AddressSpace::Const { .. } => push_directive(tokens, "const"),
136            AddressSpace::Shared { .. } => push_directive(tokens, "shared"),
137            AddressSpace::Local { .. } => push_directive(tokens, "local"),
138            AddressSpace::Param { .. } => push_directive(tokens, "param"),
139            AddressSpace::Reg { .. } => push_directive(tokens, "reg"),
140        }
141    }
142}
143
144impl PtxUnparser for AttributeDirective {
145    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
146        match self {
147            AttributeDirective::Unified { uuid1, uuid2, .. } => {
148                push_directive(tokens, "unified");
149                tokens.push(PtxToken::LParen);
150                let first = uuid1.to_string();
151                push_numeric(tokens, &first);
152                tokens.push(PtxToken::Comma);
153                let second = uuid2.to_string();
154                push_numeric(tokens, &second);
155                tokens.push(PtxToken::RParen);
156            }
157            AttributeDirective::Managed { .. } => push_directive(tokens, "managed"),
158        }
159    }
160}
161
162impl PtxUnparser for DataType {
163    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
164        let directive = match self {
165            DataType::U8 { .. } => "u8",
166            DataType::U16 { .. } => "u16",
167            DataType::U32 { .. } => "u32",
168            DataType::U64 { .. } => "u64",
169            DataType::S8 { .. } => "s8",
170            DataType::S16 { .. } => "s16",
171            DataType::S32 { .. } => "s32",
172            DataType::S64 { .. } => "s64",
173            DataType::F16 { .. } => "f16",
174            DataType::F16x2 { .. } => "f16x2",
175            DataType::F32 { .. } => "f32",
176            DataType::F64 { .. } => "f64",
177            DataType::B8 { .. } => "b8",
178            DataType::B16 { .. } => "b16",
179            DataType::B32 { .. } => "b32",
180            DataType::B64 { .. } => "b64",
181            DataType::B128 { .. } => "b128",
182            DataType::Pred { .. } => "pred",
183        };
184        push_directive(tokens, directive);
185    }
186}
187
188impl PtxUnparser for Sign {
189    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
190        match self {
191            Sign::Negative { .. } => tokens.push(PtxToken::Minus),
192            Sign::Positive { .. } => tokens.push(PtxToken::Plus),
193        }
194    }
195}
196
197impl PtxUnparser for Immediate {
198    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
199        let literal = self.value.as_str();
200        if let Some(rest) = literal.strip_prefix('-') {
201            tokens.push(PtxToken::Minus);
202            push_numeric(tokens, rest);
203        } else if let Some(rest) = literal.strip_prefix('+') {
204            tokens.push(PtxToken::Plus);
205            push_numeric(tokens, rest);
206        } else {
207            push_numeric(tokens, literal);
208        }
209    }
210}
211
212impl PtxUnparser for RegisterOperand {
213    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
214        push_register(tokens, &self.name);
215    }
216}
217
218impl PtxUnparser for PredicateRegister {
219    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
220        push_register(tokens, &self.name);
221    }
222}
223
224impl PtxUnparser for Label {
225    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
226        push_identifier(tokens, &self.name);
227    }
228}
229
230impl PtxUnparser for SpecialRegister {
231    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
232        let name = match self {
233            SpecialRegister::AggrSmemSize { .. } => "%aggr_smem_size".to_string(),
234            SpecialRegister::DynamicSmemSize { .. } => "%dynamic_smem_size".to_string(),
235            SpecialRegister::LanemaskGt { .. } => "%lanemask_gt".to_string(),
236            SpecialRegister::ReservedSmemOffsetBegin { .. } => "%reserved_smem_offset_begin".to_string(),
237            SpecialRegister::Clock { .. } => "%clock".to_string(),
238            SpecialRegister::Envreg { index, .. } => format!("%envreg{}", index),
239            SpecialRegister::LanemaskLe { .. } => "%lanemask_le".to_string(),
240            SpecialRegister::ReservedSmemOffsetCap { .. } => "%reserved_smem_offset_cap".to_string(),
241            SpecialRegister::Clock64 { .. } => "%clock64".to_string(),
242            SpecialRegister::Globaltimer { .. } => "%globaltimer".to_string(),
243            SpecialRegister::LanemaskLt { .. } => "%lanemask_lt".to_string(),
244            SpecialRegister::ReservedSmemOffsetEnd { .. } => "%reserved_smem_offset_end".to_string(),
245            SpecialRegister::ClusterCtaid { axis, .. } => {
246                push_register_with_axis(tokens, "%cluster_ctaid", axis);
247                return;
248            }
249            SpecialRegister::GlobaltimerHi { .. } => "%globaltimer_hi".to_string(),
250            SpecialRegister::Nclusterid { .. } => "%nclusterid".to_string(),
251            SpecialRegister::Smid { .. } => "%smid".to_string(),
252            SpecialRegister::ClusterCtarank { axis, .. } => {
253                push_register_with_axis(tokens, "%cluster_ctarank", axis);
254                return;
255            }
256            SpecialRegister::GlobaltimerLo { .. } => "%globaltimer_lo".to_string(),
257            SpecialRegister::Nctaid { axis, .. } => {
258                push_register_with_axis(tokens, "%nctaid", axis);
259                return;
260            }
261            SpecialRegister::Tid { axis, .. } => {
262                push_register_with_axis(tokens, "%tid", axis);
263                return;
264            }
265            SpecialRegister::ClusterNctaid { axis, .. } => {
266                push_register_with_axis(tokens, "%cluster_nctaid", axis);
267                return;
268            }
269            SpecialRegister::Gridid { .. } => "%gridid".to_string(),
270            SpecialRegister::Nsmid { .. } => "%nsmid".to_string(),
271            SpecialRegister::TotalSmemSize { .. } => "%total_smem_size".to_string(),
272            SpecialRegister::ClusterNctarank { axis, .. } => {
273                push_register_with_axis(tokens, "%cluster_nctarank", axis);
274                return;
275            }
276            SpecialRegister::IsExplicitCluster { .. } => "%is_explicit_cluster".to_string(),
277            SpecialRegister::Ntid { axis, .. } => {
278                push_register_with_axis(tokens, "%ntid", axis);
279                return;
280            }
281            SpecialRegister::Warpid { .. } => "%warpid".to_string(),
282            SpecialRegister::Clusterid { .. } => "%clusterid".to_string(),
283            SpecialRegister::Laneid { .. } => "%laneid".to_string(),
284            SpecialRegister::Nwarpid { .. } => "%nwarpid".to_string(),
285            SpecialRegister::WARPSZ { .. } => "%WARPSZ".to_string(),
286            SpecialRegister::Ctaid { axis, .. } => {
287                push_register_with_axis(tokens, "%ctaid", axis);
288                return;
289            }
290            SpecialRegister::LanemaskEq { .. } => "%lanemask_eq".to_string(),
291            SpecialRegister::Pm { index, .. } => format!("%pm{}", index),
292            SpecialRegister::Pm64 { index, .. } => format!("%pm{}_64", index),
293            SpecialRegister::CurrentGraphExec { .. } => "%current_graph_exec".to_string(),
294            SpecialRegister::LanemaskGe { .. } => "%lanemask_ge".to_string(),
295            SpecialRegister::ReservedSmemOffset { index, .. } => {
296                format!("%reserved_smem_offset_{}", index)
297            }
298        };
299        push_register(tokens, &name);
300    }
301}
302
303impl PtxUnparser for Operand {
304    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
305        match self {
306            Operand::Register { operand: register, .. } => register.unparse_tokens(tokens),
307            Operand::Immediate { operand: immediate, .. } => immediate.unparse_tokens(tokens),
308            Operand::Symbol { name: symbol, .. } => push_identifier(tokens, symbol),
309            Operand::SymbolOffset { symbol, offset, .. } => {
310                push_identifier(tokens, symbol);
311                tokens.push(PtxToken::Plus);
312                offset.unparse_tokens(tokens);
313            }
314        }
315    }
316}
317
318impl PtxUnparser for VectorOperand {
319    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
320        tokens.push(PtxToken::LBrace);
321        match self {
322            VectorOperand::Vector1 { operand: item, .. } => item.unparse_tokens(tokens),
323            VectorOperand::Vector2 { operands: items, .. } => {
324                for (idx, item) in items.iter().enumerate() {
325                    if idx > 0 {
326                        tokens.push(PtxToken::Comma);
327                    }
328                    item.unparse_tokens(tokens);
329                }
330            }
331            VectorOperand::Vector3 { operands: items, .. } => {
332                for (idx, item) in items.iter().enumerate() {
333                    if idx > 0 {
334                        tokens.push(PtxToken::Comma);
335                    }
336                    item.unparse_tokens(tokens);
337                }
338            }
339            VectorOperand::Vector4 { operands: items, .. } => {
340                for (idx, item) in items.iter().enumerate() {
341                    if idx > 0 {
342                        tokens.push(PtxToken::Comma);
343                    }
344                    item.unparse_tokens(tokens);
345                }
346            }
347            VectorOperand::Vector8 { operands: items, .. } => {
348                for (idx, item) in items.iter().enumerate() {
349                    if idx > 0 {
350                        tokens.push(PtxToken::Comma);
351                    }
352                    item.unparse_tokens(tokens);
353                }
354            }
355        }
356        tokens.push(PtxToken::RBrace);
357    }
358}
359
360impl PtxUnparser for GeneralOperand {
361    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
362        match self {
363            GeneralOperand::Vec { operand: vector, .. } => vector.unparse_tokens(tokens),
364            GeneralOperand::Single { operand, .. } => operand.unparse_tokens(tokens),
365        }
366    }
367}
368
369impl PtxUnparser for TexHandler2 {
370    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
371        tokens.push(PtxToken::LBracket);
372        for (idx, item) in self.operands.iter().enumerate() {
373            if idx > 0 {
374                tokens.push(PtxToken::Comma);
375            }
376            item.unparse_tokens(tokens);
377        }
378        tokens.push(PtxToken::RBracket);
379    }
380}
381
382impl PtxUnparser for TexHandler3 {
383    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
384        tokens.push(PtxToken::LBracket);
385        self.handle.unparse_tokens(tokens);
386        tokens.push(PtxToken::Comma);
387        self.sampler.unparse_tokens(tokens);
388        tokens.push(PtxToken::Comma);
389        self.coords.unparse_tokens(tokens);
390        tokens.push(PtxToken::RBracket);
391    }
392}
393
394impl PtxUnparser for TexHandler3Optional {
395    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
396        tokens.push(PtxToken::LBracket);
397        self.handle.unparse_tokens(tokens);
398        tokens.push(PtxToken::Comma);
399        if let Some(sampler) = &self.sampler {
400            sampler.unparse_tokens(tokens);
401            tokens.push(PtxToken::Comma);
402        }
403        self.coords.unparse_tokens(tokens);
404        tokens.push(PtxToken::RBracket);
405    }
406}
407
408impl PtxUnparser for AddressBase {
409    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
410        match self {
411            AddressBase::Register { operand: register, .. } => register.unparse_tokens(tokens),
412            AddressBase::Variable { symbol, .. } => symbol.unparse_tokens(tokens),
413        }
414    }
415}
416
417impl PtxUnparser for AddressOffset {
418    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
419        match self {
420            AddressOffset::Register { operand: register, .. } => {
421                tokens.push(PtxToken::Plus);
422                register.unparse_tokens(tokens);
423            }
424            AddressOffset::Immediate { sign, value: immediate, .. } => {
425                sign.unparse_tokens(tokens);
426                immediate.unparse_tokens(tokens);
427            }
428        }
429    }
430}
431
432impl PtxUnparser for AddressOperand {
433    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
434        match self {
435            AddressOperand::Array { base, index, .. } => {
436                base.unparse_tokens(tokens);
437                tokens.push(PtxToken::LBracket);
438                index.unparse_tokens(tokens);
439                tokens.push(PtxToken::RBracket);
440            }
441            AddressOperand::ImmediateAddress { addr, .. } => {
442                tokens.push(PtxToken::LBracket);
443                addr.unparse_tokens(tokens);
444                tokens.push(PtxToken::RBracket);
445            }
446            AddressOperand::Offset { base, offset, .. } => {
447                tokens.push(PtxToken::LBracket);
448                base.unparse_tokens(tokens);
449                if let Some(offset) = offset {
450                    offset.unparse_tokens(tokens);
451                }
452                tokens.push(PtxToken::RBracket);
453            }
454        }
455    }
456}
457
458impl PtxUnparser for FunctionSymbol {
459    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
460        push_identifier(tokens, &self.name);
461    }
462}
463
464impl PtxUnparser for VariableSymbol {
465    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
466        push_identifier(tokens, &self.name);
467    }
468}
469
470impl PtxUnparser for crate::r#type::common::Instruction {
471    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
472        // Emit predicate if present
473        if let Some(predicate) = &self.predicate {
474            tokens.push(PtxToken::At);
475            if predicate.negated {
476                tokens.push(PtxToken::Exclaim);
477            }
478            predicate.operand.unparse_tokens(tokens);
479        }
480
481        // Emit the instruction
482        self.inst.unparse_tokens(tokens);
483    }
484}