1use crate::unparser::common::push_register;
2use crate::{
3 lexer::PtxToken,
4 r#type::{function::*, variable::VariableDirective},
5 unparser::*,
6};
7
8fn push_register_components(tokens: &mut Vec<PtxToken>, name: &str) {
9 if let Some(stripped) = name.strip_prefix('%') {
10 let mut parts = stripped.split('.');
11 if let Some(first) = parts.next() {
12 let register_name = format!("%{first}");
13 push_register(tokens, ®ister_name);
14 }
15 for part in parts {
16 if part.is_empty() {
17 continue;
18 }
19 push_directive(tokens, part);
20 }
21 } else {
22 push_identifier(tokens, name);
23 }
24}
25
26fn unparse_register_directive(tokens: &mut Vec<PtxToken>, directive: &RegisterDirective) {
27 push_directive(tokens, "reg");
28 if let Some(ty) = &directive.ty {
29 push_directive(tokens, ty);
30 }
31 push_register_components(tokens, &directive.name);
32 if let Some(range) = directive.range {
33 tokens.push(PtxToken::LAngle);
34 push_decimal(tokens, range);
35 tokens.push(PtxToken::RAngle);
36 }
37 tokens.push(PtxToken::Semicolon);
38}
39
40fn unparse_entry_directive(tokens: &mut Vec<PtxToken>, directive: &FunctionEntryDirective) {
41 match directive {
42 FunctionEntryDirective::Reg(register) => unparse_register_directive(tokens, register),
43 FunctionEntryDirective::Local(variable) => variable.unparse_tokens(tokens),
44 FunctionEntryDirective::Param(variable) => variable.unparse_tokens(tokens),
45 FunctionEntryDirective::Shared(variable) => variable.unparse_tokens(tokens),
46 FunctionEntryDirective::Pragma(_) => {
47 panic!("unimplemented: unparsing .pragma function entry directives");
48 }
49 FunctionEntryDirective::Loc(_) => {
50 panic!("unimplemented: unparsing .loc function entry directives");
51 }
52 FunctionEntryDirective::Dwarf(_) => {
53 panic!("unimplemented: unparsing dwarf function entry directives");
54 }
55 }
56}
57
58fn unparse_extern_call_setup(tokens: &mut Vec<PtxToken>, setup: &ExternCallSetup) {
59 match setup {
60 ExternCallSetup::Param(variable) => variable.unparse_tokens(tokens),
61 ExternCallSetup::Store(instruction) => instruction.unparse_tokens(tokens),
62 }
63}
64
65fn unparse_extern_call_block(tokens: &mut Vec<PtxToken>, block: &ExternCallBlock) {
66 tokens.push(PtxToken::LBrace);
67 for directive in &block.declarations {
68 unparse_entry_directive(tokens, directive);
69 }
70 for entry in &block.setup {
71 unparse_extern_call_setup(tokens, entry);
72 }
73 block.call.unparse_tokens(tokens);
74 for instruction in &block.post_call {
75 instruction.unparse_tokens(tokens);
76 }
77 tokens.push(PtxToken::RBrace);
78}
79
80fn unparse_function_statement(tokens: &mut Vec<PtxToken>, statement: &FunctionStatement) {
81 match statement {
82 FunctionStatement::Label(name) => {
83 push_identifier(tokens, name);
84 tokens.push(PtxToken::Colon);
85 }
86 FunctionStatement::Instruction(instruction) => instruction.unparse_tokens(tokens),
87 FunctionStatement::ExternCallBlock(block) => unparse_extern_call_block(tokens, block),
88 FunctionStatement::Directive(_) => {
89 panic!("unimplemented: unparsing function statement directives");
90 }
91 }
92}
93
94fn unparse_function_dim(tokens: &mut Vec<PtxToken>, dim: &FunctionDim3) {
95 push_decimal(tokens, dim.x);
96 if let Some(y) = dim.y {
97 tokens.push(PtxToken::Comma);
98 push_decimal(tokens, y);
99 }
100 if let Some(z) = dim.z {
101 tokens.push(PtxToken::Comma);
102 push_decimal(tokens, z);
103 }
104}
105
106fn unparse_param(tokens: &mut Vec<PtxToken>, param: &VariableDirective) {
107 let mut param_tokens = param.to_tokens();
108 if matches!(param_tokens.last(), Some(PtxToken::Semicolon)) {
109 param_tokens.pop();
110 }
111 tokens.extend(param_tokens);
112}
113
114fn unparse_param_list(tokens: &mut Vec<PtxToken>, params: &[VariableDirective]) {
115 for (idx, param) in params.iter().enumerate() {
116 if idx > 0 {
117 tokens.push(PtxToken::Comma);
118 }
119 unparse_param(tokens, param);
120 }
121}
122
123fn unparse_function_header_directive(
124 tokens: &mut Vec<PtxToken>,
125 directive: &FunctionHeaderDirective,
126) {
127 match directive {
128 FunctionHeaderDirective::Linkage(linkage) => linkage.unparse_tokens(tokens),
129 FunctionHeaderDirective::NoReturn => push_directive(tokens, "noreturn"),
130 FunctionHeaderDirective::AbiPreserve(value) => {
131 push_directive(tokens, "abipreserve");
132 push_decimal(tokens, *value);
133 }
134 FunctionHeaderDirective::AbiPreserveControl(value) => {
135 push_directive(tokens, "abipreserve_control");
136 push_decimal(tokens, *value);
137 }
138 FunctionHeaderDirective::MaxClusterRank(value) => {
139 push_directive(tokens, "maxclusterrank");
140 push_decimal(tokens, *value);
141 }
142 FunctionHeaderDirective::BlocksAreClusters => push_directive(tokens, "blocksareclusters"),
143 FunctionHeaderDirective::ExplicitCluster(dim) => {
144 push_directive(tokens, "explicitcluster");
145 unparse_function_dim(tokens, dim);
146 }
147 FunctionHeaderDirective::ReqNctaPerCluster(dim) => {
148 push_directive(tokens, "reqnctapercluster");
149 unparse_function_dim(tokens, dim);
150 }
151 FunctionHeaderDirective::MaxNReg(value) => {
152 push_directive(tokens, "maxnreg");
153 push_decimal(tokens, *value);
154 }
155 FunctionHeaderDirective::MaxNTid(dim) => {
156 push_directive(tokens, "maxntid");
157 unparse_function_dim(tokens, dim);
158 }
159 FunctionHeaderDirective::MinNCtaPerSm(value) => {
160 push_directive(tokens, "minnctapersm");
161 push_decimal(tokens, *value);
162 }
163 FunctionHeaderDirective::ReqNTid(dim) => {
164 push_directive(tokens, "reqntid");
165 unparse_function_dim(tokens, dim);
166 }
167 FunctionHeaderDirective::MaxNCtaPerSm(value) => {
168 push_directive(tokens, "maxnctapersm");
169 push_decimal(tokens, *value);
170 }
171 FunctionHeaderDirective::Pragma(arguments) => {
172 push_directive(tokens, "pragma");
173 for argument in arguments {
174 tokens.push(PtxToken::Identifier(argument.clone()));
175 }
176 }
177 }
178}
179
180fn unparse_function_headers(tokens: &mut Vec<PtxToken>, directives: &[FunctionHeaderDirective]) {
181 for directive in directives {
182 unparse_function_header_directive(tokens, directive);
183 }
184}
185
186fn unparse_function_body(tokens: &mut Vec<PtxToken>, body: &FunctionBody, prefer_braces: bool) {
187 if body.entry_directives.is_empty() && body.statements.is_empty() {
188 if prefer_braces {
189 tokens.push(PtxToken::LBrace);
190 tokens.push(PtxToken::RBrace);
191 } else {
192 tokens.push(PtxToken::Semicolon);
193 }
194 return;
195 }
196
197 tokens.push(PtxToken::LBrace);
198 for directive in &body.entry_directives {
199 unparse_entry_directive(tokens, directive);
200 }
201 for statement in &body.statements {
202 unparse_function_statement(tokens, statement);
203 }
204 tokens.push(PtxToken::RBrace);
205}
206
207impl PtxUnparser for FunctionAlias {
208 fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
209 push_directive(tokens, "alias");
210 push_identifier(tokens, &self.alias);
211 tokens.push(PtxToken::Comma);
212 push_identifier(tokens, &self.target);
213 tokens.push(PtxToken::Semicolon);
214 }
215}
216
217impl PtxUnparser for EntryFunction {
218 fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
219 unparse_function_headers(tokens, &self.directives);
220 push_directive(tokens, "entry");
221 push_identifier(tokens, &self.name);
222 tokens.push(PtxToken::LParen);
223 unparse_param_list(tokens, &self.params);
224 tokens.push(PtxToken::RParen);
225 unparse_function_body(tokens, &self.body, true);
226 }
227}
228
229impl PtxUnparser for FuncFunction {
230 fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
231 unparse_function_headers(tokens, &self.directives);
232 push_directive(tokens, "func");
233 if let Some(ret) = &self.return_param {
234 tokens.push(PtxToken::LParen);
235 unparse_param(tokens, ret);
236 tokens.push(PtxToken::RParen);
237 }
238 push_identifier(tokens, &self.name);
239 tokens.push(PtxToken::LParen);
240 unparse_param_list(tokens, &self.params);
241 tokens.push(PtxToken::RParen);
242 unparse_function_body(tokens, &self.body, false);
243 }
244}
245
246impl PtxUnparser for FunctionKernelDirective {
247 fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
248 match self {
249 FunctionKernelDirective::Entry(entry) => entry.unparse_tokens(tokens),
250 FunctionKernelDirective::Func(func) => func.unparse_tokens(tokens),
251 FunctionKernelDirective::Alias(alias) => alias.unparse_tokens(tokens),
252 }
253 }
254}