ptx_parser/unparser/
function.rs

1use crate::unparser::common::push_register;
2use crate::{
3    lexer::PtxToken,
4    r#type::{function::*, variable::VariableDirective},
5    unparser::*,
6};
7
8fn push_register_components(tokens: &mut Vec<PtxToken>, name: &str) {
9    if let Some(stripped) = name.strip_prefix('%') {
10        let mut parts = stripped.split('.');
11        if let Some(first) = parts.next() {
12            let register_name = format!("%{first}");
13            push_register(tokens, &register_name);
14        }
15        for part in parts {
16            if part.is_empty() {
17                continue;
18            }
19            push_directive(tokens, part);
20        }
21    } else {
22        push_identifier(tokens, name);
23    }
24}
25
26fn unparse_param(tokens: &mut Vec<PtxToken>, param: &VariableDirective) {
27    let mut param_tokens = param.to_tokens();
28    if matches!(param_tokens.last(), Some(PtxToken::Semicolon)) {
29        param_tokens.pop();
30    }
31    tokens.extend(param_tokens);
32}
33
34fn unparse_param_list(tokens: &mut Vec<PtxToken>, params: &[VariableDirective]) {
35    for (idx, param) in params.iter().enumerate() {
36        if idx > 0 {
37            tokens.push(PtxToken::Comma);
38        }
39        unparse_param(tokens, param);
40    }
41}
42
43impl PtxUnparser for RegisterDirective {
44    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
45        push_directive(tokens, "reg");
46        if let Some(ty) = &self.ty {
47            push_directive(tokens, ty);
48        }
49        push_register_components(tokens, &self.name);
50        if let Some(range) = self.range {
51            tokens.push(PtxToken::LAngle);
52            push_decimal(tokens, range);
53            tokens.push(PtxToken::RAngle);
54        }
55        tokens.push(PtxToken::Semicolon);
56    }
57}
58
59impl PtxUnparser for StatementDirective {
60    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
61        match self {
62            StatementDirective::Reg(register) => register.unparse_tokens(tokens),
63            StatementDirective::Local(variable)
64            | StatementDirective::Param(variable)
65            | StatementDirective::Shared(variable) => variable.unparse_tokens(tokens),
66            StatementDirective::Pragma(pragma) => {
67                push_directive(tokens, "pragma");
68                for (idx, argument) in pragma.arguments.iter().enumerate() {
69                    if idx > 0 {
70                        tokens.push(PtxToken::Comma);
71                    }
72                    push_token_from_str(tokens, argument);
73                }
74                tokens.push(PtxToken::Semicolon);
75            }
76            StatementDirective::Loc(loc) => {
77                push_directive(tokens, "loc");
78                push_decimal(tokens, loc.file_index);
79                push_decimal(tokens, loc.line);
80                push_decimal(tokens, loc.column);
81                for option in &loc.options {
82                    push_token_from_str(tokens, option);
83                }
84                tokens.push(PtxToken::Semicolon);
85            }
86            StatementDirective::Dwarf(dwarf) => {
87                push_directive(tokens, "dwarf");
88                push_identifier(tokens, &dwarf.keyword);
89                for argument in &dwarf.arguments {
90                    tokens.push(PtxToken::Comma);
91                    push_token_from_str(tokens, argument);
92                }
93                tokens.push(PtxToken::Semicolon);
94            }
95            StatementDirective::Section(section) => {
96                push_directive(tokens, "section");
97                push_token_from_str(tokens, &section.name);
98                for argument in &section.arguments {
99                    tokens.push(PtxToken::Comma);
100                    push_token_from_str(tokens, argument);
101                }
102                tokens.push(PtxToken::Semicolon);
103            }
104        }
105    }
106}
107
108impl PtxUnparser for FunctionStatement {
109    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
110        match self {
111            FunctionStatement::Label(name) => {
112                push_identifier(tokens, name);
113                tokens.push(PtxToken::Colon);
114            }
115            FunctionStatement::Instruction(instruction) => instruction.unparse_tokens(tokens),
116            FunctionStatement::Directive(directive) => directive.unparse_tokens(tokens),
117            FunctionStatement::Block(block) => {
118                tokens.push(PtxToken::LBrace);
119                for statement in block {
120                    statement.unparse_tokens(tokens);
121                }
122                tokens.push(PtxToken::RBrace);
123            }
124        }
125    }
126}
127
128impl PtxUnparser for FunctionBody {
129    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
130        tokens.push(PtxToken::LBrace);
131        for statement in &self.statements {
132            statement.unparse_tokens(tokens);
133        }
134        tokens.push(PtxToken::RBrace);
135    }
136}
137
138impl PtxUnparser for FunctionDim3 {
139    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
140        push_decimal(tokens, self.x);
141        if let Some(y) = self.y {
142            tokens.push(PtxToken::Comma);
143            push_decimal(tokens, y);
144        }
145        if let Some(z) = self.z {
146            tokens.push(PtxToken::Comma);
147            push_decimal(tokens, z);
148        }
149    }
150}
151
152impl PtxUnparser for FunctionHeaderDirective {
153    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
154        match self {
155            FunctionHeaderDirective::Linkage(linkage) => linkage.unparse_tokens(tokens),
156            FunctionHeaderDirective::NoReturn => push_directive(tokens, "noreturn"),
157            FunctionHeaderDirective::AbiPreserve(value) => {
158                push_directive(tokens, "abipreserve");
159                push_decimal(tokens, *value);
160            }
161            FunctionHeaderDirective::AbiPreserveControl(value) => {
162                push_directive(tokens, "abipreserve_control");
163                push_decimal(tokens, *value);
164            }
165            FunctionHeaderDirective::MaxClusterRank(value) => {
166                push_directive(tokens, "maxclusterrank");
167                push_decimal(tokens, *value);
168            }
169            FunctionHeaderDirective::BlocksAreClusters => {
170                push_directive(tokens, "blocksareclusters")
171            }
172            FunctionHeaderDirective::ExplicitCluster(dim) => {
173                push_directive(tokens, "explicitcluster");
174                dim.unparse_tokens(tokens);
175            }
176            FunctionHeaderDirective::ReqNctaPerCluster(dim) => {
177                push_directive(tokens, "reqnctapercluster");
178                dim.unparse_tokens(tokens);
179            }
180            FunctionHeaderDirective::MaxNReg(value) => {
181                push_directive(tokens, "maxnreg");
182                push_decimal(tokens, *value);
183            }
184            FunctionHeaderDirective::MaxNTid(dim) => {
185                push_directive(tokens, "maxntid");
186                dim.unparse_tokens(tokens);
187            }
188            FunctionHeaderDirective::MinNCtaPerSm(value) => {
189                push_directive(tokens, "minnctapersm");
190                push_decimal(tokens, *value);
191            }
192            FunctionHeaderDirective::ReqNTid(dim) => {
193                push_directive(tokens, "reqntid");
194                dim.unparse_tokens(tokens);
195            }
196            FunctionHeaderDirective::MaxNCtaPerSm(value) => {
197                push_directive(tokens, "maxnctapersm");
198                push_decimal(tokens, *value);
199            }
200            FunctionHeaderDirective::Pragma(arguments) => {
201                push_directive(tokens, "pragma");
202                for argument in arguments {
203                    tokens.push(PtxToken::Identifier(argument.clone()));
204                }
205            }
206        }
207    }
208}
209
210impl PtxUnparser for FunctionAlias {
211    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
212        push_directive(tokens, "alias");
213        push_identifier(tokens, &self.alias);
214        tokens.push(PtxToken::Comma);
215        push_identifier(tokens, &self.target);
216        tokens.push(PtxToken::Semicolon);
217    }
218}
219
220impl PtxUnparser for EntryFunction {
221    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
222        for directive in &self.directives {
223            directive.unparse_tokens(tokens);
224        }
225        push_directive(tokens, "entry");
226        push_identifier(tokens, &self.name);
227        tokens.push(PtxToken::LParen);
228        unparse_param_list(tokens, &self.params);
229        tokens.push(PtxToken::RParen);
230        if self.body.statements.is_empty() {
231            tokens.push(PtxToken::LBrace);
232            tokens.push(PtxToken::RBrace);
233        } else {
234            self.body.unparse_tokens(tokens);
235        }
236    }
237}
238
239impl PtxUnparser for FuncFunction {
240    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
241        for directive in &self.directives {
242            directive.unparse_tokens(tokens);
243        }
244        push_directive(tokens, "func");
245        if let Some(ret) = &self.return_param {
246            tokens.push(PtxToken::LParen);
247            unparse_param(tokens, ret);
248            tokens.push(PtxToken::RParen);
249        }
250        push_identifier(tokens, &self.name);
251        tokens.push(PtxToken::LParen);
252        unparse_param_list(tokens, &self.params);
253        tokens.push(PtxToken::RParen);
254        if self.body.statements.is_empty() {
255            tokens.push(PtxToken::Semicolon);
256        } else {
257            self.body.unparse_tokens(tokens);
258        }
259    }
260}
261
262impl PtxUnparser for FunctionKernelDirective {
263    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
264        match self {
265            FunctionKernelDirective::Entry(entry) => entry.unparse_tokens(tokens),
266            FunctionKernelDirective::Func(func) => func.unparse_tokens(tokens),
267            FunctionKernelDirective::Alias(alias) => alias.unparse_tokens(tokens),
268        }
269    }
270}