ptx_parser/unparser/
common.rs

1use crate::{
2    lexer::{PtxToken, tokenize},
3    r#type::{
4        common::{
5            AddressBase, AddressOffset, AddressOperand, AttributeDirective, Axis, CodeLinkage,
6            DataLinkage, DataType, FunctionSymbol, GeneralOperand, Immediate, Label, Operand,
7            PredicateRegister, RegisterOperand, Sign, SpecialRegister, TexHandler2, TexHandler3,
8            TexHandler3Optional, VariableSymbol, VectorOperand,
9        },
10        function::{DwarfDirective, DwarfDirectiveKind},
11        variable::ParamStateSpace,
12    },
13    unparser::PtxUnparser,
14};
15
16fn push_tokenized(tokens: &mut Vec<PtxToken>, text: &str) {
17    if text.trim().is_empty() {
18        return;
19    }
20    let lexemes =
21        tokenize(text).unwrap_or_else(|_| panic!("failed to tokenize literal {:?}", text));
22    tokens.extend(lexemes.into_iter().map(|(token, _)| token));
23}
24
25pub(crate) fn push_directive(tokens: &mut Vec<PtxToken>, name: &str) {
26    let raw = if name.starts_with('.') {
27        name.to_string()
28    } else {
29        format!(".{}", name)
30    };
31    push_tokenized(tokens, &raw);
32}
33
34pub(crate) fn push_token_from_str(tokens: &mut Vec<PtxToken>, value: &str) {
35    push_tokenized(tokens, value);
36}
37
38pub(crate) fn push_identifier(tokens: &mut Vec<PtxToken>, name: &str) {
39    tokens.push(PtxToken::Identifier(name.to_string()));
40}
41
42pub(crate) fn push_register(tokens: &mut Vec<PtxToken>, name: &str) {
43    tokens.push(PtxToken::Register(name.to_string()));
44}
45
46pub(crate) fn push_decimal<T: ToString>(tokens: &mut Vec<PtxToken>, value: T) {
47    tokens.push(PtxToken::DecimalInteger(value.to_string()));
48}
49
50fn push_hex_literal(tokens: &mut Vec<PtxToken>, value: u64) {
51    tokens.push(PtxToken::HexInteger(format!("0x{:x}", value)));
52}
53
54pub(crate) fn push_opcode(tokens: &mut Vec<PtxToken>, opcode: &str) {
55    push_identifier(tokens, opcode);
56}
57
58fn push_register_with_axis(tokens: &mut Vec<PtxToken>, base: &str, axis: &Axis) {
59    push_register(tokens, base);
60    match axis {
61        Axis::None { .. } => {}
62        Axis::X { .. } => push_directive(tokens, "x"),
63        Axis::Y { .. } => push_directive(tokens, "y"),
64        Axis::Z { .. } => push_directive(tokens, "z"),
65    };
66}
67
68fn numeric_token(literal: &str) -> PtxToken {
69    if literal.starts_with("0f") || literal.starts_with("0F") {
70        PtxToken::HexFloatSingle(literal.to_string())
71    } else if literal.starts_with("0d") || literal.starts_with("0D") {
72        PtxToken::HexFloatDouble(literal.to_string())
73    } else if literal.starts_with("0x") || literal.starts_with("0X") {
74        PtxToken::HexInteger(literal.to_string())
75    } else if literal.starts_with("0b") || literal.starts_with("0B") {
76        PtxToken::BinaryInteger(literal.to_string())
77    } else if literal.len() > 1
78        && literal.starts_with('0')
79        && literal.chars().all(|c| c >= '0' && c <= '7')
80    {
81        PtxToken::OctalInteger(literal.to_string())
82    } else if literal.contains('e') || literal.contains('E') {
83        PtxToken::FloatExponent(literal.to_string())
84    } else if literal.contains('.') {
85        PtxToken::Float(literal.to_string())
86    } else {
87        PtxToken::DecimalInteger(literal.to_string())
88    }
89}
90
91fn push_numeric(tokens: &mut Vec<PtxToken>, literal: &str) {
92    tokens.push(numeric_token(literal));
93}
94
95fn push_dwarf_values<I>(tokens: &mut Vec<PtxToken>, iter: I)
96where
97    I: IntoIterator<Item = u64>,
98{
99    for (idx, value) in iter.into_iter().enumerate() {
100        if idx > 0 {
101            tokens.push(PtxToken::Comma);
102        }
103        push_hex_literal(tokens, value);
104    }
105}
106
107impl PtxUnparser for DwarfDirective {
108    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
109        push_directive(tokens, "dwarf");
110        match &self.kind {
111            DwarfDirectiveKind::ByteValues(values) => {
112                push_directive(tokens, "byte");
113                push_dwarf_values(tokens, values.iter().map(|v| u64::from(*v)));
114            }
115            DwarfDirectiveKind::FourByteValues(values) => {
116                push_directive(tokens, "4byte");
117                push_dwarf_values(tokens, values.iter().map(|v| u64::from(*v)));
118            }
119            DwarfDirectiveKind::QuadValues(values) => {
120                push_directive(tokens, "quad");
121                push_dwarf_values(tokens, values.iter().copied());
122            }
123            DwarfDirectiveKind::FourByteLabel(label) => {
124                push_directive(tokens, "4byte");
125                push_identifier(tokens, &label.val);
126            }
127            DwarfDirectiveKind::QuadLabel(label) => {
128                push_directive(tokens, "quad");
129                push_identifier(tokens, &label.val);
130            }
131        }
132        tokens.push(PtxToken::Semicolon);
133    }
134}
135
136impl PtxUnparser for CodeLinkage {
137    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
138        match self {
139            CodeLinkage::Visible { .. } => push_directive(tokens, "visible"),
140            CodeLinkage::Extern { .. } => push_directive(tokens, "extern"),
141            CodeLinkage::Weak { .. } => push_directive(tokens, "weak"),
142        }
143    }
144}
145
146impl PtxUnparser for DataLinkage {
147    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
148        match self {
149            DataLinkage::Visible { .. } => push_directive(tokens, "visible"),
150            DataLinkage::Extern { .. } => push_directive(tokens, "extern"),
151            DataLinkage::Weak { .. } => push_directive(tokens, "weak"),
152            DataLinkage::Common { .. } => push_directive(tokens, "common"),
153        }
154    }
155}
156
157impl PtxUnparser for ParamStateSpace {
158    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
159        match self {
160            ParamStateSpace::Const { .. } => push_directive(tokens, "const"),
161            ParamStateSpace::Global { .. } => push_directive(tokens, "global"),
162            ParamStateSpace::Local { .. } => push_directive(tokens, "local"),
163            ParamStateSpace::Shared { .. } => push_directive(tokens, "shared"),
164        }
165    }
166}
167
168impl PtxUnparser for AttributeDirective {
169    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
170        match self {
171            AttributeDirective::Unified { uuid1, uuid2, .. } => {
172                push_directive(tokens, "unified");
173                tokens.push(PtxToken::LParen);
174                let first = uuid1.to_string();
175                push_numeric(tokens, &first);
176                tokens.push(PtxToken::Comma);
177                let second = uuid2.to_string();
178                push_numeric(tokens, &second);
179                tokens.push(PtxToken::RParen);
180            }
181            AttributeDirective::Managed { .. } => push_directive(tokens, "managed"),
182        }
183    }
184}
185
186impl PtxUnparser for DataType {
187    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
188        let directive = match self {
189            DataType::U8 { .. } => "u8",
190            DataType::U16 { .. } => "u16",
191            DataType::U32 { .. } => "u32",
192            DataType::U64 { .. } => "u64",
193            DataType::S8 { .. } => "s8",
194            DataType::S16 { .. } => "s16",
195            DataType::S32 { .. } => "s32",
196            DataType::S64 { .. } => "s64",
197            DataType::F16 { .. } => "f16",
198            DataType::F16x2 { .. } => "f16x2",
199            DataType::F32 { .. } => "f32",
200            DataType::F64 { .. } => "f64",
201            DataType::B8 { .. } => "b8",
202            DataType::B16 { .. } => "b16",
203            DataType::B32 { .. } => "b32",
204            DataType::B64 { .. } => "b64",
205            DataType::B128 { .. } => "b128",
206            DataType::Pred { .. } => "pred",
207            // Texture types (merged from TexType)
208            DataType::TexRef { .. } => "texref",
209            DataType::SamplerRef { .. } => "samplerref",
210            DataType::SurfRef { .. } => "surfref",
211        };
212        push_directive(tokens, directive);
213    }
214}
215
216impl PtxUnparser for Sign {
217    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
218        match self {
219            Sign::Negative { .. } => tokens.push(PtxToken::Minus),
220            Sign::Positive { .. } => tokens.push(PtxToken::Plus),
221        }
222    }
223}
224
225impl PtxUnparser for Immediate {
226    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
227        let literal = self.value.as_str();
228        if let Some(rest) = literal.strip_prefix('-') {
229            tokens.push(PtxToken::Minus);
230            push_numeric(tokens, rest);
231        } else if let Some(rest) = literal.strip_prefix('+') {
232            tokens.push(PtxToken::Plus);
233            push_numeric(tokens, rest);
234        } else {
235            push_numeric(tokens, literal);
236        }
237    }
238}
239
240impl PtxUnparser for RegisterOperand {
241    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
242        let mut repr = self.name.clone();
243        if let Some(component) = &self.component {
244            repr.push('.');
245            repr.push_str(component);
246        }
247        push_register(tokens, &repr);
248    }
249}
250
251impl PtxUnparser for PredicateRegister {
252    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
253        push_register(tokens, &self.name);
254    }
255}
256
257impl PtxUnparser for Label {
258    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
259        push_identifier(tokens, &self.val);
260    }
261}
262
263impl PtxUnparser for SpecialRegister {
264    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
265        let name = match self {
266            SpecialRegister::AggrSmemSize { .. } => "%aggr_smem_size".to_string(),
267            SpecialRegister::DynamicSmemSize { .. } => "%dynamic_smem_size".to_string(),
268            SpecialRegister::LanemaskGt { .. } => "%lanemask_gt".to_string(),
269            SpecialRegister::ReservedSmemOffsetBegin { .. } => {
270                "%reserved_smem_offset_begin".to_string()
271            }
272            SpecialRegister::Clock { .. } => "%clock".to_string(),
273            SpecialRegister::Envreg { index, .. } => format!("%envreg{}", index),
274            SpecialRegister::LanemaskLe { .. } => "%lanemask_le".to_string(),
275            SpecialRegister::ReservedSmemOffsetCap { .. } => {
276                "%reserved_smem_offset_cap".to_string()
277            }
278            SpecialRegister::Clock64 { .. } => "%clock64".to_string(),
279            SpecialRegister::Globaltimer { .. } => "%globaltimer".to_string(),
280            SpecialRegister::LanemaskLt { .. } => "%lanemask_lt".to_string(),
281            SpecialRegister::ReservedSmemOffsetEnd { .. } => {
282                "%reserved_smem_offset_end".to_string()
283            }
284            SpecialRegister::ClusterCtaid { axis, .. } => {
285                push_register_with_axis(tokens, "%cluster_ctaid", axis);
286                return;
287            }
288            SpecialRegister::GlobaltimerHi { .. } => "%globaltimer_hi".to_string(),
289            SpecialRegister::Nclusterid { .. } => "%nclusterid".to_string(),
290            SpecialRegister::Smid { .. } => "%smid".to_string(),
291            SpecialRegister::ClusterCtarank { axis, .. } => {
292                push_register_with_axis(tokens, "%cluster_ctarank", axis);
293                return;
294            }
295            SpecialRegister::GlobaltimerLo { .. } => "%globaltimer_lo".to_string(),
296            SpecialRegister::Nctaid { axis, .. } => {
297                push_register_with_axis(tokens, "%nctaid", axis);
298                return;
299            }
300            SpecialRegister::Tid { axis, .. } => {
301                push_register_with_axis(tokens, "%tid", axis);
302                return;
303            }
304            SpecialRegister::ClusterNctaid { axis, .. } => {
305                push_register_with_axis(tokens, "%cluster_nctaid", axis);
306                return;
307            }
308            SpecialRegister::Gridid { .. } => "%gridid".to_string(),
309            SpecialRegister::Nsmid { .. } => "%nsmid".to_string(),
310            SpecialRegister::TotalSmemSize { .. } => "%total_smem_size".to_string(),
311            SpecialRegister::ClusterNctarank { axis, .. } => {
312                push_register_with_axis(tokens, "%cluster_nctarank", axis);
313                return;
314            }
315            SpecialRegister::IsExplicitCluster { .. } => "%is_explicit_cluster".to_string(),
316            SpecialRegister::Ntid { axis, .. } => {
317                push_register_with_axis(tokens, "%ntid", axis);
318                return;
319            }
320            SpecialRegister::Warpid { .. } => "%warpid".to_string(),
321            SpecialRegister::Clusterid { .. } => "%clusterid".to_string(),
322            SpecialRegister::Laneid { .. } => "%laneid".to_string(),
323            SpecialRegister::Nwarpid { .. } => "%nwarpid".to_string(),
324            SpecialRegister::WARPSZ { .. } => "%WARPSZ".to_string(),
325            SpecialRegister::Ctaid { axis, .. } => {
326                push_register_with_axis(tokens, "%ctaid", axis);
327                return;
328            }
329            SpecialRegister::LanemaskEq { .. } => "%lanemask_eq".to_string(),
330            SpecialRegister::Pm { index, .. } => format!("%pm{}", index),
331            SpecialRegister::Pm64 { index, .. } => format!("%pm{}_64", index),
332            SpecialRegister::CurrentGraphExec { .. } => "%current_graph_exec".to_string(),
333            SpecialRegister::LanemaskGe { .. } => "%lanemask_ge".to_string(),
334            SpecialRegister::ReservedSmemOffset { index, .. } => {
335                format!("%reserved_smem_offset_{}", index)
336            }
337        };
338        push_register(tokens, &name);
339    }
340}
341
342impl PtxUnparser for Operand {
343    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
344        match self {
345            Operand::Register {
346                operand: register, ..
347            } => register.unparse_tokens(tokens),
348            Operand::Immediate {
349                operand: immediate, ..
350            } => immediate.unparse_tokens(tokens),
351            Operand::Symbol { name: symbol, .. } => push_identifier(tokens, symbol),
352            Operand::SymbolOffset { symbol, offset, .. } => {
353                push_identifier(tokens, symbol);
354                tokens.push(PtxToken::Plus);
355                offset.unparse_tokens(tokens);
356            }
357        }
358    }
359}
360
361impl PtxUnparser for VectorOperand {
362    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
363        tokens.push(PtxToken::LBrace);
364        match self {
365            VectorOperand::Vector1 { operand: item, .. } => item.unparse_tokens(tokens),
366            VectorOperand::Vector2 {
367                operands: items, ..
368            } => {
369                for (idx, item) in items.iter().enumerate() {
370                    if idx > 0 {
371                        tokens.push(PtxToken::Comma);
372                    }
373                    item.unparse_tokens(tokens);
374                }
375            }
376            VectorOperand::Vector3 {
377                operands: items, ..
378            } => {
379                for (idx, item) in items.iter().enumerate() {
380                    if idx > 0 {
381                        tokens.push(PtxToken::Comma);
382                    }
383                    item.unparse_tokens(tokens);
384                }
385            }
386            VectorOperand::Vector4 {
387                operands: items, ..
388            } => {
389                for (idx, item) in items.iter().enumerate() {
390                    if idx > 0 {
391                        tokens.push(PtxToken::Comma);
392                    }
393                    item.unparse_tokens(tokens);
394                }
395            }
396            VectorOperand::Vector8 {
397                operands: items, ..
398            } => {
399                for (idx, item) in items.iter().enumerate() {
400                    if idx > 0 {
401                        tokens.push(PtxToken::Comma);
402                    }
403                    item.unparse_tokens(tokens);
404                }
405            }
406        }
407        tokens.push(PtxToken::RBrace);
408    }
409}
410
411impl PtxUnparser for GeneralOperand {
412    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
413        match self {
414            GeneralOperand::Vec {
415                operand: vector, ..
416            } => vector.unparse_tokens(tokens),
417            GeneralOperand::Single { operand, .. } => operand.unparse_tokens(tokens),
418        }
419    }
420}
421
422impl PtxUnparser for TexHandler2 {
423    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
424        tokens.push(PtxToken::LBracket);
425        for (idx, item) in self.operands.iter().enumerate() {
426            if idx > 0 {
427                tokens.push(PtxToken::Comma);
428            }
429            item.unparse_tokens(tokens);
430        }
431        tokens.push(PtxToken::RBracket);
432    }
433}
434
435impl PtxUnparser for TexHandler3 {
436    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
437        tokens.push(PtxToken::LBracket);
438        self.handle.unparse_tokens(tokens);
439        tokens.push(PtxToken::Comma);
440        self.sampler.unparse_tokens(tokens);
441        tokens.push(PtxToken::Comma);
442        self.coords.unparse_tokens(tokens);
443        tokens.push(PtxToken::RBracket);
444    }
445}
446
447impl PtxUnparser for TexHandler3Optional {
448    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
449        tokens.push(PtxToken::LBracket);
450        self.handle.unparse_tokens(tokens);
451        tokens.push(PtxToken::Comma);
452        if let Some(sampler) = &self.sampler {
453            sampler.unparse_tokens(tokens);
454            tokens.push(PtxToken::Comma);
455        }
456        self.coords.unparse_tokens(tokens);
457        tokens.push(PtxToken::RBracket);
458    }
459}
460
461impl PtxUnparser for AddressBase {
462    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
463        match self {
464            AddressBase::Register {
465                operand: register, ..
466            } => register.unparse_tokens(tokens),
467            AddressBase::Variable { symbol, .. } => symbol.unparse_tokens(tokens),
468        }
469    }
470}
471
472impl PtxUnparser for AddressOffset {
473    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
474        match self {
475            AddressOffset::Register {
476                operand: register, ..
477            } => {
478                tokens.push(PtxToken::Plus);
479                register.unparse_tokens(tokens);
480            }
481            AddressOffset::Immediate {
482                sign,
483                value: immediate,
484                ..
485            } => {
486                sign.unparse_tokens(tokens);
487                immediate.unparse_tokens(tokens);
488            }
489        }
490    }
491}
492
493impl PtxUnparser for AddressOperand {
494    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
495        match self {
496            AddressOperand::Array { base, index, .. } => {
497                base.unparse_tokens(tokens);
498                tokens.push(PtxToken::LBracket);
499                index.unparse_tokens(tokens);
500                tokens.push(PtxToken::RBracket);
501            }
502            AddressOperand::ImmediateAddress { addr, .. } => {
503                tokens.push(PtxToken::LBracket);
504                addr.unparse_tokens(tokens);
505                tokens.push(PtxToken::RBracket);
506            }
507            AddressOperand::Offset { base, offset, .. } => {
508                tokens.push(PtxToken::LBracket);
509                base.unparse_tokens(tokens);
510                if let Some(offset) = offset {
511                    offset.unparse_tokens(tokens);
512                }
513                tokens.push(PtxToken::RBracket);
514            }
515        }
516    }
517}
518
519impl PtxUnparser for FunctionSymbol {
520    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
521        push_identifier(tokens, &self.val);
522    }
523}
524
525impl PtxUnparser for VariableSymbol {
526    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
527        push_identifier(tokens, &self.val);
528    }
529}
530
531impl PtxUnparser for crate::r#type::common::Instruction {
532    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
533        // Emit predicate if present
534        if let Some(predicate) = &self.predicate {
535            tokens.push(PtxToken::At);
536            if predicate.negated {
537                tokens.push(PtxToken::Exclaim);
538            }
539            predicate.operand.unparse_tokens(tokens);
540        }
541
542        // Emit the instruction
543        self.inst.unparse_tokens(tokens);
544    }
545}