ptx_parser/unparser/
function.rs

1use crate::unparser::common::push_register;
2use crate::{
3    lexer::PtxToken,
4    r#type::{function::*, variable::VariableDirective},
5    unparser::*,
6};
7
8fn push_register_components(tokens: &mut Vec<PtxToken>, name: &str) {
9    if let Some(stripped) = name.strip_prefix('%') {
10        let mut parts = stripped.split('.');
11        if let Some(first) = parts.next() {
12            let register_name = format!("%{first}");
13            push_register(tokens, &register_name);
14        }
15        for part in parts {
16            if part.is_empty() {
17                continue;
18            }
19            push_directive(tokens, part);
20        }
21    } else {
22        push_identifier(tokens, name);
23    }
24}
25
26fn unparse_register_directive(tokens: &mut Vec<PtxToken>, directive: &RegisterDirective) {
27    push_directive(tokens, "reg");
28    if let Some(ty) = &directive.ty {
29        push_directive(tokens, ty);
30    }
31    push_register_components(tokens, &directive.name);
32    if let Some(range) = directive.range {
33        tokens.push(PtxToken::LAngle);
34        push_decimal(tokens, range);
35        tokens.push(PtxToken::RAngle);
36    }
37    tokens.push(PtxToken::Semicolon);
38}
39
40fn unparse_entry_directive(tokens: &mut Vec<PtxToken>, directive: &FunctionEntryDirective) {
41    match directive {
42        FunctionEntryDirective::Reg(register) => unparse_register_directive(tokens, register),
43        FunctionEntryDirective::Local(variable) => variable.unparse_tokens(tokens),
44        FunctionEntryDirective::Param(variable) => variable.unparse_tokens(tokens),
45        FunctionEntryDirective::Shared(variable) => variable.unparse_tokens(tokens),
46        FunctionEntryDirective::Pragma(_) => {
47            panic!("unimplemented: unparsing .pragma function entry directives");
48        }
49        FunctionEntryDirective::Loc(_) => {
50            panic!("unimplemented: unparsing .loc function entry directives");
51        }
52        FunctionEntryDirective::Dwarf(_) => {
53            panic!("unimplemented: unparsing dwarf function entry directives");
54        }
55    }
56}
57
58fn unparse_extern_call_setup(tokens: &mut Vec<PtxToken>, setup: &ExternCallSetup) {
59    match setup {
60        ExternCallSetup::Param(variable) => variable.unparse_tokens(tokens),
61        ExternCallSetup::Store(instruction) => instruction.unparse_tokens(tokens),
62    }
63}
64
65fn unparse_extern_call_block(tokens: &mut Vec<PtxToken>, block: &ExternCallBlock) {
66    tokens.push(PtxToken::LBrace);
67    for directive in &block.declarations {
68        unparse_entry_directive(tokens, directive);
69    }
70    for entry in &block.setup {
71        unparse_extern_call_setup(tokens, entry);
72    }
73    block.call.unparse_tokens(tokens);
74    for instruction in &block.post_call {
75        instruction.unparse_tokens(tokens);
76    }
77    tokens.push(PtxToken::RBrace);
78}
79
80fn unparse_function_statement(tokens: &mut Vec<PtxToken>, statement: &FunctionStatement) {
81    match statement {
82        FunctionStatement::Label(name) => {
83            push_identifier(tokens, name);
84            tokens.push(PtxToken::Colon);
85        }
86        FunctionStatement::Instruction(instruction) => instruction.unparse_tokens(tokens),
87        FunctionStatement::ExternCallBlock(block) => unparse_extern_call_block(tokens, block),
88        FunctionStatement::Directive(_) => {
89            panic!("unimplemented: unparsing function statement directives");
90        }
91    }
92}
93
94fn unparse_function_dim(tokens: &mut Vec<PtxToken>, dim: &FunctionDim3) {
95    push_decimal(tokens, dim.x);
96    if let Some(y) = dim.y {
97        tokens.push(PtxToken::Comma);
98        push_decimal(tokens, y);
99    }
100    if let Some(z) = dim.z {
101        tokens.push(PtxToken::Comma);
102        push_decimal(tokens, z);
103    }
104}
105
106fn unparse_param(tokens: &mut Vec<PtxToken>, param: &VariableDirective) {
107    let mut param_tokens = param.to_tokens();
108    if matches!(param_tokens.last(), Some(PtxToken::Semicolon)) {
109        param_tokens.pop();
110    }
111    tokens.extend(param_tokens);
112}
113
114fn unparse_param_list(tokens: &mut Vec<PtxToken>, params: &[VariableDirective]) {
115    for (idx, param) in params.iter().enumerate() {
116        if idx > 0 {
117            tokens.push(PtxToken::Comma);
118        }
119        unparse_param(tokens, param);
120    }
121}
122
123fn unparse_function_header_directive(
124    tokens: &mut Vec<PtxToken>,
125    directive: &FunctionHeaderDirective,
126) {
127    match directive {
128        FunctionHeaderDirective::Linkage(linkage) => linkage.unparse_tokens(tokens),
129        FunctionHeaderDirective::NoReturn => push_directive(tokens, "noreturn"),
130        FunctionHeaderDirective::AbiPreserve(value) => {
131            push_directive(tokens, "abipreserve");
132            push_decimal(tokens, *value);
133        }
134        FunctionHeaderDirective::AbiPreserveControl(value) => {
135            push_directive(tokens, "abipreserve_control");
136            push_decimal(tokens, *value);
137        }
138        FunctionHeaderDirective::MaxClusterRank(value) => {
139            push_directive(tokens, "maxclusterrank");
140            push_decimal(tokens, *value);
141        }
142        FunctionHeaderDirective::BlocksAreClusters => push_directive(tokens, "blocksareclusters"),
143        FunctionHeaderDirective::ExplicitCluster(dim) => {
144            push_directive(tokens, "explicitcluster");
145            unparse_function_dim(tokens, dim);
146        }
147        FunctionHeaderDirective::ReqNctaPerCluster(dim) => {
148            push_directive(tokens, "reqnctapercluster");
149            unparse_function_dim(tokens, dim);
150        }
151        FunctionHeaderDirective::MaxNReg(value) => {
152            push_directive(tokens, "maxnreg");
153            push_decimal(tokens, *value);
154        }
155        FunctionHeaderDirective::MaxNTid(dim) => {
156            push_directive(tokens, "maxntid");
157            unparse_function_dim(tokens, dim);
158        }
159        FunctionHeaderDirective::MinNCtaPerSm(value) => {
160            push_directive(tokens, "minnctapersm");
161            push_decimal(tokens, *value);
162        }
163        FunctionHeaderDirective::ReqNTid(dim) => {
164            push_directive(tokens, "reqntid");
165            unparse_function_dim(tokens, dim);
166        }
167        FunctionHeaderDirective::MaxNCtaPerSm(value) => {
168            push_directive(tokens, "maxnctapersm");
169            push_decimal(tokens, *value);
170        }
171        FunctionHeaderDirective::Pragma(arguments) => {
172            push_directive(tokens, "pragma");
173            for argument in arguments {
174                tokens.push(PtxToken::Identifier(argument.clone()));
175            }
176        }
177    }
178}
179
180fn unparse_function_headers(tokens: &mut Vec<PtxToken>, directives: &[FunctionHeaderDirective]) {
181    for directive in directives {
182        unparse_function_header_directive(tokens, directive);
183    }
184}
185
186fn unparse_function_body(tokens: &mut Vec<PtxToken>, body: &FunctionBody, prefer_braces: bool) {
187    if body.entry_directives.is_empty() && body.statements.is_empty() {
188        if prefer_braces {
189            tokens.push(PtxToken::LBrace);
190            tokens.push(PtxToken::RBrace);
191        } else {
192            tokens.push(PtxToken::Semicolon);
193        }
194        return;
195    }
196
197    tokens.push(PtxToken::LBrace);
198    for directive in &body.entry_directives {
199        unparse_entry_directive(tokens, directive);
200    }
201    for statement in &body.statements {
202        unparse_function_statement(tokens, statement);
203    }
204    tokens.push(PtxToken::RBrace);
205}
206
207impl PtxUnparser for FunctionAlias {
208    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
209        push_directive(tokens, "alias");
210        push_identifier(tokens, &self.alias);
211        tokens.push(PtxToken::Comma);
212        push_identifier(tokens, &self.target);
213        tokens.push(PtxToken::Semicolon);
214    }
215}
216
217impl PtxUnparser for EntryFunction {
218    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
219        unparse_function_headers(tokens, &self.directives);
220        push_directive(tokens, "entry");
221        push_identifier(tokens, &self.name);
222        tokens.push(PtxToken::LParen);
223        unparse_param_list(tokens, &self.params);
224        tokens.push(PtxToken::RParen);
225        unparse_function_body(tokens, &self.body, true);
226    }
227}
228
229impl PtxUnparser for FuncFunction {
230    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
231        unparse_function_headers(tokens, &self.directives);
232        push_directive(tokens, "func");
233        if let Some(ret) = &self.return_param {
234            tokens.push(PtxToken::LParen);
235            unparse_param(tokens, ret);
236            tokens.push(PtxToken::RParen);
237        }
238        push_identifier(tokens, &self.name);
239        tokens.push(PtxToken::LParen);
240        unparse_param_list(tokens, &self.params);
241        tokens.push(PtxToken::RParen);
242        unparse_function_body(tokens, &self.body, false);
243    }
244}
245
246impl PtxUnparser for FunctionKernelDirective {
247    fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
248        match self {
249            FunctionKernelDirective::Entry(entry) => entry.unparse_tokens(tokens),
250            FunctionKernelDirective::Func(func) => func.unparse_tokens(tokens),
251            FunctionKernelDirective::Alias(alias) => alias.unparse_tokens(tokens),
252        }
253    }
254}