1use crate::unparser::common::push_register;
2use crate::{
3 lexer::PtxToken,
4 r#type::{function::*, variable::ParameterDirective},
5 unparser::*,
6};
7
8fn push_register_components(tokens: &mut Vec<PtxToken>, name: &str) {
9 if let Some(stripped) = name.strip_prefix('%') {
10 let mut parts = stripped.split('.');
11 if let Some(first) = parts.next() {
12 let register_name = format!("%{first}");
13 push_register(tokens, ®ister_name);
14 }
15 for part in parts {
16 if part.is_empty() {
17 continue;
18 }
19 push_directive(tokens, part);
20 }
21 } else {
22 push_identifier(tokens, name);
23 }
24}
25
26fn unparse_param(tokens: &mut Vec<PtxToken>, param: &ParameterDirective) {
27 match param {
28 ParameterDirective::Parameter {
29 align,
30 ty,
31 ptr,
32 space,
33 name,
34 array,
35 ..
36 } => {
37 push_directive(tokens, "param");
38 ty.unparse_tokens(tokens);
39 if *ptr {
40 push_directive(tokens, "ptr");
41 }
42 if let Some(address_space) = space {
43 address_space.unparse_tokens(tokens);
44 }
45 if let Some(value) = align {
46 push_directive(tokens, "align");
47 push_decimal(tokens, *value);
48 }
49 push_identifier(tokens, &name.val);
50 for extent in array {
51 tokens.push(PtxToken::LBracket);
52 if let Some(value) = extent {
53 push_decimal(tokens, *value);
54 }
55 tokens.push(PtxToken::RBracket);
56 }
57 }
58 ParameterDirective::Register { ty, name, .. } => {
59 push_directive(tokens, "reg");
60 ty.unparse_tokens(tokens);
61 push_register_components(tokens, &name.val);
62 }
63 }
64}
65
66fn unparse_param_list(tokens: &mut Vec<PtxToken>, params: &[ParameterDirective]) {
67 for (idx, param) in params.iter().enumerate() {
68 if idx > 0 {
69 tokens.push(PtxToken::Comma);
70 }
71 unparse_param(tokens, param);
72 }
73}
74
75fn unparse_section_line(tokens: &mut Vec<PtxToken>, line: &StatementSectionDirectiveLine) {
76 match line {
77 StatementSectionDirectiveLine::B8 { values, .. } => {
78 push_directive(tokens, "b8");
79 for (idx, value) in values.iter().enumerate() {
80 if idx > 0 {
81 tokens.push(PtxToken::Comma);
82 }
83 push_signed_decimal_i64(tokens, *value as i64);
84 }
85 tokens.push(PtxToken::Semicolon);
86 }
87 StatementSectionDirectiveLine::B16 { values, .. } => {
88 push_directive(tokens, "b16");
89 for (idx, value) in values.iter().enumerate() {
90 if idx > 0 {
91 tokens.push(PtxToken::Comma);
92 }
93 push_signed_decimal_i64(tokens, *value as i64);
94 }
95 tokens.push(PtxToken::Semicolon);
96 }
97 StatementSectionDirectiveLine::B32Immediate { values, .. } => {
98 push_directive(tokens, "b32");
99 for (idx, value) in values.iter().enumerate() {
100 if idx > 0 {
101 tokens.push(PtxToken::Comma);
102 }
103 push_signed_decimal_i64(tokens, *value);
104 }
105 tokens.push(PtxToken::Semicolon);
106 }
107 StatementSectionDirectiveLine::B64Immediate { values, .. } => {
108 push_directive(tokens, "b64");
109 for (idx, value) in values.iter().enumerate() {
110 if idx > 0 {
111 tokens.push(PtxToken::Comma);
112 }
113 push_signed_decimal_i128(tokens, *value);
114 }
115 tokens.push(PtxToken::Semicolon);
116 }
117 StatementSectionDirectiveLine::B32Label { labels, .. } => {
118 push_directive(tokens, "b32");
119 push_identifier(tokens, &labels.val);
120 tokens.push(PtxToken::Semicolon);
121 }
122 StatementSectionDirectiveLine::B64Label { labels, .. } => {
123 push_directive(tokens, "b64");
124 push_identifier(tokens, &labels.val);
125 tokens.push(PtxToken::Semicolon);
126 }
127 StatementSectionDirectiveLine::B32LabelPlusImm { entries, .. } => {
128 push_directive(tokens, "b32");
129 let (label, offset) = entries;
130 push_identifier(tokens, &label.val);
131 if *offset >= 0 {
132 tokens.push(PtxToken::Plus);
133 push_decimal(tokens, *offset);
134 } else {
135 tokens.push(PtxToken::Minus);
136 let magnitude = (*offset as i128).abs();
137 push_decimal(tokens, magnitude);
138 }
139 tokens.push(PtxToken::Semicolon);
140 }
141 StatementSectionDirectiveLine::B64LabelPlusImm { entries, .. } => {
142 push_directive(tokens, "b64");
143 let (label, offset) = entries;
144 push_identifier(tokens, &label.val);
145 if *offset >= 0 {
146 tokens.push(PtxToken::Plus);
147 push_decimal(tokens, *offset);
148 } else {
149 tokens.push(PtxToken::Minus);
150 let magnitude = (*offset as i128).abs();
151 push_decimal(tokens, magnitude);
152 }
153 tokens.push(PtxToken::Semicolon);
154 }
155 StatementSectionDirectiveLine::B32LabelDiff { entries, .. } => {
156 push_directive(tokens, "b32");
157 let (left, right) = entries;
158 push_identifier(tokens, &left.val);
159 tokens.push(PtxToken::Minus);
160 push_identifier(tokens, &right.val);
161 tokens.push(PtxToken::Semicolon);
162 }
163 StatementSectionDirectiveLine::B64LabelDiff { entries, .. } => {
164 push_directive(tokens, "b64");
165 let (left, right) = entries;
166 push_identifier(tokens, &left.val);
167 tokens.push(PtxToken::Minus);
168 push_identifier(tokens, &right.val);
169 tokens.push(PtxToken::Semicolon);
170 }
171 }
172}
173
174fn push_signed_decimal_i64(tokens: &mut Vec<PtxToken>, value: i64) {
175 if value < 0 {
176 tokens.push(PtxToken::Minus);
177 push_decimal(tokens, (-value) as i128);
178 } else {
179 push_decimal(tokens, value);
180 }
181}
182
183fn push_signed_decimal_i128(tokens: &mut Vec<PtxToken>, value: i128) {
184 if value < 0 {
185 tokens.push(PtxToken::Minus);
186 push_decimal(tokens, -value);
187 } else {
188 push_decimal(tokens, value);
189 }
190}
191
192impl PtxUnparser for RegisterDirective {
193 fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
194 push_directive(tokens, "reg");
195 self.ty.unparse_tokens(tokens);
196 for (idx, target) in self.registers.iter().enumerate() {
197 if idx > 0 {
198 tokens.push(PtxToken::Comma);
199 }
200 push_register_components(tokens, &target.name.val);
201 if let Some(range) = target.range {
202 tokens.push(PtxToken::LAngle);
203 push_decimal(tokens, range);
204 tokens.push(PtxToken::RAngle);
205 }
206 }
207 tokens.push(PtxToken::Semicolon);
208 }
209}
210
211impl PtxUnparser for StatementDirective {
212 fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
213 match self {
214 StatementDirective::Reg {
215 directive: register,
216 ..
217 } => register.unparse_tokens(tokens),
218 StatementDirective::Local {
219 directive: variable,
220 ..
221 } => {
222 push_directive(tokens, "local");
223 variable.unparse_tokens(tokens);
224 }
225 StatementDirective::Param {
226 directive: variable,
227 ..
228 } => {
229 push_directive(tokens, "param");
230 variable.unparse_tokens(tokens);
231 }
232 StatementDirective::Shared {
233 directive: variable,
234 ..
235 } => {
236 push_directive(tokens, "shared");
237 variable.unparse_tokens(tokens);
238 }
239 StatementDirective::Pragma {
240 directive: pragma, ..
241 } => {
242 push_directive(tokens, "pragma");
243 let text = match &pragma.kind {
244 PragmaDirectiveKind::Nounroll => "nounroll".to_string(),
245 PragmaDirectiveKind::EnableSmemSpilling => "enable_smem_spilling".to_string(),
246 PragmaDirectiveKind::UsedBytesMask { mask } => {
247 format!("used_bytes_mask {}", mask)
248 }
249 PragmaDirectiveKind::Frequency { value } => {
250 format!("frequency {}", value)
251 }
252 PragmaDirectiveKind::Raw(text) => text.clone(),
253 };
254 tokens.push(PtxToken::StringLiteral(text));
255 tokens.push(PtxToken::Semicolon);
256 }
257 StatementDirective::BranchTargets { directive, .. } => {
258 push_directive(tokens, "branchtargets");
259 for (idx, label) in directive.labels.iter().enumerate() {
260 if idx > 0 {
261 tokens.push(PtxToken::Comma);
262 }
263 push_token_from_str(tokens, &label.val);
264 }
265 tokens.push(PtxToken::Semicolon);
266 }
267 StatementDirective::CallTargets { directive, .. } => {
268 push_directive(tokens, "calltargets");
269 for (idx, target) in directive.targets.iter().enumerate() {
270 if idx > 0 {
271 tokens.push(PtxToken::Comma);
272 }
273 push_token_from_str(tokens, &target.val);
274 }
275 tokens.push(PtxToken::Semicolon);
276 }
277 StatementDirective::Loc { directive: loc, .. } => {
278 push_directive(tokens, "loc");
279 push_decimal(tokens, loc.file_index);
280 push_decimal(tokens, loc.line);
281 push_decimal(tokens, loc.column);
282 if let Some(inline) = &loc.inlined_at {
283 tokens.push(PtxToken::Comma);
284 push_identifier(tokens, "inlined_at");
285 push_decimal(tokens, inline.file_index);
286 push_decimal(tokens, inline.line);
287 push_decimal(tokens, inline.column);
288 tokens.push(PtxToken::Comma);
289 push_identifier(tokens, &inline.function_name.val);
290 push_identifier(tokens, &inline.label.val);
291 if let Some(offset) = inline.label_offset {
292 if offset >= 0 {
293 tokens.push(PtxToken::Plus);
294 push_decimal(tokens, offset);
295 } else {
296 tokens.push(PtxToken::Minus);
297 push_decimal(tokens, offset.abs());
298 }
299 }
300 }
301 tokens.push(PtxToken::Semicolon);
302 }
303 StatementDirective::Dwarf {
304 directive: dwarf, ..
305 } => {
306 dwarf.unparse_tokens(tokens);
307 }
308 StatementDirective::Section {
309 directive: section, ..
310 } => {
311 section.unparse_tokens(tokens);
312 }
313 StatementDirective::CallPrototype { directive, .. } => {
314 push_directive(tokens, "callprototype");
315 if let Some(ret) = &directive.return_param {
316 unparse_param(tokens, ret);
317 } else {
318 push_identifier(tokens, "_");
319 }
320 tokens.push(PtxToken::LParen);
321 unparse_param_list(tokens, &directive.params);
322 tokens.push(PtxToken::RParen);
323 if directive.noreturn {
324 push_directive(tokens, "noreturn");
325 }
326 if let Some(value) = directive.abi_preserve {
327 push_directive(tokens, "abi_preserve");
328 push_decimal(tokens, value);
329 }
330 if let Some(value) = directive.abi_preserve_control {
331 push_directive(tokens, "abi_preserve_control");
332 push_decimal(tokens, value);
333 }
334 tokens.push(PtxToken::Semicolon);
335 }
336 }
337 }
338}
339
340impl PtxUnparser for SectionDirective {
341 fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
342 push_directive(tokens, "section");
343 push_token_from_str(tokens, &self.name);
344 tokens.push(PtxToken::LBrace);
345 for entry in &self.entries {
346 match entry {
347 SectionEntry::Label { label, .. } => {
348 push_identifier(tokens, &label.val);
349 tokens.push(PtxToken::Colon);
350 }
351 SectionEntry::Directive(line) => {
352 unparse_section_line(tokens, line);
353 }
354 }
355 }
356 tokens.push(PtxToken::RBrace);
357 }
358}
359
360impl PtxUnparser for FunctionStatement {
361 fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
362 match self {
363 FunctionStatement::Label { label, .. } => {
364 push_identifier(tokens, &label.val);
365 tokens.push(PtxToken::Colon);
366 }
367 FunctionStatement::Instruction { instruction, .. } => {
368 instruction.unparse_tokens(tokens)
369 }
370 FunctionStatement::Directive { directive, .. } => directive.unparse_tokens(tokens),
371 FunctionStatement::Block {
372 statements: block, ..
373 } => {
374 tokens.push(PtxToken::LBrace);
375 for statement in block {
376 statement.unparse_tokens(tokens);
377 }
378 tokens.push(PtxToken::RBrace);
379 }
380 }
381 }
382}
383
384impl PtxUnparser for FunctionBody {
385 fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
386 tokens.push(PtxToken::LBrace);
387 for statement in &self.statements {
388 statement.unparse_tokens(tokens);
389 }
390 tokens.push(PtxToken::RBrace);
391 }
392}
393
394impl PtxUnparser for FunctionDim {
395 fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
396 match self {
397 FunctionDim::X { x, .. } => {
398 push_decimal(tokens, *x);
399 }
400 FunctionDim::XY { x, y, .. } => {
401 push_decimal(tokens, *x);
402 tokens.push(PtxToken::Comma);
403 push_decimal(tokens, *y);
404 }
405 FunctionDim::XYZ { x, y, z, .. } => {
406 push_decimal(tokens, *x);
407 tokens.push(PtxToken::Comma);
408 push_decimal(tokens, *y);
409 tokens.push(PtxToken::Comma);
410 push_decimal(tokens, *z);
411 }
412 }
413 }
414}
415
416impl PtxUnparser for EntryFunctionHeaderDirective {
417 fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
418 match self {
419 EntryFunctionHeaderDirective::MaxNReg { value, .. } => {
420 push_directive(tokens, "maxnreg");
421 push_decimal(tokens, *value);
422 }
423 EntryFunctionHeaderDirective::MaxNTid { dim, .. } => {
424 push_directive(tokens, "maxntid");
425 dim.unparse_tokens(tokens);
426 }
427 EntryFunctionHeaderDirective::ReqNTid { dim, .. } => {
428 push_directive(tokens, "reqntid");
429 dim.unparse_tokens(tokens);
430 }
431 EntryFunctionHeaderDirective::MinNCtaPerSm { value, .. } => {
432 push_directive(tokens, "minnctapersm");
433 push_decimal(tokens, *value);
434 }
435 EntryFunctionHeaderDirective::MaxNCtaPerSm { value, .. } => {
436 push_directive(tokens, "maxnctapersm");
437 push_decimal(tokens, *value);
438 }
439 EntryFunctionHeaderDirective::Pragma {
440 args: arguments, ..
441 } => {
442 push_directive(tokens, "pragma");
443 for argument in arguments {
444 tokens.push(PtxToken::StringLiteral(argument.clone()));
445 }
446 }
447 EntryFunctionHeaderDirective::ReqNctaPerCluster { dim, .. } => {
448 push_directive(tokens, "reqnctapercluster");
449 dim.unparse_tokens(tokens);
450 }
451 EntryFunctionHeaderDirective::ExplicitCluster { .. } => {
452 push_directive(tokens, "explicitcluster");
453 }
454 EntryFunctionHeaderDirective::MaxClusterRank { value, .. } => {
455 push_directive(tokens, "maxclusterrank");
456 push_decimal(tokens, *value);
457 }
458 EntryFunctionHeaderDirective::BlocksAreClusters { .. } => {
459 push_directive(tokens, "blocksareclusters")
460 }
461 }
462 }
463}
464
465impl PtxUnparser for FuncFunctionHeaderDirective {
466 fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
467 match self {
468 FuncFunctionHeaderDirective::NoReturn { .. } => push_directive(tokens, "noreturn"),
469 FuncFunctionHeaderDirective::Pragma {
470 args: arguments, ..
471 } => {
472 push_directive(tokens, "pragma");
473 for argument in arguments {
474 tokens.push(PtxToken::StringLiteral(argument.clone()));
475 }
476 }
477 FuncFunctionHeaderDirective::AbiPreserve { value, .. } => {
478 push_directive(tokens, "abi_preserve");
479 push_decimal(tokens, *value);
480 }
481 FuncFunctionHeaderDirective::AbiPreserveControl { value, .. } => {
482 push_directive(tokens, "abi_preserve_control");
483 push_decimal(tokens, *value);
484 }
485 }
486 }
487}
488
489impl PtxUnparser for AliasFunctionDirective {
490 fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
491 push_directive(tokens, "alias");
492 push_identifier(tokens, &self.alias.val);
493 tokens.push(PtxToken::Comma);
494 push_identifier(tokens, &self.target.val);
495 tokens.push(PtxToken::Semicolon);
496 }
497}
498
499impl PtxUnparser for EntryFunctionDirective {
500 fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
501 for directive in &self.directives {
502 directive.unparse_tokens(tokens);
503 }
504 push_directive(tokens, "entry");
505 push_identifier(tokens, &self.name.val);
506 tokens.push(PtxToken::LParen);
507 unparse_param_list(tokens, &self.params);
508 tokens.push(PtxToken::RParen);
509 match &self.body {
510 Some(body) => body.unparse_tokens(tokens),
511 None => tokens.push(PtxToken::Semicolon),
512 }
513 }
514}
515
516impl PtxUnparser for FuncFunctionDirective {
517 fn unparse_tokens(&self, tokens: &mut Vec<PtxToken>) {
518 for attribute in &self.attributes {
519 attribute.unparse_tokens(tokens);
520 }
521 for directive in &self.directives {
522 directive.unparse_tokens(tokens);
523 }
524 push_directive(tokens, "func");
525 if let Some(ret) = &self.return_param {
526 tokens.push(PtxToken::LParen);
527 unparse_param(tokens, ret);
528 tokens.push(PtxToken::RParen);
529 }
530 push_identifier(tokens, &self.name.val);
531 tokens.push(PtxToken::LParen);
532 unparse_param_list(tokens, &self.params);
533 tokens.push(PtxToken::RParen);
534 match &self.body {
535 Some(body) => body.unparse_tokens(tokens),
536 None => tokens.push(PtxToken::Semicolon),
537 }
538 }
539}