1use ptx_90_parser_construct::func;
2
3use crate::{
4 alt, c,
5 lexer::PtxToken,
6 mapc, ok,
7 parser::{
8 ParseErrorKind, PtxParseError, PtxParser, PtxTokenStream, Span,
9 util::{
10 alt, between, colon_p, comma_p, directive_exact_p, directive_p, identifier_p,
11 integer_p, langle_p, lbrace_p, lparen_p, many, map, minus_p, optional,
12 parse_signed_integer, parse_u32_literal, parse_unsigned_integer, plus_p, pure,
13 rangle_p, rbrace_p, register_p, rparen_p, semicolon_p, sep_by, sep_by1, seq, seq5,
14 skip_first, skip_second, skip_semicolon, string_literal_p, try_map, u32_p,
15 },
16 },
17 seq_n,
18 r#type::{
19 AliasFunctionDirective, AttributeDirective, BranchTargetsDirective, CallPrototypeDirective,
20 CallTargetsDirective, DataType, DwarfDirective, DwarfDirectiveKind, EntryFunctionDirective,
21 EntryFunctionHeaderDirective, FuncFunctionDirective, FuncFunctionHeaderDirective,
22 FunctionBody, FunctionDim, FunctionStatement, FunctionSymbol, Instruction, Label,
23 LocationDirective, LocationInlinedAt, ParameterDirective, PragmaDirective,
24 PragmaDirectiveKind, RegisterDirective, RegisterTarget, SectionDirective, SectionEntry,
25 StatementDirective, StatementSectionDirectiveLine, VariableDirective, VariableSymbol,
26 },
27};
28
29impl PtxParser for StatementDirective {
30 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
31 let branch_targets = try_map(
32 skip_semicolon(skip_first(
33 directive_exact_p("branchtargets"),
34 sep_by1(Label::parse(), comma_p()),
35 )),
36 |labels, span| {
37 let directive = BranchTargetsDirective {
38 labels,
39 span: span.clone(),
40 };
41 ok!(StatementDirective::BranchTargets { directive })
42 },
43 );
44
45 let call_targets = try_map(
46 skip_semicolon(skip_first(
47 directive_exact_p("calltargets"),
48 sep_by1(FunctionSymbol::parse(), comma_p()),
49 )),
50 |targets, span| {
51 let directive = CallTargetsDirective {
52 targets,
53 span: span.clone(),
54 };
55 ok!(StatementDirective::CallTargets { directive })
56 },
57 );
58
59 let call_prototype = try_map(
60 skip_semicolon(skip_first(
61 directive_exact_p("callprototype"),
62 seq5(
63 return_spec_parser(),
64 parameter_list_parser(),
65 noreturn_parser(),
66 abi_preserve_parser(),
67 abi_preserve_control_parser(),
68 ),
69 )),
70 |(return_param, params, noreturn, abi_preserve, abi_preserve_control), span| {
71 let directive = CallPrototypeDirective {
72 return_param,
73 params,
74 noreturn,
75 abi_preserve,
76 abi_preserve_control,
77 span: span.clone(),
78 };
79 ok!(StatementDirective::CallPrototype { directive })
80 },
81 );
82
83 let location = mapc!(location_directive(), StatementDirective::Loc { directive });
84
85 let reg_stmt = mapc!(register_statement(), StatementDirective::Reg { directive });
86
87 let local_stmt = mapc!(
88 skip_first(directive_exact_p("local"), VariableDirective::parse()),
89 StatementDirective::Local { directive }
90 );
91
92 let param_stmt = mapc!(
93 skip_first(directive_exact_p("param"), VariableDirective::parse()),
94 StatementDirective::Param { directive }
95 );
96
97 let shared_stmt = mapc!(
98 skip_first(directive_exact_p("shared"), VariableDirective::parse()),
99 StatementDirective::Shared { directive }
100 );
101
102 alt!(
103 location,
104 reg_stmt,
105 local_stmt,
106 param_stmt,
107 shared_stmt,
108 branch_targets,
109 call_targets,
110 call_prototype,
111 dwarf_directive(),
112 pragma_directive(),
113 section_directive()
114 )
115 }
116}
117
118impl PtxParser for SectionDirective {
119 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
120 map(
121 skip_first(
122 directive_exact_p("section"),
123 seq(section_name_parser(), section_body_parser()),
124 ),
125 |(name, entries), span| {
126 c!(SectionDirective {
127 name = name,
128 entries,
129 })
130 },
131 )
132 }
133}
134
135impl PtxParser for DwarfDirective {
136 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
137 skip_first(directive_exact_p("dwarf"), dwarf_kind_parser())
138 }
139}
140
141impl PtxParser for FunctionStatement {
142 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
143 let label_stmt = map(seq(Label::parse(), colon_p()), |(label, _), span| {
144 c!(FunctionStatement::Label { label })
145 });
146
147 let block_stmt = move |stream: &mut PtxTokenStream| {
148 map(
149 between(lbrace_p(), rbrace_p(), many(FunctionStatement::parse())),
150 |statements, span| c!(FunctionStatement::Block { statements }),
151 )(stream)
152 };
153
154 let directive_stmt = mapc!(
155 StatementDirective::parse(),
156 FunctionStatement::Directive { directive }
157 );
158
159 let instruction_stmt = mapc!(
160 Instruction::parse(),
161 FunctionStatement::Instruction { instruction }
162 );
163
164 alt!(label_stmt, block_stmt, directive_stmt, instruction_stmt)
165 }
166}
167
168fn return_spec_parser()
169-> impl Fn(&mut PtxTokenStream) -> Result<(Option<ParameterDirective>, Span), PtxParseError> {
170 alt(
171 map(ParameterDirective::parse(), |param, _| Some(param)),
172 map(underscore_placeholder(), |_, _| None),
173 )
174}
175
176fn underscore_placeholder() -> impl Fn(&mut PtxTokenStream) -> Result<((), Span), PtxParseError> {
177 try_map(identifier_p(), |name, span| {
178 if name == "_" {
179 Ok(())
180 } else {
181 Err(PtxParseError {
182 kind: ParseErrorKind::UnexpectedToken {
183 expected: vec!["identifier `_`".into()],
184 found: name,
185 },
186 span,
187 })
188 }
189 })
190}
191
192fn parameter_list_parser()
193-> impl Fn(&mut PtxTokenStream) -> Result<(Vec<ParameterDirective>, Span), PtxParseError> {
194 map(
195 between(
196 lparen_p(),
197 rparen_p(),
198 sep_by(ParameterDirective::parse(), comma_p()),
199 ),
200 |params, _| params,
201 )
202}
203
204fn noreturn_parser() -> impl Fn(&mut PtxTokenStream) -> Result<(bool, Span), PtxParseError> {
205 map(optional(directive_exact_p("noreturn")), |flag, _| {
206 flag.is_some()
207 })
208}
209
210fn abi_preserve_parser()
211-> impl Fn(&mut PtxTokenStream) -> Result<(Option<u32>, Span), PtxParseError> {
212 map(
213 optional(skip_first(directive_exact_p("abi_preserve"), u32_p())),
214 |value, _span| value,
215 )
216}
217
218fn abi_preserve_control_parser()
219-> impl Fn(&mut PtxTokenStream) -> Result<(Option<u32>, Span), PtxParseError> {
220 map(
221 optional(skip_first(
222 directive_exact_p("abi_preserve_control"),
223 u32_p(),
224 )),
225 |value, _span| value,
226 )
227}
228fn dwarf_directive()
229-> impl Fn(&mut PtxTokenStream) -> Result<(StatementDirective, Span), PtxParseError> {
230 mapc!(
231 skip_semicolon(DwarfDirective::parse()),
232 StatementDirective::Dwarf { directive }
233 )
234}
235
236fn dwarf_kind_parser()
237-> impl Fn(&mut PtxTokenStream) -> Result<(DwarfDirective, Span), PtxParseError> {
238 let byte_values = try_map(
239 seq(
240 directive_exact_p("byte"),
241 sep_by1(unsigned_integer_literal(), comma_p()),
242 ),
243 |(_, values), span| {
244 let mut parsed = Vec::new();
245 for (text, value_span) in values {
246 let value = parse_unsigned_integer(&text, value_span, 0, u8::MAX as u128)?;
247 parsed.push(value as u8);
248 }
249 ok!(DwarfDirective {
250 kind = DwarfDirectiveKind::ByteValues(parsed)
251 })
252 },
253 );
254
255 let four_byte_values = try_map(
256 seq(
257 four_byte_keyword(),
258 sep_by1(unsigned_integer_literal(), comma_p()),
259 ),
260 |(_, values), span| {
261 let mut parsed = Vec::new();
262 for (text, value_span) in values {
263 let value = parse_unsigned_integer(&text, value_span, 0, u32::MAX as u128)?;
264 parsed.push(value as u32);
265 }
266 ok!(DwarfDirective {
267 kind = DwarfDirectiveKind::FourByteValues(parsed)
268 })
269 },
270 );
271
272 let four_byte_label = try_map(
273 seq(four_byte_keyword(), Label::parse()),
274 |(_, label), span| {
275 ok!(DwarfDirective {
276 kind = DwarfDirectiveKind::FourByteLabel(label)
277 })
278 },
279 );
280
281 let quad_values = try_map(
282 seq(
283 directive_exact_p("quad"),
284 sep_by1(unsigned_integer_literal(), comma_p()),
285 ),
286 |(_, values), span| {
287 let mut parsed = Vec::new();
288 for (text, value_span) in values {
289 let value = parse_unsigned_integer(&text, value_span, 0, u64::MAX as u128)?;
290 parsed.push(value as u64);
291 }
292 ok!(DwarfDirective {
293 kind = DwarfDirectiveKind::QuadValues(parsed)
294 })
295 },
296 );
297
298 let quad_label = try_map(
299 seq(directive_exact_p("quad"), Label::parse()),
300 |(_, label), span| {
301 ok!(DwarfDirective {
302 kind = DwarfDirectiveKind::QuadLabel(label)
303 })
304 },
305 );
306
307 alt!(
308 byte_values,
309 four_byte_label,
310 quad_label,
311 four_byte_values,
312 quad_values
313 )
314}
315
316fn pragma_directive()
317-> impl Fn(&mut PtxTokenStream) -> Result<(StatementDirective, Span), PtxParseError> {
318 try_map(
319 skip_semicolon(seq(directive_exact_p("pragma"), string_literal_p())),
320 |(_, text), span| {
321 let kind = match text.trim() {
322 "nounroll" => PragmaDirectiveKind::Nounroll,
323 "enable_smem_spilling" => PragmaDirectiveKind::EnableSmemSpilling,
324 other if other.starts_with("used_bytes_mask") => {
325 let mask = other["used_bytes_mask".len()..].trim().to_string();
326 PragmaDirectiveKind::UsedBytesMask { mask }
327 }
328 other if other.starts_with("frequency") => {
329 let value_str = other["frequency".len()..].trim();
330 let value = parse_u32_literal(value_str, span)?;
331 PragmaDirectiveKind::Frequency { value }
332 }
333 other => PragmaDirectiveKind::Raw(other.to_string()),
334 };
335 let directive = c!(PragmaDirective { kind });
336 ok!(StatementDirective::Pragma { directive })
337 },
338 )
339}
340
341fn section_directive()
342-> impl Fn(&mut PtxTokenStream) -> Result<(StatementDirective, Span), PtxParseError> {
343 mapc!(
344 SectionDirective::parse(),
345 StatementDirective::Section { directive }
346 )
347}
348
349fn register_statement()
350-> impl Fn(&mut PtxTokenStream) -> Result<(RegisterDirective, Span), PtxParseError> {
351 mapc!(
352 skip_semicolon(seq(
353 skip_first(directive_exact_p("reg"), DataType::parse()),
354 register_targets_parser(),
355 )),
356 RegisterDirective { ty, registers }
357 )
358}
359
360fn register_targets_parser()
361-> impl Fn(&mut PtxTokenStream) -> Result<(Vec<RegisterTarget>, Span), PtxParseError> {
362 map(
363 sep_by1(
364 seq(register_symbol(), optional(register_count())),
365 comma_p(),
366 ),
367 |entries, _span| {
368 let registers = entries
369 .into_iter()
370 .map(|(symbol, range)| {
371 let symbol_span = symbol.span;
372 RegisterTarget {
373 name: symbol,
374 range,
375 span: symbol_span,
376 }
377 })
378 .collect();
379 registers
380 },
381 )
382}
383
384fn register_symbol() -> impl Fn(&mut PtxTokenStream) -> Result<(VariableSymbol, Span), PtxParseError>
385{
386 alt(
387 map(register_p(), |name, span| VariableSymbol {
388 val: name,
389 span,
390 }),
391 map(identifier_p(), |val, span| VariableSymbol { val, span }),
392 )
393}
394
395fn register_count() -> impl Fn(&mut PtxTokenStream) -> Result<(u32, Span), PtxParseError> {
396 between(langle_p(), rangle_p(), u32_p())
397}
398
399fn location_directive()
400-> impl Fn(&mut PtxTokenStream) -> Result<(LocationDirective, Span), PtxParseError> {
401 mapc!(
402 seq_n!(
403 skip_first(directive_exact_p("loc"), u32_p()),
404 u32_p(),
405 u32_p(),
406 pure(Option::<LocationInlinedAt>::None)
407 ),
408 LocationDirective {
409 file_index,
410 line,
411 column,
412 inlined_at,
413 }
414 )
415}
416
417fn section_name_parser() -> impl Fn(&mut PtxTokenStream) -> Result<(String, Span), PtxParseError> {
418 alt(
419 map(directive_p(), func!(|name| format!(".{name}"))),
420 identifier_p(),
421 )
422}
423
424fn section_body_parser()
425-> impl Fn(&mut PtxTokenStream) -> Result<(Vec<SectionEntry>, Span), PtxParseError> {
426 between(lbrace_p(), rbrace_p(), many(section_entry_parser()))
427}
428
429fn section_entry_parser()
430-> impl Fn(&mut PtxTokenStream) -> Result<(SectionEntry, Span), PtxParseError> {
431 alt(
432 label_entry(),
433 map(
434 section_directive_line(),
435 func!(|line| SectionEntry::Directive(line)),
436 ),
437 )
438}
439
440fn label_entry() -> impl Fn(&mut PtxTokenStream) -> Result<(SectionEntry, Span), PtxParseError> {
441 map(seq(Label::parse(), colon_p()), |(label, _), span| {
442 SectionEntry::Label { label, span }
443 })
444}
445
446fn section_directive_line()
447-> impl Fn(&mut PtxTokenStream) -> Result<(StatementSectionDirectiveLine, Span), PtxParseError> {
448 let b8 = try_map(
449 skip_semicolon(skip_first(
450 directive_exact_p("b8"),
451 sep_by1(signed_integer_literal(), comma_p()),
452 )),
453 func!(|values| {
454 let mut out = Vec::new();
455 for (text, value_span) in values {
456 let value = parse_signed_integer(&text, value_span, -128, 255)?;
457 out.push(value as i16);
458 }
459 ok!(StatementSectionDirectiveLine::B8 { values = out })
460 }),
461 );
462
463 let b16 = try_map(
464 skip_semicolon(skip_first(
465 directive_exact_p("b16"),
466 sep_by1(signed_integer_literal(), comma_p()),
467 )),
468 |values, span| {
469 let mut out = Vec::new();
470 for (text, value_span) in values {
471 let value = parse_signed_integer(&text, value_span, -32_768, 65_535)?;
472 out.push(value as i32);
473 }
474 ok!(StatementSectionDirectiveLine::B16 { values = out })
475 },
476 );
477
478 let b32 = try_map(
479 skip_semicolon(skip_first(directive_exact_p("b32"), b32_section_suffix())),
480 |line, span| Ok(line.with_span(span)),
481 );
482
483 let b64 = try_map(
484 skip_semicolon(skip_first(directive_exact_p("b64"), b64_section_suffix())),
485 |line, span| Ok(line.with_span(span)),
486 );
487
488 alt!(b8, b16, b32, b64)
489}
490
491fn b32_section_suffix()
492-> impl Fn(&mut PtxTokenStream) -> Result<(StatementSectionDirectiveLine, Span), PtxParseError> {
493 let immediate = try_map(
494 sep_by1(signed_integer_literal(), comma_p()),
495 |values, span| {
496 let mut out = Vec::new();
497 for (text, value_span) in values {
498 let value =
499 parse_signed_integer(&text, value_span, i64::MIN as i128, i64::MAX as i128)?;
500 out.push(value as i64);
501 }
502 ok!(StatementSectionDirectiveLine::B32Immediate { values = out })
503 },
504 );
505
506 let label_diff = try_map(
507 seq_n!(Label::parse(), minus_p(), Label::parse()),
508 |(left, _, right), span| {
509 ok!(StatementSectionDirectiveLine::B32LabelDiff {
510 entries = (left, right)
511 })
512 },
513 );
514
515 let label_plus = try_map(
516 seq_n!(
517 Label::parse(),
518 alt(map(plus_p(), |_, _| 1i32), map(minus_p(), |_, _| -1i32)),
519 integer_p(),
520 ),
521 |(label, sign, digits), span| {
522 let limit = if sign < 0 {
523 (i32::MAX as u128) + 1
524 } else {
525 i32::MAX as u128
526 };
527 let magnitude = parse_unsigned_integer(&digits, span, 0, limit)? as i128;
528 let value = if sign < 0 { -magnitude } else { magnitude };
529 ok!(StatementSectionDirectiveLine::B32LabelPlusImm {
530 entries = (label, value as i32)
531 })
532 },
533 );
534
535 let label_only = map(
536 Label::parse(),
537 |label, span| c!(StatementSectionDirectiveLine::B32Label { labels = label }),
538 );
539
540 alt!(immediate, label_diff, label_plus, label_only)
541}
542
543fn b64_section_suffix()
544-> impl Fn(&mut PtxTokenStream) -> Result<(StatementSectionDirectiveLine, Span), PtxParseError> {
545 let immediate = try_map(
546 sep_by1(signed_integer_literal(), comma_p()),
547 |values, span| {
548 let mut out = Vec::new();
549 for (text, value_span) in values {
550 let value = parse_signed_integer(&text, value_span, i128::MIN, i128::MAX)?;
551 out.push(value);
552 }
553 ok!(StatementSectionDirectiveLine::B64Immediate { values = out })
554 },
555 );
556
557 let label_diff = try_map(
558 seq_n!(Label::parse(), minus_p(), Label::parse()),
559 |(left, _, right), span| {
560 ok!(StatementSectionDirectiveLine::B64LabelDiff {
561 entries = (left, right)
562 })
563 },
564 );
565
566 let label_plus = try_map(
567 seq_n!(
568 Label::parse(),
569 alt(map(plus_p(), |_, _| 1i32), map(minus_p(), |_, _| -1i32)),
570 integer_p(),
571 ),
572 |(label, sign, digits), span| {
573 let limit = if sign < 0 {
574 (i64::MAX as u128) + 1
575 } else {
576 i64::MAX as u128
577 };
578 let magnitude = parse_unsigned_integer(&digits, span, 0, limit)? as i128;
579 let value = if sign < 0 { -magnitude } else { magnitude };
580 ok!(StatementSectionDirectiveLine::B64LabelPlusImm {
581 entries = (label, value as i64)
582 })
583 },
584 );
585
586 let label_only = map(
587 Label::parse(),
588 |label, span| c!(StatementSectionDirectiveLine::B64Label { labels = label }),
589 );
590
591 alt!(immediate, label_diff, label_plus, label_only)
592}
593
594fn signed_integer_literal()
595-> impl Fn(&mut PtxTokenStream) -> Result<((String, Span), Span), PtxParseError> {
596 map(
597 seq(
598 optional(alt(map(minus_p(), |_, _| '-'), map(plus_p(), |_, _| '+'))),
599 integer_p(),
600 ),
601 |(sign, digits), span| {
602 let mut value = String::new();
603 if let Some(ch) = sign {
604 if ch == '-' {
605 value.push('-');
606 }
607 }
608 value.push_str(&digits);
609 (value, span)
610 },
611 )
612}
613
614fn unsigned_integer_literal()
615-> impl Fn(&mut PtxTokenStream) -> Result<((String, Span), Span), PtxParseError> {
616 map(integer_p(), |digits, span| (digits, span))
617}
618
619fn four_byte_keyword() -> impl Fn(&mut PtxTokenStream) -> Result<((), Span), PtxParseError> {
620 move |stream| {
621 stream.try_with_span(|stream| {
622 stream.expect(&PtxToken::Dot)?;
623 let (value, value_span) = integer_p()(stream)?;
624 if value != "4" {
625 return Err(crate::unexpected_value!(value_span, &["4"], value));
626 }
627 let (name, name_span) = identifier_p()(stream)?;
628 if name != "byte" {
629 return Err(crate::unexpected_value!(name_span, &["byte"], name));
630 }
631 Ok(())
632 })
633 }
634}
635
636impl PtxParser for AliasFunctionDirective {
637 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
638 use crate::parser::util::{comma_p, directive_exact_p, semicolon_p, skip_first};
639
640 try_map(
641 seq_n!(
642 skip_first(directive_exact_p("alias"), FunctionSymbol::parse()),
643 skip_first(comma_p(), FunctionSymbol::parse()),
644 semicolon_p()
645 ),
646 |(alias, target, _), span| ok!(AliasFunctionDirective { alias, target }),
647 )
648 }
649}
650
651impl PtxParser for FunctionBody {
652 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
653 try_map(
654 between(lbrace_p(), rbrace_p(), many(FunctionStatement::parse())),
655 |statements, span| ok!(FunctionBody { statements }),
656 )
657 }
658}
659
660impl PtxParser for FuncFunctionDirective {
661 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
662 let return_spec = alt(
663 map(
664 between(
665 lparen_p(),
666 rparen_p(),
667 optional(ParameterDirective::parse()),
668 ),
669 |param, _| param,
670 ),
671 map(optional(ParameterDirective::parse()), |param, _| param),
672 );
673
674 let body_or_prototype = alt(
675 map(FunctionBody::parse(), |body, _| Some(body)),
676 map(semicolon_p(), |_, _| None),
677 );
678
679 mapc!(
680 seq_n!(
681 skip_first(directive_exact_p("func"), many(AttributeDirective::parse())),
682 return_spec,
683 FunctionSymbol::parse(),
684 between(
685 lparen_p(),
686 rparen_p(),
687 sep_by(ParameterDirective::parse(), comma_p()),
688 ),
689 many(FuncFunctionHeaderDirective::parse()),
690 body_or_prototype,
691 ),
692 FuncFunctionDirective {
693 attributes,
694 return_param,
695 name,
696 params,
697 directives,
698 body
699 }
700 )
701 }
702}
703
704impl PtxParser for EntryFunctionDirective {
705 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
706 mapc!(
707 seq_n!(
708 skip_first(directive_exact_p("entry"), FunctionSymbol::parse()),
709 between(
710 lparen_p(),
711 rparen_p(),
712 sep_by(ParameterDirective::parse(), comma_p()),
713 ),
714 many(EntryFunctionHeaderDirective::parse()),
715 optional(FunctionBody::parse()),
716 ),
717 EntryFunctionDirective {
718 name,
719 params,
720 directives,
721 body,
722 }
723 )
724 }
725}
726
727impl PtxParser for FuncFunctionHeaderDirective {
728 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
729 alt!(
730 mapc!(
731 directive_exact_p("noreturn"),
732 FuncFunctionHeaderDirective::NoReturn {}
733 ),
734 mapc!(
735 skip_first(
736 directive_exact_p("pragma"),
737 seq(sep_by1(string_literal_p(), comma_p()), semicolon_p())
738 ),
739 FuncFunctionHeaderDirective::Pragma { args, _ }
740 ),
741 mapc!(
742 skip_first(directive_exact_p("abi_preserve"), u32_p()),
743 FuncFunctionHeaderDirective::AbiPreserve { value }
744 ),
745 mapc!(
746 skip_first(directive_exact_p("abi_preserve_control"), u32_p()),
747 FuncFunctionHeaderDirective::AbiPreserveControl { value }
748 )
749 )
750 }
751}
752
753impl PtxParser for EntryFunctionHeaderDirective {
754 fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
755 alt!(
756 mapc!(
757 skip_first(directive_exact_p("maxnreg"), u32_p()),
758 EntryFunctionHeaderDirective::MaxNReg { value }
759 ),
760 try_map(
761 skip_first(directive_exact_p("maxntid"), sep_by1(u32_p(), comma_p())),
762 |dim_strs, span| {
763 let dim = parse_function_dim(&dim_strs, span)?;
764 ok!(EntryFunctionHeaderDirective::MaxNTid { dim })
765 }
766 ),
767 try_map(
768 skip_first(directive_exact_p("reqntid"), sep_by1(u32_p(), comma_p())),
769 |dim_strs, span| {
770 let dim = parse_function_dim(&dim_strs, span)?;
771 ok!(EntryFunctionHeaderDirective::ReqNTid { dim })
772 }
773 ),
774 mapc!(
775 skip_first(directive_exact_p("minnctapersm"), u32_p()),
776 EntryFunctionHeaderDirective::MinNCtaPerSm { value }
777 ),
778 mapc!(
779 skip_first(directive_exact_p("maxnctapersm"), u32_p()),
780 EntryFunctionHeaderDirective::MaxNCtaPerSm { value }
781 ),
782 mapc!(
783 skip_first(
784 directive_exact_p("pragma"),
785 skip_second(sep_by1(string_literal_p(), comma_p()), semicolon_p())
786 ),
787 EntryFunctionHeaderDirective::Pragma { args }
788 )
789 )
790 }
791}
792
793fn parse_function_dim(dims: &[u32], span: Span) -> Result<FunctionDim, PtxParseError> {
794 match dims.len() {
795 1 => {
796 let x = dims[0];
797 Ok(FunctionDim::X { x, span })
798 }
799 2 => {
800 let x = dims[0];
801 let y = dims[1];
802 Ok(FunctionDim::XY { x, y, span })
803 }
804 3 => {
805 let x = dims[0];
806 let y = dims[1];
807 let z = dims[2];
808 Ok(FunctionDim::XYZ { x, y, z, span })
809 }
810 _ => Err(PtxParseError {
811 kind: ParseErrorKind::InvalidLiteral(format!(
812 "expected 1-3 dimensions, got {}",
813 dims.len()
814 )),
815 span,
816 }),
817 }
818}