ptx_parser/unparser/
function.rs

1use crate::unparser::common::push_register;
2use crate::{
3    lexer::PtxToken,
4    r#type::{function::*, variable::ParameterDirective},
5    unparser::*,
6};
7
8fn push_register_components(tokens: &mut Vec<PtxToken>, name: &str) {
9    if let Some(stripped) = name.strip_prefix('%') {
10        let mut parts = stripped.split('.');
11        if let Some(first) = parts.next() {
12            let register_name = format!("%{first}");
13            push_register(tokens, &register_name);
14        }
15        for part in parts {
16            if part.is_empty() {
17                continue;
18            }
19            push_directive(tokens, part);
20        }
21    } else {
22        push_identifier(tokens, name);
23    }
24}
25
26fn unparse_param(tokens: &mut Vec<PtxToken>, param: &ParameterDirective) {
27    match param {
28        ParameterDirective::Parameter {
29            align,
30            ty,
31            ptr,
32            space,
33            name,
34            array,
35            ..
36        } => {
37            push_directive(tokens, "param");
38            ty.unparse_tokens(tokens);
39            if *ptr {
40                push_directive(tokens, "ptr");
41            }
42            if let Some(address_space) = space {
43                address_space.unparse_tokens(tokens);
44            }
45            if let Some(value) = align {
46                push_directive(tokens, "align");
47                push_decimal(tokens, *value);
48            }
49            push_identifier(tokens, &name.val);
50            for extent in array {
51                tokens.push(PtxToken::LBracket);
52                if let Some(value) = extent {
53                    push_decimal(tokens, *value);
54                }
55                tokens.push(PtxToken::RBracket);
56            }
57        }
58        ParameterDirective::Register { ty, name, .. } => {
59            push_directive(tokens, "reg");
60            ty.unparse_tokens(tokens);
61            push_register_components(tokens, &name.val);
62        }
63    }
64}
65
66fn unparse_param_list(tokens: &mut Vec<PtxToken>, params: &[ParameterDirective]) {
67    for (idx, param) in params.iter().enumerate() {
68        if idx > 0 {
69            tokens.push(PtxToken::Comma);
70        }
71        unparse_param(tokens, param);
72    }
73}
74
75fn unparse_section_line(tokens: &mut Vec<PtxToken>, line: &StatementSectionDirectiveLine) {
76    match line {
77        StatementSectionDirectiveLine::B8 { values, .. } => {
78            push_directive(tokens, "b8");
79            for (idx, value) in values.iter().enumerate() {
80                if idx > 0 {
81                    tokens.push(PtxToken::Comma);
82                }
83                push_signed_decimal_i64(tokens, *value as i64);
84            }
85            tokens.push(PtxToken::Semicolon);
86        }
87        StatementSectionDirectiveLine::B16 { values, .. } => {
88            push_directive(tokens, "b16");
89            for (idx, value) in values.iter().enumerate() {
90                if idx > 0 {
91                    tokens.push(PtxToken::Comma);
92                }
93                push_signed_decimal_i64(tokens, *value as i64);
94            }
95            tokens.push(PtxToken::Semicolon);
96        }
97        StatementSectionDirectiveLine::B32Immediate { values, .. } => {
98            push_directive(tokens, "b32");
99            for (idx, value) in values.iter().enumerate() {
100                if idx > 0 {
101                    tokens.push(PtxToken::Comma);
102                }
103                push_signed_decimal_i64(tokens, *value);
104            }
105            tokens.push(PtxToken::Semicolon);
106        }
107        StatementSectionDirectiveLine::B64Immediate { values, .. } => {
108            push_directive(tokens, "b64");
109            for (idx, value) in values.iter().enumerate() {
110                if idx > 0 {
111                    tokens.push(PtxToken::Comma);
112                }
113                push_signed_decimal_i128(tokens, *value);
114            }
115            tokens.push(PtxToken::Semicolon);
116        }
117        StatementSectionDirectiveLine::B32Label { labels, .. } => {
118            push_directive(tokens, "b32");
119            push_identifier(tokens, &labels.val);
120            tokens.push(PtxToken::Semicolon);
121        }
122        StatementSectionDirectiveLine::B64Label { labels, .. } => {
123            push_directive(tokens, "b64");
124            push_identifier(tokens, &labels.val);
125            tokens.push(PtxToken::Semicolon);
126        }
127        StatementSectionDirectiveLine::B32LabelPlusImm { entries, .. } => {
128            push_directive(tokens, "b32");
129            let (label, offset) = entries;
130            push_identifier(tokens, &label.val);
131            if *offset >= 0 {
132                tokens.push(PtxToken::Plus);
133                push_decimal(tokens, *offset);
134            } else {
135                tokens.push(PtxToken::Minus);
136                let magnitude = (*offset as i128).abs();
137                push_decimal(tokens, magnitude);
138            }
139            tokens.push(PtxToken::Semicolon);
140        }
141        StatementSectionDirectiveLine::B64LabelPlusImm { entries, .. } => {
142            push_directive(tokens, "b64");
143            let (label, offset) = entries;
144            push_identifier(tokens, &label.val);
145            if *offset >= 0 {
146                tokens.push(PtxToken::Plus);
147                push_decimal(tokens, *offset);
148            } else {
149                tokens.push(PtxToken::Minus);
150                let magnitude = (*offset as i128).abs();
151                push_decimal(tokens, magnitude);
152            }
153            tokens.push(PtxToken::Semicolon);
154        }
155        StatementSectionDirectiveLine::B32LabelDiff { entries, .. } => {
156            push_directive(tokens, "b32");
157            let (left, right) = entries;
158            push_identifier(tokens, &left.val);
159            tokens.push(PtxToken::Minus);
160            push_identifier(tokens, &right.val);
161            tokens.push(PtxToken::Semicolon);
162        }
163        StatementSectionDirectiveLine::B64LabelDiff { entries, .. } => {
164            push_directive(tokens, "b64");
165            let (left, right) = entries;
166            push_identifier(tokens, &left.val);
167            tokens.push(PtxToken::Minus);
168            push_identifier(tokens, &right.val);
169            tokens.push(PtxToken::Semicolon);
170        }
171    }
172}
173
174fn push_signed_decimal_i64(tokens: &mut Vec<PtxToken>, value: i64) {
175    if value < 0 {
176        tokens.push(PtxToken::Minus);
177        push_decimal(tokens, (-value) as i128);
178    } else {
179        push_decimal(tokens, value);
180    }
181}
182
183fn push_signed_decimal_i128(tokens: &mut Vec<PtxToken>, value: i128) {
184    if value < 0 {
185        tokens.push(PtxToken::Minus);
186        push_decimal(tokens, -value);
187    } else {
188        push_decimal(tokens, value);
189    }
190}
191
192impl PtxUnparser for RegisterDirective {
193    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
194        push_directive(tokens, "reg");
195        self.ty.unparse_tokens(tokens);
196        for (idx, target) in self.registers.iter().enumerate() {
197            if idx > 0 {
198                tokens.push(PtxToken::Comma);
199            }
200            push_register_components(tokens, &target.name.val);
201            if let Some(range) = target.range {
202                tokens.push(PtxToken::LAngle);
203                push_decimal(tokens, range);
204                tokens.push(PtxToken::RAngle);
205            }
206        }
207        tokens.push(PtxToken::Semicolon);
208    }
209}
210
211impl PtxUnparser for StatementDirective {
212    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
213        match self {
214            StatementDirective::Reg {
215                directive: register,
216                ..
217            } => register.unparse_tokens(tokens),
218            StatementDirective::Local {
219                directive: variable,
220                ..
221            } => {
222                push_directive(tokens, "local");
223                variable.unparse_tokens(tokens);
224            }
225            StatementDirective::Param {
226                directive: variable,
227                ..
228            } => {
229                push_directive(tokens, "param");
230                variable.unparse_tokens(tokens);
231            }
232            StatementDirective::Shared {
233                directive: variable,
234                ..
235            } => {
236                push_directive(tokens, "shared");
237                variable.unparse_tokens(tokens);
238            }
239            StatementDirective::Pragma {
240                directive: pragma, ..
241            } => {
242                push_directive(tokens, "pragma");
243                let text = match &pragma.kind {
244                    PragmaDirectiveKind::Nounroll => "nounroll".to_string(),
245                    PragmaDirectiveKind::EnableSmemSpilling => "enable_smem_spilling".to_string(),
246                    PragmaDirectiveKind::UsedBytesMask { mask } => {
247                        format!("used_bytes_mask {}", mask)
248                    }
249                    PragmaDirectiveKind::Frequency { value } => {
250                        format!("frequency {}", value)
251                    }
252                    PragmaDirectiveKind::Raw(text) => text.clone(),
253                };
254                tokens.push(PtxToken::StringLiteral(text));
255                tokens.push(PtxToken::Semicolon);
256            }
257            StatementDirective::BranchTargets { directive, .. } => {
258                push_directive(tokens, "branchtargets");
259                for (idx, label) in directive.labels.iter().enumerate() {
260                    if idx > 0 {
261                        tokens.push(PtxToken::Comma);
262                    }
263                    push_token_from_str(tokens, &label.val);
264                }
265                tokens.push(PtxToken::Semicolon);
266            }
267            StatementDirective::CallTargets { directive, .. } => {
268                push_directive(tokens, "calltargets");
269                for (idx, target) in directive.targets.iter().enumerate() {
270                    if idx > 0 {
271                        tokens.push(PtxToken::Comma);
272                    }
273                    push_token_from_str(tokens, &target.val);
274                }
275                tokens.push(PtxToken::Semicolon);
276            }
277            StatementDirective::Loc { directive: loc, .. } => {
278                push_directive(tokens, "loc");
279                push_decimal(tokens, loc.file_index);
280                push_decimal(tokens, loc.line);
281                push_decimal(tokens, loc.column);
282                if let Some(inline) = &loc.inlined_at {
283                    tokens.push(PtxToken::Comma);
284                    push_identifier(tokens, "inlined_at");
285                    push_decimal(tokens, inline.file_index);
286                    push_decimal(tokens, inline.line);
287                    push_decimal(tokens, inline.column);
288                    tokens.push(PtxToken::Comma);
289                    push_identifier(tokens, &inline.function_name.val);
290                    push_identifier(tokens, &inline.label.val);
291                    if let Some(offset) = inline.label_offset {
292                        if offset >= 0 {
293                            tokens.push(PtxToken::Plus);
294                            push_decimal(tokens, offset);
295                        } else {
296                            tokens.push(PtxToken::Minus);
297                            push_decimal(tokens, offset.abs());
298                        }
299                    }
300                }
301                tokens.push(PtxToken::Semicolon);
302            }
303            StatementDirective::Dwarf {
304                directive: dwarf, ..
305            } => {
306                dwarf.unparse_tokens(tokens);
307            }
308            StatementDirective::Section {
309                directive: section, ..
310            } => {
311                section.unparse_tokens(tokens);
312            }
313            StatementDirective::CallPrototype { directive, .. } => {
314                push_directive(tokens, "callprototype");
315                if let Some(ret) = &directive.return_param {
316                    unparse_param(tokens, ret);
317                } else {
318                    push_identifier(tokens, "_");
319                }
320                tokens.push(PtxToken::LParen);
321                unparse_param_list(tokens, &directive.params);
322                tokens.push(PtxToken::RParen);
323                if directive.noreturn {
324                    push_directive(tokens, "noreturn");
325                }
326                if let Some(value) = directive.abi_preserve {
327                    push_directive(tokens, "abi_preserve");
328                    push_decimal(tokens, value);
329                }
330                if let Some(value) = directive.abi_preserve_control {
331                    push_directive(tokens, "abi_preserve_control");
332                    push_decimal(tokens, value);
333                }
334                tokens.push(PtxToken::Semicolon);
335            }
336        }
337    }
338}
339
340impl PtxUnparser for SectionDirective {
341    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
342        push_directive(tokens, "section");
343        push_token_from_str(tokens, &self.name);
344        tokens.push(PtxToken::LBrace);
345        for entry in &self.entries {
346            match entry {
347                SectionEntry::Label { label, .. } => {
348                    push_identifier(tokens, &label.val);
349                    tokens.push(PtxToken::Colon);
350                }
351                SectionEntry::Directive(line) => {
352                    unparse_section_line(tokens, line);
353                }
354            }
355        }
356        tokens.push(PtxToken::RBrace);
357    }
358}
359
360impl PtxUnparser for FunctionStatement {
361    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
362        match self {
363            FunctionStatement::Label { label, .. } => {
364                push_identifier(tokens, &label.val);
365                tokens.push(PtxToken::Colon);
366            }
367            FunctionStatement::Instruction { instruction, .. } => {
368                instruction.unparse_tokens(tokens)
369            }
370            FunctionStatement::Directive { directive, .. } => directive.unparse_tokens(tokens),
371            FunctionStatement::Block {
372                statements: block, ..
373            } => {
374                tokens.push(PtxToken::LBrace);
375                for statement in block {
376                    statement.unparse_tokens(tokens);
377                }
378                tokens.push(PtxToken::RBrace);
379            }
380        }
381    }
382}
383
384impl PtxUnparser for FunctionBody {
385    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
386        tokens.push(PtxToken::LBrace);
387        for statement in &self.statements {
388            statement.unparse_tokens(tokens);
389        }
390        tokens.push(PtxToken::RBrace);
391    }
392}
393
394impl PtxUnparser for FunctionDim {
395    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
396        match self {
397            FunctionDim::X { x, .. } => {
398                push_decimal(tokens, *x);
399            }
400            FunctionDim::XY { x, y, .. } => {
401                push_decimal(tokens, *x);
402                tokens.push(PtxToken::Comma);
403                push_decimal(tokens, *y);
404            }
405            FunctionDim::XYZ { x, y, z, .. } => {
406                push_decimal(tokens, *x);
407                tokens.push(PtxToken::Comma);
408                push_decimal(tokens, *y);
409                tokens.push(PtxToken::Comma);
410                push_decimal(tokens, *z);
411            }
412        }
413    }
414}
415
416impl PtxUnparser for EntryFunctionHeaderDirective {
417    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
418        match self {
419            EntryFunctionHeaderDirective::MaxNReg { value, .. } => {
420                push_directive(tokens, "maxnreg");
421                push_decimal(tokens, *value);
422            }
423            EntryFunctionHeaderDirective::MaxNTid { dim, .. } => {
424                push_directive(tokens, "maxntid");
425                dim.unparse_tokens(tokens);
426            }
427            EntryFunctionHeaderDirective::ReqNTid { dim, .. } => {
428                push_directive(tokens, "reqntid");
429                dim.unparse_tokens(tokens);
430            }
431            EntryFunctionHeaderDirective::MinNCtaPerSm { value, .. } => {
432                push_directive(tokens, "minnctapersm");
433                push_decimal(tokens, *value);
434            }
435            EntryFunctionHeaderDirective::MaxNCtaPerSm { value, .. } => {
436                push_directive(tokens, "maxnctapersm");
437                push_decimal(tokens, *value);
438            }
439            EntryFunctionHeaderDirective::Pragma {
440                args: arguments, ..
441            } => {
442                push_directive(tokens, "pragma");
443                for argument in arguments {
444                    tokens.push(PtxToken::StringLiteral(argument.clone()));
445                }
446            }
447            EntryFunctionHeaderDirective::ReqNctaPerCluster { dim, .. } => {
448                push_directive(tokens, "reqnctapercluster");
449                dim.unparse_tokens(tokens);
450            }
451            EntryFunctionHeaderDirective::ExplicitCluster { .. } => {
452                push_directive(tokens, "explicitcluster");
453            }
454            EntryFunctionHeaderDirective::MaxClusterRank { value, .. } => {
455                push_directive(tokens, "maxclusterrank");
456                push_decimal(tokens, *value);
457            }
458            EntryFunctionHeaderDirective::BlocksAreClusters { .. } => {
459                push_directive(tokens, "blocksareclusters")
460            }
461        }
462    }
463}
464
465impl PtxUnparser for FuncFunctionHeaderDirective {
466    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
467        match self {
468            FuncFunctionHeaderDirective::NoReturn { .. } => push_directive(tokens, "noreturn"),
469            FuncFunctionHeaderDirective::Pragma {
470                args: arguments, ..
471            } => {
472                push_directive(tokens, "pragma");
473                for argument in arguments {
474                    tokens.push(PtxToken::StringLiteral(argument.clone()));
475                }
476            }
477            FuncFunctionHeaderDirective::AbiPreserve { value, .. } => {
478                push_directive(tokens, "abi_preserve");
479                push_decimal(tokens, *value);
480            }
481            FuncFunctionHeaderDirective::AbiPreserveControl { value, .. } => {
482                push_directive(tokens, "abi_preserve_control");
483                push_decimal(tokens, *value);
484            }
485        }
486    }
487}
488
489impl PtxUnparser for AliasFunctionDirective {
490    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
491        push_directive(tokens, "alias");
492        push_identifier(tokens, &self.alias.val);
493        tokens.push(PtxToken::Comma);
494        push_identifier(tokens, &self.target.val);
495        tokens.push(PtxToken::Semicolon);
496    }
497}
498
499impl PtxUnparser for EntryFunctionDirective {
500    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
501        for directive in &self.directives {
502            directive.unparse_tokens(tokens);
503        }
504        push_directive(tokens, "entry");
505        push_identifier(tokens, &self.name.val);
506        tokens.push(PtxToken::LParen);
507        unparse_param_list(tokens, &self.params);
508        tokens.push(PtxToken::RParen);
509        match &self.body {
510            Some(body) => body.unparse_tokens(tokens),
511            None => tokens.push(PtxToken::Semicolon),
512        }
513    }
514}
515
516impl PtxUnparser for FuncFunctionDirective {
517    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
518        for attribute in &self.attributes {
519            attribute.unparse_tokens(tokens);
520        }
521        for directive in &self.directives {
522            directive.unparse_tokens(tokens);
523        }
524        push_directive(tokens, "func");
525        if let Some(ret) = &self.return_param {
526            tokens.push(PtxToken::LParen);
527            unparse_param(tokens, ret);
528            tokens.push(PtxToken::RParen);
529        }
530        push_identifier(tokens, &self.name.val);
531        tokens.push(PtxToken::LParen);
532        unparse_param_list(tokens, &self.params);
533        tokens.push(PtxToken::RParen);
534        match &self.body {
535            Some(body) => body.unparse_tokens(tokens),
536            None => tokens.push(PtxToken::Semicolon),
537        }
538    }
539}