ptx_parser/parser/
function.rs

1use crate::{
2    alt, c,
3    lexer::PtxToken,
4    mapc, ok,
5    parser::{
6        ParseErrorKind, PtxParseError, PtxParser, PtxTokenStream, Span,
7        util::{
8            alt, between, colon_p, comma_p, directive_exact_p, directive_p, identifier_p,
9            integer_p, langle_p, lbrace_p, lparen_p, many, map, minus_p, optional,
10            parse_signed_integer, parse_u32_literal, parse_unsigned_integer, plus_p, rangle_p,
11            rbrace_p, register_p, rparen_p, semicolon_p, sep_by, sep_by1, seq, seq5, skip_first,
12            skip_semicolon, string_literal_p, try_map,
13        },
14    },
15    seq_n,
16    r#type::{
17        AliasFunctionDirective, AttributeDirective, BranchTargetsDirective, CallPrototypeDirective,
18        CallTargetsDirective, DataType, DwarfDirective, DwarfDirectiveKind, EntryFunctionDirective,
19        EntryFunctionHeaderDirective, FuncFunctionDirective, FuncFunctionHeaderDirective,
20        FunctionBody, FunctionDim, FunctionStatement, FunctionSymbol, Instruction, Label,
21        LocationDirective, ParameterDirective, PragmaDirective, PragmaDirectiveKind,
22        RegisterDirective, RegisterTarget, SectionDirective, SectionEntry, StatementDirective,
23        StatementSectionDirectiveLine, VariableDirective, VariableSymbol,
24    },
25};
26
27impl PtxParser for StatementDirective {
28    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
29        let branch_targets = try_map(
30            skip_semicolon(skip_first(
31                directive_exact_p("branchtargets"),
32                sep_by1(Label::parse(), comma_p()),
33            )),
34            |labels, span| {
35                let directive = BranchTargetsDirective {
36                    labels,
37                    span: span.clone(),
38                };
39                ok!(StatementDirective::BranchTargets { directive })
40            },
41        );
42
43        let call_targets = try_map(
44            skip_semicolon(skip_first(
45                directive_exact_p("calltargets"),
46                sep_by1(FunctionSymbol::parse(), comma_p()),
47            )),
48            |targets, span| {
49                let directive = CallTargetsDirective {
50                    targets,
51                    span: span.clone(),
52                };
53                ok!(StatementDirective::CallTargets { directive })
54            },
55        );
56
57        let call_prototype = try_map(
58            skip_semicolon(skip_first(
59                directive_exact_p("callprototype"),
60                seq5(
61                    return_spec_parser(),
62                    parameter_list_parser(),
63                    noreturn_parser(),
64                    abi_preserve_parser(),
65                    abi_preserve_control_parser(),
66                ),
67            )),
68            |(return_param, params, noreturn, abi_preserve, abi_preserve_control), span| {
69                let directive = CallPrototypeDirective {
70                    return_param,
71                    params,
72                    noreturn,
73                    abi_preserve,
74                    abi_preserve_control,
75                    span: span.clone(),
76                };
77                ok!(StatementDirective::CallPrototype { directive })
78            },
79        );
80
81        let location = mapc!(location_directive(), StatementDirective::Loc { directive });
82
83        let reg_stmt = mapc!(register_statement(), StatementDirective::Reg { directive });
84
85        let local_stmt = mapc!(
86            skip_first(directive_exact_p("local"), VariableDirective::parse()),
87            StatementDirective::Local { directive }
88        );
89
90        let param_stmt = mapc!(
91            skip_first(directive_exact_p("param"), VariableDirective::parse()),
92            StatementDirective::Param { directive }
93        );
94
95        let shared_stmt = mapc!(
96            skip_first(directive_exact_p("shared"), VariableDirective::parse()),
97            StatementDirective::Shared { directive }
98        );
99
100        alt!(
101            location,
102            reg_stmt,
103            local_stmt,
104            param_stmt,
105            shared_stmt,
106            branch_targets,
107            call_targets,
108            call_prototype,
109            dwarf_directive(),
110            pragma_directive(),
111            section_directive()
112        )
113    }
114}
115
116impl PtxParser for SectionDirective {
117    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
118        map(
119            skip_first(
120                directive_exact_p("section"),
121                seq(section_name_parser(), section_body_parser()),
122            ),
123            |(name, entries), span| {
124                c!(SectionDirective {
125                    name = name,
126                    entries,
127                })
128            },
129        )
130    }
131}
132
133impl PtxParser for DwarfDirective {
134    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
135        skip_first(directive_exact_p("dwarf"), dwarf_kind_parser())
136    }
137}
138
139impl PtxParser for FunctionStatement {
140    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
141        let label_stmt = map(seq(Label::parse(), colon_p()), |(label, _), span| {
142            c!(FunctionStatement::Label { label })
143        });
144
145        let block_stmt = move |stream: &mut PtxTokenStream| {
146            map(
147                between(lbrace_p(), rbrace_p(), many(FunctionStatement::parse())),
148                |statements, span| c!(FunctionStatement::Block { statements }),
149            )(stream)
150        };
151
152        let directive_stmt = mapc!(
153            StatementDirective::parse(),
154            FunctionStatement::Directive { directive }
155        );
156
157        let instruction_stmt = mapc!(
158            Instruction::parse(),
159            FunctionStatement::Instruction { instruction }
160        );
161
162        alt!(label_stmt, block_stmt, directive_stmt, instruction_stmt)
163    }
164}
165
166fn return_spec_parser()
167-> impl Fn(&mut PtxTokenStream) -> Result<(Option<ParameterDirective>, Span), PtxParseError> {
168    alt(
169        map(ParameterDirective::parse(), |param, _| Some(param)),
170        map(underscore_placeholder(), |_, _| None),
171    )
172}
173
174fn underscore_placeholder() -> impl Fn(&mut PtxTokenStream) -> Result<((), Span), PtxParseError> {
175    try_map(identifier_p(), |name, span| {
176        if name == "_" {
177            Ok(())
178        } else {
179            Err(PtxParseError {
180                kind: ParseErrorKind::UnexpectedToken {
181                    expected: vec!["identifier `_`".into()],
182                    found: name,
183                },
184                span,
185            })
186        }
187    })
188}
189
190fn parameter_list_parser()
191-> impl Fn(&mut PtxTokenStream) -> Result<(Vec<ParameterDirective>, Span), PtxParseError> {
192    map(
193        between(
194            lparen_p(),
195            rparen_p(),
196            sep_by(ParameterDirective::parse(), comma_p()),
197        ),
198        |params, _| params,
199    )
200}
201
202fn noreturn_parser() -> impl Fn(&mut PtxTokenStream) -> Result<(bool, Span), PtxParseError> {
203    map(optional(directive_exact_p("noreturn")), |flag, _| {
204        flag.is_some()
205    })
206}
207
208fn abi_preserve_parser()
209-> impl Fn(&mut PtxTokenStream) -> Result<(Option<u32>, Span), PtxParseError> {
210    try_map(
211        optional(seq(directive_exact_p("abi_preserve"), integer_p())),
212        |value, span| {
213            let parsed = value
214                .map(|(_, literal)| parse_u32_literal(&literal, span))
215                .transpose()?;
216            Ok(parsed)
217        },
218    )
219}
220
221fn abi_preserve_control_parser()
222-> impl Fn(&mut PtxTokenStream) -> Result<(Option<u32>, Span), PtxParseError> {
223    try_map(
224        optional(seq(directive_exact_p("abi_preserve_control"), integer_p())),
225        |value, span| {
226            let parsed = value
227                .map(|(_, literal)| parse_u32_literal(&literal, span))
228                .transpose()?;
229            Ok(parsed)
230        },
231    )
232}
233fn dwarf_directive()
234-> impl Fn(&mut PtxTokenStream) -> Result<(StatementDirective, Span), PtxParseError> {
235    mapc!(
236        skip_semicolon(DwarfDirective::parse()),
237        StatementDirective::Dwarf { directive }
238    )
239}
240
241fn dwarf_kind_parser()
242-> impl Fn(&mut PtxTokenStream) -> Result<(DwarfDirective, Span), PtxParseError> {
243    let byte_values = try_map(
244        seq(
245            directive_exact_p("byte"),
246            sep_by1(unsigned_integer_literal(), comma_p()),
247        ),
248        |(_, values), span| {
249            let mut parsed = Vec::new();
250            for (text, value_span) in values {
251                let value = parse_unsigned_integer(&text, value_span, 0, u8::MAX as u128)?;
252                parsed.push(value as u8);
253            }
254            ok!(DwarfDirective {
255                kind = DwarfDirectiveKind::ByteValues(parsed)
256            })
257        },
258    );
259
260    let four_byte_values = try_map(
261        seq(
262            four_byte_keyword(),
263            sep_by1(unsigned_integer_literal(), comma_p()),
264        ),
265        |(_, values), span| {
266            let mut parsed = Vec::new();
267            for (text, value_span) in values {
268                let value = parse_unsigned_integer(&text, value_span, 0, u32::MAX as u128)?;
269                parsed.push(value as u32);
270            }
271            ok!(DwarfDirective {
272                kind = DwarfDirectiveKind::FourByteValues(parsed)
273            })
274        },
275    );
276
277    let four_byte_label = try_map(
278        seq(four_byte_keyword(), Label::parse()),
279        |(_, label), span| {
280            ok!(DwarfDirective {
281                kind = DwarfDirectiveKind::FourByteLabel(label)
282            })
283        },
284    );
285
286    let quad_values = try_map(
287        seq(
288            directive_exact_p("quad"),
289            sep_by1(unsigned_integer_literal(), comma_p()),
290        ),
291        |(_, values), span| {
292            let mut parsed = Vec::new();
293            for (text, value_span) in values {
294                let value = parse_unsigned_integer(&text, value_span, 0, u64::MAX as u128)?;
295                parsed.push(value as u64);
296            }
297            ok!(DwarfDirective {
298                kind = DwarfDirectiveKind::QuadValues(parsed)
299            })
300        },
301    );
302
303    let quad_label = try_map(
304        seq(directive_exact_p("quad"), Label::parse()),
305        |(_, label), span| {
306            ok!(DwarfDirective {
307                kind = DwarfDirectiveKind::QuadLabel(label)
308            })
309        },
310    );
311
312    alt!(
313        byte_values,
314        four_byte_label,
315        quad_label,
316        four_byte_values,
317        quad_values
318    )
319}
320
321fn pragma_directive()
322-> impl Fn(&mut PtxTokenStream) -> Result<(StatementDirective, Span), PtxParseError> {
323    try_map(
324        skip_semicolon(seq(directive_exact_p("pragma"), string_literal_p())),
325        |(_, text), span| {
326            let kind = match text.trim() {
327                "nounroll" => PragmaDirectiveKind::Nounroll,
328                "enable_smem_spilling" => PragmaDirectiveKind::EnableSmemSpilling,
329                other if other.starts_with("used_bytes_mask") => {
330                    let mask = other["used_bytes_mask".len()..].trim().to_string();
331                    PragmaDirectiveKind::UsedBytesMask { mask }
332                }
333                other if other.starts_with("frequency") => {
334                    let value_str = other["frequency".len()..].trim();
335                    let value = parse_u32_literal(value_str, span)?;
336                    PragmaDirectiveKind::Frequency { value }
337                }
338                other => PragmaDirectiveKind::Raw(other.to_string()),
339            };
340            let directive = c!(PragmaDirective { kind });
341            ok!(StatementDirective::Pragma { directive })
342        },
343    )
344}
345
346fn section_directive()
347-> impl Fn(&mut PtxTokenStream) -> Result<(StatementDirective, Span), PtxParseError> {
348    mapc!(
349        SectionDirective::parse(),
350        StatementDirective::Section { directive }
351    )
352}
353
354fn register_statement()
355-> impl Fn(&mut PtxTokenStream) -> Result<(RegisterDirective, Span), PtxParseError> {
356    try_map(
357        skip_semicolon(seq(
358            skip_first(directive_exact_p("reg"), DataType::parse()),
359            register_targets_parser(),
360        )),
361        |(ty, registers), span| {
362            ok!(RegisterDirective {
363                ty,
364                registers,
365            })
366        },
367    )
368}
369
370fn register_targets_parser()
371-> impl Fn(&mut PtxTokenStream) -> Result<(Vec<RegisterTarget>, Span), PtxParseError> {
372    map(
373        sep_by1(
374            seq(register_symbol(), optional(register_count())),
375            comma_p(),
376        ),
377        |entries, _span| {
378            let registers = entries
379                .into_iter()
380                .map(|(symbol, range)| {
381                    let symbol_span = symbol.span;
382                    RegisterTarget {
383                        name: symbol,
384                        range,
385                        span: symbol_span,
386                    }
387                })
388                .collect();
389            registers
390        },
391    )
392}
393
394fn register_symbol() -> impl Fn(&mut PtxTokenStream) -> Result<(VariableSymbol, Span), PtxParseError>
395{
396    alt(
397        map(register_p(), |name, span| VariableSymbol {
398            val: name,
399            span,
400        }),
401        map(identifier_p(), |val, span| VariableSymbol { val, span }),
402    )
403}
404
405fn register_count() -> impl Fn(&mut PtxTokenStream) -> Result<(u32, Span), PtxParseError> {
406    try_map(
407        between(langle_p(), rangle_p(), integer_p()),
408        |value, span| {
409            let count = parse_u32_literal(&value, span)?;
410            Ok(count)
411        },
412    )
413}
414
415fn location_directive()
416-> impl Fn(&mut PtxTokenStream) -> Result<(LocationDirective, Span), PtxParseError> {
417    try_map(
418        seq_n!(
419            skip_first(directive_exact_p("loc"), integer_p()),
420            integer_p(),
421            integer_p()
422        ),
423        |(file_idx, line_idx, col_idx), span| {
424            let file_index = parse_u32_literal(&file_idx, span)?;
425            let line = parse_u32_literal(&line_idx, span)?;
426            let column = parse_u32_literal(&col_idx, span)?;
427            Ok(LocationDirective {
428                file_index,
429                line,
430                column,
431                inlined_at: None,
432                span,
433            })
434        },
435    )
436}
437
438fn section_name_parser() -> impl Fn(&mut PtxTokenStream) -> Result<(String, Span), PtxParseError> {
439    alt(
440        map(directive_p(), |name, _| format!(".{name}")),
441        map(identifier_p(), |name, _| name),
442    )
443}
444
445fn section_body_parser()
446-> impl Fn(&mut PtxTokenStream) -> Result<(Vec<SectionEntry>, Span), PtxParseError> {
447    between(lbrace_p(), rbrace_p(), many(section_entry_parser()))
448}
449
450fn section_entry_parser()
451-> impl Fn(&mut PtxTokenStream) -> Result<(SectionEntry, Span), PtxParseError> {
452    alt(
453        label_entry(),
454        map(section_directive_line(), |line, _| {
455            SectionEntry::Directive(line)
456        }),
457    )
458}
459
460fn label_entry() -> impl Fn(&mut PtxTokenStream) -> Result<(SectionEntry, Span), PtxParseError> {
461    map(seq(Label::parse(), colon_p()), |(label, _), span| {
462        SectionEntry::Label { label, span }
463    })
464}
465
466fn section_directive_line()
467-> impl Fn(&mut PtxTokenStream) -> Result<(StatementSectionDirectiveLine, Span), PtxParseError> {
468    let b8 = try_map(
469        skip_semicolon(skip_first(
470            directive_exact_p("b8"),
471            sep_by1(signed_integer_literal(), comma_p()),
472        )),
473        |values, span| {
474            let mut out = Vec::new();
475            for (text, value_span) in values {
476                let value = parse_signed_integer(&text, value_span, -128, 255)?;
477                out.push(value as i16);
478            }
479            ok!(StatementSectionDirectiveLine::B8 { values = out })
480        },
481    );
482
483    let b16 = try_map(
484        skip_semicolon(skip_first(
485            directive_exact_p("b16"),
486            sep_by1(signed_integer_literal(), comma_p()),
487        )),
488        |values, span| {
489            let mut out = Vec::new();
490            for (text, value_span) in values {
491                let value = parse_signed_integer(&text, value_span, -32_768, 65_535)?;
492                out.push(value as i32);
493            }
494            ok!(StatementSectionDirectiveLine::B16 { values = out })
495        },
496    );
497
498    let b32 = try_map(
499        skip_semicolon(skip_first(directive_exact_p("b32"), b32_section_suffix())),
500        |line, span| Ok(line.with_span(span)),
501    );
502
503    let b64 = try_map(
504        skip_semicolon(skip_first(directive_exact_p("b64"), b64_section_suffix())),
505        |line, span| Ok(line.with_span(span)),
506    );
507
508    alt!(b8, b16, b32, b64)
509}
510
511fn b32_section_suffix()
512-> impl Fn(&mut PtxTokenStream) -> Result<(StatementSectionDirectiveLine, Span), PtxParseError> {
513    let immediate = try_map(
514        sep_by1(signed_integer_literal(), comma_p()),
515        |values, span| {
516            let mut out = Vec::new();
517            for (text, value_span) in values {
518                let value =
519                    parse_signed_integer(&text, value_span, i64::MIN as i128, i64::MAX as i128)?;
520                out.push(value as i64);
521            }
522            ok!(StatementSectionDirectiveLine::B32Immediate { values = out })
523        },
524    );
525
526    let label_diff = try_map(
527        seq_n!(Label::parse(), minus_p(), Label::parse()),
528        |(left, _, right), span| {
529            ok!(StatementSectionDirectiveLine::B32LabelDiff {
530                entries = (left, right)
531            })
532        },
533    );
534
535    let label_plus = try_map(
536        seq_n!(
537            Label::parse(),
538            alt(map(plus_p(), |_, _| 1i32), map(minus_p(), |_, _| -1i32)),
539            integer_p(),
540        ),
541        |(label, sign, digits), span| {
542            let limit = if sign < 0 {
543                (i32::MAX as u128) + 1
544            } else {
545                i32::MAX as u128
546            };
547            let magnitude = parse_unsigned_integer(&digits, span, 0, limit)? as i128;
548            let value = if sign < 0 { -magnitude } else { magnitude };
549            ok!(StatementSectionDirectiveLine::B32LabelPlusImm {
550                entries = (label, value as i32)
551            })
552        },
553    );
554
555    let label_only = map(
556        Label::parse(),
557        |label, span| c!(StatementSectionDirectiveLine::B32Label { labels = label }),
558    );
559
560    alt!(immediate, label_diff, label_plus, label_only)
561}
562
563fn b64_section_suffix()
564-> impl Fn(&mut PtxTokenStream) -> Result<(StatementSectionDirectiveLine, Span), PtxParseError> {
565    let immediate = try_map(
566        sep_by1(signed_integer_literal(), comma_p()),
567        |values, span| {
568            let mut out = Vec::new();
569            for (text, value_span) in values {
570                let value = parse_signed_integer(&text, value_span, i128::MIN, i128::MAX)?;
571                out.push(value);
572            }
573            ok!(StatementSectionDirectiveLine::B64Immediate { values = out })
574        },
575    );
576
577    let label_diff = try_map(
578        seq_n!(Label::parse(), minus_p(), Label::parse()),
579        |(left, _, right), span| {
580            ok!(StatementSectionDirectiveLine::B64LabelDiff {
581                entries = (left, right)
582            })
583        },
584    );
585
586    let label_plus = try_map(
587        seq_n!(
588            Label::parse(),
589            alt(map(plus_p(), |_, _| 1i32), map(minus_p(), |_, _| -1i32)),
590            integer_p(),
591        ),
592        |(label, sign, digits), span| {
593            let limit = if sign < 0 {
594                (i64::MAX as u128) + 1
595            } else {
596                i64::MAX as u128
597            };
598            let magnitude = parse_unsigned_integer(&digits, span, 0, limit)? as i128;
599            let value = if sign < 0 { -magnitude } else { magnitude };
600            ok!(StatementSectionDirectiveLine::B64LabelPlusImm {
601                entries = (label, value as i64)
602            })
603        },
604    );
605
606    let label_only = map(
607        Label::parse(),
608        |label, span| c!(StatementSectionDirectiveLine::B64Label { labels = label }),
609    );
610
611    alt!(immediate, label_diff, label_plus, label_only)
612}
613
614fn signed_integer_literal()
615-> impl Fn(&mut PtxTokenStream) -> Result<((String, Span), Span), PtxParseError> {
616    map(
617        seq(
618            optional(alt(map(minus_p(), |_, _| '-'), map(plus_p(), |_, _| '+'))),
619            integer_p(),
620        ),
621        |(sign, digits), span| {
622            let mut value = String::new();
623            if let Some(ch) = sign {
624                if ch == '-' {
625                    value.push('-');
626                }
627            }
628            value.push_str(&digits);
629            (value, span)
630        },
631    )
632}
633
634fn unsigned_integer_literal()
635-> impl Fn(&mut PtxTokenStream) -> Result<((String, Span), Span), PtxParseError> {
636    map(integer_p(), |digits, span| (digits, span))
637}
638
639fn four_byte_keyword() -> impl Fn(&mut PtxTokenStream) -> Result<((), Span), PtxParseError> {
640    move |stream| {
641        stream.try_with_span(|stream| {
642            stream.expect(&PtxToken::Dot)?;
643            let (value, value_span) = integer_p()(stream)?;
644            if value != "4" {
645                return Err(crate::unexpected_value!(value_span, &["4"], value));
646            }
647            let (name, name_span) = identifier_p()(stream)?;
648            if name != "byte" {
649                return Err(crate::unexpected_value!(name_span, &["byte"], name));
650            }
651            Ok(())
652        })
653    }
654}
655
656impl PtxParser for AliasFunctionDirective {
657    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
658        use crate::parser::util::{comma_p, directive_exact_p, semicolon_p, skip_first};
659
660        try_map(
661            seq_n!(
662                skip_first(directive_exact_p("alias"), FunctionSymbol::parse()),
663                skip_first(comma_p(), FunctionSymbol::parse()),
664                semicolon_p()
665            ),
666            |(alias, target, _), span| ok!(AliasFunctionDirective { alias, target }),
667        )
668    }
669}
670
671impl PtxParser for FunctionBody {
672    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
673        try_map(
674            between(lbrace_p(), rbrace_p(), many(FunctionStatement::parse())),
675            |statements, span| ok!(FunctionBody { statements }),
676        )
677    }
678}
679
680impl PtxParser for FuncFunctionDirective {
681    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
682        let return_spec = alt(
683            map(
684                between(
685                    lparen_p(),
686                    rparen_p(),
687                    optional(ParameterDirective::parse()),
688                ),
689                |param, _| param,
690            ),
691            map(optional(ParameterDirective::parse()), |param, _| param),
692        );
693
694        let body_or_prototype = alt(
695            map(FunctionBody::parse(), |body, _| Some(body)),
696            map(semicolon_p(), |_, _| None),
697        );
698
699        mapc!(
700            seq_n!(
701                skip_first(directive_exact_p("func"), many(AttributeDirective::parse())),
702                return_spec,
703                FunctionSymbol::parse(),
704                between(
705                    lparen_p(),
706                    rparen_p(),
707                    sep_by(ParameterDirective::parse(), comma_p()),
708                ),
709                many(FuncFunctionHeaderDirective::parse()),
710                body_or_prototype,
711            ),
712            FuncFunctionDirective {
713                attributes,
714                return_param,
715                name,
716                params,
717                directives,
718                body
719            }
720        )
721    }
722}
723
724impl PtxParser for EntryFunctionDirective {
725    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
726        mapc!(
727            seq_n!(
728                skip_first(directive_exact_p("entry"), FunctionSymbol::parse()),
729                between(
730                    lparen_p(),
731                    rparen_p(),
732                    sep_by(ParameterDirective::parse(), comma_p()),
733                ),
734                many(EntryFunctionHeaderDirective::parse()),
735                optional(FunctionBody::parse()),
736            ),
737            EntryFunctionDirective {
738                name,
739                params,
740                directives,
741                body,
742            }
743        )
744    }
745}
746
747impl PtxParser for FuncFunctionHeaderDirective {
748    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
749        alt!(
750            mapc!(
751                directive_exact_p("noreturn"),
752                FuncFunctionHeaderDirective::NoReturn {}
753            ),
754            mapc!(
755                skip_first(
756                    directive_exact_p("pragma"),
757                    seq(sep_by1(string_literal_p(), comma_p()), semicolon_p())
758                ),
759                FuncFunctionHeaderDirective::Pragma { args, _ }
760            ),
761            try_map(
762                skip_first(directive_exact_p("abi_preserve"), integer_p()),
763                |val_str, span| {
764                    let value = parse_u32_literal(&val_str, span)?;
765                    ok!(FuncFunctionHeaderDirective::AbiPreserve { value })
766                }
767            ),
768            try_map(
769                skip_first(directive_exact_p("abi_preserve_control"), integer_p()),
770                |val_str, span| {
771                    let value = parse_u32_literal(&val_str, span)?;
772                    ok!(FuncFunctionHeaderDirective::AbiPreserveControl { value })
773                }
774            )
775        )
776    }
777}
778
779impl PtxParser for EntryFunctionHeaderDirective {
780    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
781        alt!(
782            try_map(
783                skip_first(directive_exact_p("maxnreg"), integer_p()),
784                |val_str, span| {
785                    let value = parse_u32_literal(&val_str, span)?;
786                    ok!(EntryFunctionHeaderDirective::MaxNReg { value })
787                }
788            ),
789            try_map(
790                skip_first(
791                    directive_exact_p("maxntid"),
792                    sep_by1(integer_p(), comma_p())
793                ),
794                |dim_strs, span| {
795                    let dim = parse_function_dim(&dim_strs, span)?;
796                    ok!(EntryFunctionHeaderDirective::MaxNTid { dim })
797                }
798            ),
799            try_map(
800                skip_first(
801                    directive_exact_p("reqntid"),
802                    sep_by1(integer_p(), comma_p())
803                ),
804                |dim_strs, span| {
805                    let dim = parse_function_dim(&dim_strs, span)?;
806                    ok!(EntryFunctionHeaderDirective::ReqNTid { dim })
807                }
808            ),
809            try_map(
810                skip_first(directive_exact_p("minnctapersm"), integer_p()),
811                |val_str, span| {
812                    let value = parse_u32_literal(&val_str, span)?;
813                    ok!(EntryFunctionHeaderDirective::MinNCtaPerSm { value })
814                }
815            ),
816            try_map(
817                skip_first(directive_exact_p("maxnctapersm"), integer_p()),
818                |val_str, span| {
819                    let value = parse_u32_literal(&val_str, span)?;
820                    ok!(EntryFunctionHeaderDirective::MaxNCtaPerSm { value })
821                }
822            ),
823            try_map(
824                skip_first(
825                    directive_exact_p("pragma"),
826                    seq(sep_by1(string_literal_p(), comma_p()), semicolon_p())
827                ),
828                |(args, _), span| { ok!(EntryFunctionHeaderDirective::Pragma { args }) }
829            )
830        )
831    }
832}
833
834fn parse_function_dim(dims: &[String], span: Span) -> Result<FunctionDim, PtxParseError> {
835    match dims.len() {
836        1 => {
837            let x = parse_u32_literal(&dims[0], span)?;
838            Ok(FunctionDim::X { x, span })
839        }
840        2 => {
841            let x = parse_u32_literal(&dims[0], span)?;
842            let y = parse_u32_literal(&dims[1], span)?;
843            Ok(FunctionDim::XY { x, y, span })
844        }
845        3 => {
846            let x = parse_u32_literal(&dims[0], span)?;
847            let y = parse_u32_literal(&dims[1], span)?;
848            let z = parse_u32_literal(&dims[2], span)?;
849            Ok(FunctionDim::XYZ { x, y, z, span })
850        }
851        _ => Err(PtxParseError {
852            kind: ParseErrorKind::InvalidLiteral(format!(
853                "expected 1-3 dimensions, got {}",
854                dims.len()
855            )),
856            span,
857        }),
858    }
859}