ptx_parser/unparser/
common.rs

1use super::PtxUnparser;
2use crate::{
3    lexer::{PtxToken, tokenize},
4    r#type::common::{
5        AddressBase, AddressOffset, AddressOperand, AddressSpace, AttributeDirective, Axis,
6        CodeLinkage, CodeOrDataLinkage, DataLinkage, DataType, FunctionSymbol, GeneralOperand,
7        Immediate, Label, Operand, PredicateRegister, RegisterOperand, Sign, SpecialRegister,
8        TexHandler2, TexHandler3, TexHandler3Optional, TexType, VariableSymbol, VectorOperand,
9    },
10};
11
12fn push_tokenized(tokens: &mut Vec<PtxToken>, text: &str) {
13    if text.trim().is_empty() {
14        return;
15    }
16    let lexemes =
17        tokenize(text).unwrap_or_else(|_| panic!("failed to tokenize literal {:?}", text));
18    tokens.extend(lexemes.into_iter().map(|(token, _)| token));
19}
20
21pub(crate) fn push_directive(tokens: &mut Vec<PtxToken>, name: &str) {
22    let raw = if name.starts_with('.') {
23        name.to_string()
24    } else {
25        format!(".{}", name)
26    };
27    push_tokenized(tokens, &raw);
28}
29
30pub(crate) fn push_token_from_str(tokens: &mut Vec<PtxToken>, value: &str) {
31    push_tokenized(tokens, value);
32}
33
34pub(crate) fn push_identifier(tokens: &mut Vec<PtxToken>, name: &str) {
35    tokens.push(PtxToken::Identifier(name.to_string()));
36}
37
38pub(crate) fn push_register(tokens: &mut Vec<PtxToken>, name: &str) {
39    tokens.push(PtxToken::Register(name.to_string()));
40}
41
42pub(crate) fn push_decimal<T: ToString>(tokens: &mut Vec<PtxToken>, value: T) {
43    tokens.push(PtxToken::DecimalInteger(value.to_string()));
44}
45
46pub(crate) fn push_opcode(tokens: &mut Vec<PtxToken>, opcode: &str) {
47    push_identifier(tokens, opcode);
48}
49
50fn push_register_with_axis(tokens: &mut Vec<PtxToken>, base: &str, axis: Axis) {
51    push_register(tokens, base);
52    match axis {
53        Axis::None => {}
54        Axis::X => push_directive(tokens, "x"),
55        Axis::Y => push_directive(tokens, "y"),
56        Axis::Z => push_directive(tokens, "z"),
57    };
58}
59
60fn numeric_token(literal: &str) -> PtxToken {
61    if literal.starts_with("0f")
62        || literal.starts_with("0F")
63        || literal.starts_with("0d")
64        || literal.starts_with("0D")
65    {
66        PtxToken::HexFloat(literal.to_string())
67    } else if literal.starts_with("0x") || literal.starts_with("0X") {
68        PtxToken::HexInteger(literal.to_string())
69    } else if literal.starts_with("0b") || literal.starts_with("0B") {
70        PtxToken::BinaryInteger(literal.to_string())
71    } else if literal.len() > 1
72        && literal.starts_with('0')
73        && literal.chars().all(|c| c >= '0' && c <= '7')
74    {
75        PtxToken::OctalInteger(literal.to_string())
76    } else if literal.contains('e') || literal.contains('E') {
77        PtxToken::FloatExponent(literal.to_string())
78    } else if literal.contains('.') {
79        PtxToken::Float(literal.to_string())
80    } else {
81        PtxToken::DecimalInteger(literal.to_string())
82    }
83}
84
85fn push_numeric(tokens: &mut Vec<PtxToken>, literal: &str) {
86    tokens.push(numeric_token(literal));
87}
88
89impl PtxUnparser for CodeLinkage {
90    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
91        match self {
92            CodeLinkage::Visible => push_directive(tokens, "visible"),
93            CodeLinkage::Extern => push_directive(tokens, "extern"),
94            CodeLinkage::Weak => push_directive(tokens, "weak"),
95        }
96    }
97}
98
99impl PtxUnparser for DataLinkage {
100    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
101        match self {
102            DataLinkage::Visible => push_directive(tokens, "visible"),
103            DataLinkage::Extern => push_directive(tokens, "extern"),
104            DataLinkage::Weak => push_directive(tokens, "weak"),
105            DataLinkage::Common => push_directive(tokens, "common"),
106        }
107    }
108}
109
110impl PtxUnparser for CodeOrDataLinkage {
111    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
112        match self {
113            CodeOrDataLinkage::Visible => push_directive(tokens, "visible"),
114            CodeOrDataLinkage::Extern => push_directive(tokens, "extern"),
115            CodeOrDataLinkage::Weak => push_directive(tokens, "weak"),
116            CodeOrDataLinkage::Common => push_directive(tokens, "common"),
117        }
118    }
119}
120
121impl PtxUnparser for TexType {
122    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
123        match self {
124            TexType::TexRef => push_directive(tokens, "texref"),
125            TexType::SamplerRef => push_directive(tokens, "samplerref"),
126            TexType::SurfRef => push_directive(tokens, "surfref"),
127        }
128    }
129}
130
131impl PtxUnparser for AddressSpace {
132    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
133        match self {
134            AddressSpace::Global => push_directive(tokens, "global"),
135            AddressSpace::Const => push_directive(tokens, "const"),
136            AddressSpace::Shared => push_directive(tokens, "shared"),
137            AddressSpace::Local => push_directive(tokens, "local"),
138            AddressSpace::Param => push_directive(tokens, "param"),
139            AddressSpace::Reg => push_directive(tokens, "reg"),
140        }
141    }
142}
143
144impl PtxUnparser for AttributeDirective {
145    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
146        match self {
147            AttributeDirective::Unified(uuid1, uuid2) => {
148                push_directive(tokens, "unified");
149                tokens.push(PtxToken::LParen);
150                let first = uuid1.to_string();
151                push_numeric(tokens, &first);
152                tokens.push(PtxToken::Comma);
153                let second = uuid2.to_string();
154                push_numeric(tokens, &second);
155                tokens.push(PtxToken::RParen);
156            }
157            AttributeDirective::Managed => push_directive(tokens, "managed"),
158        }
159    }
160}
161
162impl PtxUnparser for DataType {
163    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
164        let directive = match self {
165            DataType::U8 => "u8",
166            DataType::U16 => "u16",
167            DataType::U32 => "u32",
168            DataType::U64 => "u64",
169            DataType::S8 => "s8",
170            DataType::S16 => "s16",
171            DataType::S32 => "s32",
172            DataType::S64 => "s64",
173            DataType::F16 => "f16",
174            DataType::F16x2 => "f16x2",
175            DataType::F32 => "f32",
176            DataType::F64 => "f64",
177            DataType::B8 => "b8",
178            DataType::B16 => "b16",
179            DataType::B32 => "b32",
180            DataType::B64 => "b64",
181            DataType::B128 => "b128",
182            DataType::Pred => "pred",
183        };
184        push_directive(tokens, directive);
185    }
186}
187
188impl PtxUnparser for Sign {
189    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
190        match self {
191            Sign::Negative => tokens.push(PtxToken::Minus),
192            Sign::Positive => tokens.push(PtxToken::Plus),
193        }
194    }
195}
196
197impl PtxUnparser for Immediate {
198    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
199        let literal = self.0.as_str();
200        if let Some(rest) = literal.strip_prefix('-') {
201            tokens.push(PtxToken::Minus);
202            push_numeric(tokens, rest);
203        } else if let Some(rest) = literal.strip_prefix('+') {
204            tokens.push(PtxToken::Plus);
205            push_numeric(tokens, rest);
206        } else {
207            push_numeric(tokens, literal);
208        }
209    }
210}
211
212impl PtxUnparser for RegisterOperand {
213    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
214        push_register(tokens, &self.0);
215    }
216}
217
218impl PtxUnparser for PredicateRegister {
219    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
220        push_register(tokens, &self.0);
221    }
222}
223
224impl PtxUnparser for Label {
225    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
226        push_identifier(tokens, &self.0);
227    }
228}
229
230impl PtxUnparser for SpecialRegister {
231    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
232        let name = match self {
233            SpecialRegister::AggrSmemSize => "%aggr_smem_size".to_string(),
234            SpecialRegister::DynamicSmemSize => "%dynamic_smem_size".to_string(),
235            SpecialRegister::LanemaskGt => "%lanemask_gt".to_string(),
236            SpecialRegister::ReservedSmemOffsetBegin => "%reserved_smem_offset_begin".to_string(),
237            SpecialRegister::Clock => "%clock".to_string(),
238            SpecialRegister::Envreg(index) => format!("%envreg{}", index),
239            SpecialRegister::LanemaskLe => "%lanemask_le".to_string(),
240            SpecialRegister::ReservedSmemOffsetCap => "%reserved_smem_offset_cap".to_string(),
241            SpecialRegister::Clock64 => "%clock64".to_string(),
242            SpecialRegister::Globaltimer => "%globaltimer".to_string(),
243            SpecialRegister::LanemaskLt => "%lanemask_lt".to_string(),
244            SpecialRegister::ReservedSmemOffsetEnd => "%reserved_smem_offset_end".to_string(),
245            SpecialRegister::ClusterCtaid(axis) => {
246                push_register_with_axis(tokens, "%cluster_ctaid", *axis);
247                return;
248            }
249            SpecialRegister::GlobaltimerHi => "%globaltimer_hi".to_string(),
250            SpecialRegister::Nclusterid => "%nclusterid".to_string(),
251            SpecialRegister::Smid => "%smid".to_string(),
252            SpecialRegister::ClusterCtarank(axis) => {
253                push_register_with_axis(tokens, "%cluster_ctarank", *axis);
254                return;
255            }
256            SpecialRegister::GlobaltimerLo => "%globaltimer_lo".to_string(),
257            SpecialRegister::Nctaid(axis) => {
258                push_register_with_axis(tokens, "%nctaid", *axis);
259                return;
260            }
261            SpecialRegister::Tid(axis) => {
262                push_register_with_axis(tokens, "%tid", *axis);
263                return;
264            }
265            SpecialRegister::ClusterNctaid(axis) => {
266                push_register_with_axis(tokens, "%cluster_nctaid", *axis);
267                return;
268            }
269            SpecialRegister::Gridid => "%gridid".to_string(),
270            SpecialRegister::Nsmid => "%nsmid".to_string(),
271            SpecialRegister::TotalSmemSize => "%total_smem_size".to_string(),
272            SpecialRegister::ClusterNctarank(axis) => {
273                push_register_with_axis(tokens, "%cluster_nctarank", *axis);
274                return;
275            }
276            SpecialRegister::IsExplicitCluster => "%is_explicit_cluster".to_string(),
277            SpecialRegister::Ntid(axis) => {
278                push_register_with_axis(tokens, "%ntid", *axis);
279                return;
280            }
281            SpecialRegister::Warpid => "%warpid".to_string(),
282            SpecialRegister::Clusterid => "%clusterid".to_string(),
283            SpecialRegister::Laneid => "%laneid".to_string(),
284            SpecialRegister::Nwarpid => "%nwarpid".to_string(),
285            SpecialRegister::WARPSZ => "%WARPSZ".to_string(),
286            SpecialRegister::Ctaid(axis) => {
287                push_register_with_axis(tokens, "%ctaid", *axis);
288                return;
289            }
290            SpecialRegister::LanemaskEq => "%lanemask_eq".to_string(),
291            SpecialRegister::Pm(index) => format!("%pm{}", index),
292            SpecialRegister::Pm64(index) => format!("%pm{}_64", index),
293            SpecialRegister::CurrentGraphExec => "%current_graph_exec".to_string(),
294            SpecialRegister::LanemaskGe => "%lanemask_ge".to_string(),
295            SpecialRegister::ReservedSmemOffset(index) => {
296                format!("%reserved_smem_offset_{}", index)
297            }
298        };
299        push_register(tokens, &name);
300    }
301}
302
303impl PtxUnparser for Operand {
304    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
305        match self {
306            Operand::Register(register) => register.unparse_tokens(tokens),
307            Operand::Immediate(immediate) => immediate.unparse_tokens(tokens),
308            Operand::Symbol(symbol) => push_identifier(tokens, symbol),
309            Operand::SymbolOffset(symbol, offset) => {
310                push_identifier(tokens, symbol);
311                tokens.push(PtxToken::Plus);
312                offset.unparse_tokens(tokens);
313            }
314        }
315    }
316}
317
318impl PtxUnparser for VectorOperand {
319    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
320        tokens.push(PtxToken::LBrace);
321        match self {
322            VectorOperand::Vector1(item) => item.unparse_tokens(tokens),
323            VectorOperand::Vector2(items) => {
324                for (idx, item) in items.iter().enumerate() {
325                    if idx > 0 {
326                        tokens.push(PtxToken::Comma);
327                    }
328                    item.unparse_tokens(tokens);
329                }
330            }
331            VectorOperand::Vector3(items) => {
332                for (idx, item) in items.iter().enumerate() {
333                    if idx > 0 {
334                        tokens.push(PtxToken::Comma);
335                    }
336                    item.unparse_tokens(tokens);
337                }
338            }
339            VectorOperand::Vector4(items) => {
340                for (idx, item) in items.iter().enumerate() {
341                    if idx > 0 {
342                        tokens.push(PtxToken::Comma);
343                    }
344                    item.unparse_tokens(tokens);
345                }
346            }
347            VectorOperand::Vector8(items) => {
348                for (idx, item) in items.iter().enumerate() {
349                    if idx > 0 {
350                        tokens.push(PtxToken::Comma);
351                    }
352                    item.unparse_tokens(tokens);
353                }
354            }
355        }
356        tokens.push(PtxToken::RBrace);
357    }
358}
359
360impl PtxUnparser for GeneralOperand {
361    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
362        match self {
363            GeneralOperand::Vec(vector) => vector.unparse_tokens(tokens),
364            GeneralOperand::Single(operand) => operand.unparse_tokens(tokens),
365        }
366    }
367}
368
369impl PtxUnparser for TexHandler2 {
370    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
371        tokens.push(PtxToken::LBracket);
372        let TexHandler2(items) = self;
373        for (idx, item) in items.iter().enumerate() {
374            if idx > 0 {
375                tokens.push(PtxToken::Comma);
376            }
377            item.unparse_tokens(tokens);
378        }
379        tokens.push(PtxToken::RBracket);
380    }
381}
382
383impl PtxUnparser for TexHandler3 {
384    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
385        tokens.push(PtxToken::LBracket);
386        self.handle.unparse_tokens(tokens);
387        tokens.push(PtxToken::Comma);
388        self.sampler.unparse_tokens(tokens);
389        tokens.push(PtxToken::Comma);
390        self.coords.unparse_tokens(tokens);
391        tokens.push(PtxToken::RBracket);
392    }
393}
394
395impl PtxUnparser for TexHandler3Optional {
396    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
397        tokens.push(PtxToken::LBracket);
398        self.handle.unparse_tokens(tokens);
399        tokens.push(PtxToken::Comma);
400        if let Some(sampler) = &self.sampler {
401            sampler.unparse_tokens(tokens);
402            tokens.push(PtxToken::Comma);
403        }
404        self.coords.unparse_tokens(tokens);
405        tokens.push(PtxToken::RBracket);
406    }
407}
408
409impl PtxUnparser for AddressBase {
410    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
411        match self {
412            AddressBase::Register(register) => register.unparse_tokens(tokens),
413            AddressBase::Variable(symbol) => symbol.unparse_tokens(tokens),
414        }
415    }
416}
417
418impl PtxUnparser for AddressOffset {
419    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
420        match self {
421            AddressOffset::Register(register) => {
422                tokens.push(PtxToken::Plus);
423                register.unparse_tokens(tokens);
424            }
425            AddressOffset::Immediate(sign, immediate) => {
426                sign.unparse_tokens(tokens);
427                immediate.unparse_tokens(tokens);
428            }
429        }
430    }
431}
432
433impl PtxUnparser for AddressOperand {
434    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
435        match self {
436            AddressOperand::Array(symbol, immediate) => {
437                symbol.unparse_tokens(tokens);
438                tokens.push(PtxToken::LBracket);
439                immediate.unparse_tokens(tokens);
440                tokens.push(PtxToken::RBracket);
441            }
442            AddressOperand::ImmediateAddress(immediate) => {
443                tokens.push(PtxToken::LBracket);
444                immediate.unparse_tokens(tokens);
445                tokens.push(PtxToken::RBracket);
446            }
447            AddressOperand::Offset(base, offset) => {
448                tokens.push(PtxToken::LBracket);
449                base.unparse_tokens(tokens);
450                if let Some(offset) = offset {
451                    offset.unparse_tokens(tokens);
452                }
453                tokens.push(PtxToken::RBracket);
454            }
455        }
456    }
457}
458
459impl PtxUnparser for FunctionSymbol {
460    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
461        push_identifier(tokens, &self.0);
462    }
463}
464
465impl PtxUnparser for VariableSymbol {
466    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
467        push_identifier(tokens, &self.0);
468    }
469}
470
471impl PtxUnparser for crate::r#type::common::Instruction {
472    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
473        // Emit predicate if present
474        if let Some(predicate) = &self.predicate {
475            tokens.push(PtxToken::At);
476            if predicate.negated {
477                tokens.push(PtxToken::Exclaim);
478            }
479            predicate.operand.unparse_tokens(tokens);
480        }
481
482        // Emit the instruction
483        self.inst.unparse_tokens(tokens);
484    }
485}