Skip to main content

ptx_parser/parser/
function.rs

1use ptx_90_parser_construct::func;
2
3use crate::{
4    alt, c,
5    lexer::PtxToken,
6    mapc, ok,
7    parser::{
8        ParseErrorKind, PtxParseError, PtxParser, PtxTokenStream, Span,
9        util::{
10            alt, between, colon_p, comma_p, directive_exact_p, directive_p, identifier_p,
11            integer_p, langle_p, lbrace_p, lparen_p, many, map, minus_p, optional,
12            parse_signed_integer, parse_u32_literal, parse_unsigned_integer, plus_p, pure,
13            rangle_p, rbrace_p, register_p, rparen_p, semicolon_p, sep_by, sep_by1, seq, seq5,
14            skip_first, skip_second, skip_semicolon, string_literal_p, try_map, u32_p,
15        },
16    },
17    seq_n,
18    r#type::{
19        AliasFunctionDirective, AttributeDirective, BranchTargetsDirective, CallPrototypeDirective,
20        CallTargetsDirective, DataType, DwarfDirective, DwarfDirectiveKind, EntryFunctionDirective,
21        EntryFunctionHeaderDirective, FuncFunctionDirective, FuncFunctionHeaderDirective,
22        FunctionBody, FunctionDim, FunctionStatement, FunctionSymbol, Instruction, Label,
23        LocationDirective, LocationInlinedAt, ParameterDirective, PragmaDirective,
24        PragmaDirectiveKind, RegisterDirective, RegisterTarget, SectionDirective, SectionEntry,
25        StatementDirective, StatementSectionDirectiveLine, VariableDirective, VariableSymbol,
26    },
27};
28
29impl PtxParser for StatementDirective {
30    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
31        let branch_targets = try_map(
32            skip_semicolon(skip_first(
33                directive_exact_p("branchtargets"),
34                sep_by1(Label::parse(), comma_p()),
35            )),
36            |labels, span| {
37                let directive = BranchTargetsDirective {
38                    labels,
39                    span: span.clone(),
40                };
41                ok!(StatementDirective::BranchTargets { directive })
42            },
43        );
44
45        let call_targets = try_map(
46            skip_semicolon(skip_first(
47                directive_exact_p("calltargets"),
48                sep_by1(FunctionSymbol::parse(), comma_p()),
49            )),
50            |targets, span| {
51                let directive = CallTargetsDirective {
52                    targets,
53                    span: span.clone(),
54                };
55                ok!(StatementDirective::CallTargets { directive })
56            },
57        );
58
59        let call_prototype = try_map(
60            skip_semicolon(skip_first(
61                directive_exact_p("callprototype"),
62                seq5(
63                    return_spec_parser(),
64                    parameter_list_parser(),
65                    noreturn_parser(),
66                    abi_preserve_parser(),
67                    abi_preserve_control_parser(),
68                ),
69            )),
70            |(return_param, params, noreturn, abi_preserve, abi_preserve_control), span| {
71                let directive = CallPrototypeDirective {
72                    return_param,
73                    params,
74                    noreturn,
75                    abi_preserve,
76                    abi_preserve_control,
77                    span: span.clone(),
78                };
79                ok!(StatementDirective::CallPrototype { directive })
80            },
81        );
82
83        let location = mapc!(location_directive(), StatementDirective::Loc { directive });
84
85        let reg_stmt = mapc!(register_statement(), StatementDirective::Reg { directive });
86
87        let local_stmt = mapc!(
88            skip_first(directive_exact_p("local"), VariableDirective::parse()),
89            StatementDirective::Local { directive }
90        );
91
92        let param_stmt = mapc!(
93            skip_first(directive_exact_p("param"), VariableDirective::parse()),
94            StatementDirective::Param { directive }
95        );
96
97        let shared_stmt = mapc!(
98            skip_first(directive_exact_p("shared"), VariableDirective::parse()),
99            StatementDirective::Shared { directive }
100        );
101
102        alt!(
103            location,
104            reg_stmt,
105            local_stmt,
106            param_stmt,
107            shared_stmt,
108            branch_targets,
109            call_targets,
110            call_prototype,
111            dwarf_directive(),
112            pragma_directive(),
113            section_directive()
114        )
115    }
116}
117
118impl PtxParser for SectionDirective {
119    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
120        map(
121            skip_first(
122                directive_exact_p("section"),
123                seq(section_name_parser(), section_body_parser()),
124            ),
125            |(name, entries), span| {
126                c!(SectionDirective {
127                    name = name,
128                    entries,
129                })
130            },
131        )
132    }
133}
134
135impl PtxParser for DwarfDirective {
136    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
137        skip_first(directive_exact_p("dwarf"), dwarf_kind_parser())
138    }
139}
140
141impl PtxParser for FunctionStatement {
142    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
143        let label_stmt = map(seq(Label::parse(), colon_p()), |(label, _), span| {
144            c!(FunctionStatement::Label { label })
145        });
146
147        let block_stmt = move |stream: &mut PtxTokenStream| {
148            map(
149                between(lbrace_p(), rbrace_p(), many(FunctionStatement::parse())),
150                |statements, span| c!(FunctionStatement::Block { statements }),
151            )(stream)
152        };
153
154        let directive_stmt = mapc!(
155            StatementDirective::parse(),
156            FunctionStatement::Directive { directive }
157        );
158
159        let instruction_stmt = mapc!(
160            Instruction::parse(),
161            FunctionStatement::Instruction { instruction }
162        );
163
164        alt!(label_stmt, block_stmt, directive_stmt, instruction_stmt)
165    }
166}
167
168fn return_spec_parser()
169-> impl Fn(&mut PtxTokenStream) -> Result<(Option<ParameterDirective>, Span), PtxParseError> {
170    alt(
171        map(ParameterDirective::parse(), |param, _| Some(param)),
172        map(underscore_placeholder(), |_, _| None),
173    )
174}
175
176fn underscore_placeholder() -> impl Fn(&mut PtxTokenStream) -> Result<((), Span), PtxParseError> {
177    try_map(identifier_p(), |name, span| {
178        if name == "_" {
179            Ok(())
180        } else {
181            Err(PtxParseError {
182                kind: ParseErrorKind::UnexpectedToken {
183                    expected: vec!["identifier `_`".into()],
184                    found: name,
185                },
186                span,
187            })
188        }
189    })
190}
191
192fn parameter_list_parser()
193-> impl Fn(&mut PtxTokenStream) -> Result<(Vec<ParameterDirective>, Span), PtxParseError> {
194    map(
195        between(
196            lparen_p(),
197            rparen_p(),
198            sep_by(ParameterDirective::parse(), comma_p()),
199        ),
200        |params, _| params,
201    )
202}
203
204fn noreturn_parser() -> impl Fn(&mut PtxTokenStream) -> Result<(bool, Span), PtxParseError> {
205    map(optional(directive_exact_p("noreturn")), |flag, _| {
206        flag.is_some()
207    })
208}
209
210fn abi_preserve_parser()
211-> impl Fn(&mut PtxTokenStream) -> Result<(Option<u32>, Span), PtxParseError> {
212    map(
213        optional(skip_first(directive_exact_p("abi_preserve"), u32_p())),
214        |value, _span| value,
215    )
216}
217
218fn abi_preserve_control_parser()
219-> impl Fn(&mut PtxTokenStream) -> Result<(Option<u32>, Span), PtxParseError> {
220    map(
221        optional(skip_first(
222            directive_exact_p("abi_preserve_control"),
223            u32_p(),
224        )),
225        |value, _span| value,
226    )
227}
228fn dwarf_directive()
229-> impl Fn(&mut PtxTokenStream) -> Result<(StatementDirective, Span), PtxParseError> {
230    mapc!(
231        skip_semicolon(DwarfDirective::parse()),
232        StatementDirective::Dwarf { directive }
233    )
234}
235
236fn dwarf_kind_parser()
237-> impl Fn(&mut PtxTokenStream) -> Result<(DwarfDirective, Span), PtxParseError> {
238    let byte_values = try_map(
239        seq(
240            directive_exact_p("byte"),
241            sep_by1(unsigned_integer_literal(), comma_p()),
242        ),
243        |(_, values), span| {
244            let mut parsed = Vec::new();
245            for (text, value_span) in values {
246                let value = parse_unsigned_integer(&text, value_span, 0, u8::MAX as u128)?;
247                parsed.push(value as u8);
248            }
249            ok!(DwarfDirective {
250                kind = DwarfDirectiveKind::ByteValues(parsed)
251            })
252        },
253    );
254
255    let four_byte_values = try_map(
256        seq(
257            four_byte_keyword(),
258            sep_by1(unsigned_integer_literal(), comma_p()),
259        ),
260        |(_, values), span| {
261            let mut parsed = Vec::new();
262            for (text, value_span) in values {
263                let value = parse_unsigned_integer(&text, value_span, 0, u32::MAX as u128)?;
264                parsed.push(value as u32);
265            }
266            ok!(DwarfDirective {
267                kind = DwarfDirectiveKind::FourByteValues(parsed)
268            })
269        },
270    );
271
272    let four_byte_label = try_map(
273        seq(four_byte_keyword(), Label::parse()),
274        |(_, label), span| {
275            ok!(DwarfDirective {
276                kind = DwarfDirectiveKind::FourByteLabel(label)
277            })
278        },
279    );
280
281    let quad_values = try_map(
282        seq(
283            directive_exact_p("quad"),
284            sep_by1(unsigned_integer_literal(), comma_p()),
285        ),
286        |(_, values), span| {
287            let mut parsed = Vec::new();
288            for (text, value_span) in values {
289                let value = parse_unsigned_integer(&text, value_span, 0, u64::MAX as u128)?;
290                parsed.push(value as u64);
291            }
292            ok!(DwarfDirective {
293                kind = DwarfDirectiveKind::QuadValues(parsed)
294            })
295        },
296    );
297
298    let quad_label = try_map(
299        seq(directive_exact_p("quad"), Label::parse()),
300        |(_, label), span| {
301            ok!(DwarfDirective {
302                kind = DwarfDirectiveKind::QuadLabel(label)
303            })
304        },
305    );
306
307    alt!(
308        byte_values,
309        four_byte_label,
310        quad_label,
311        four_byte_values,
312        quad_values
313    )
314}
315
316fn pragma_directive()
317-> impl Fn(&mut PtxTokenStream) -> Result<(StatementDirective, Span), PtxParseError> {
318    try_map(
319        skip_semicolon(seq(directive_exact_p("pragma"), string_literal_p())),
320        |(_, text), span| {
321            let kind = match text.trim() {
322                "nounroll" => PragmaDirectiveKind::Nounroll,
323                "enable_smem_spilling" => PragmaDirectiveKind::EnableSmemSpilling,
324                other if other.starts_with("used_bytes_mask") => {
325                    let mask = other["used_bytes_mask".len()..].trim().to_string();
326                    PragmaDirectiveKind::UsedBytesMask { mask }
327                }
328                other if other.starts_with("frequency") => {
329                    let value_str = other["frequency".len()..].trim();
330                    let value = parse_u32_literal(value_str, span)?;
331                    PragmaDirectiveKind::Frequency { value }
332                }
333                other => PragmaDirectiveKind::Raw(other.to_string()),
334            };
335            let directive = c!(PragmaDirective { kind });
336            ok!(StatementDirective::Pragma { directive })
337        },
338    )
339}
340
341fn section_directive()
342-> impl Fn(&mut PtxTokenStream) -> Result<(StatementDirective, Span), PtxParseError> {
343    mapc!(
344        SectionDirective::parse(),
345        StatementDirective::Section { directive }
346    )
347}
348
349fn register_statement()
350-> impl Fn(&mut PtxTokenStream) -> Result<(RegisterDirective, Span), PtxParseError> {
351    mapc!(
352        skip_semicolon(seq(
353            skip_first(directive_exact_p("reg"), DataType::parse()),
354            register_targets_parser(),
355        )),
356        RegisterDirective { ty, registers }
357    )
358}
359
360fn register_targets_parser()
361-> impl Fn(&mut PtxTokenStream) -> Result<(Vec<RegisterTarget>, Span), PtxParseError> {
362    map(
363        sep_by1(
364            seq(register_symbol(), optional(register_count())),
365            comma_p(),
366        ),
367        |entries, _span| {
368            let registers = entries
369                .into_iter()
370                .map(|(symbol, range)| {
371                    let symbol_span = symbol.span;
372                    RegisterTarget {
373                        name: symbol,
374                        range,
375                        span: symbol_span,
376                    }
377                })
378                .collect();
379            registers
380        },
381    )
382}
383
384fn register_symbol() -> impl Fn(&mut PtxTokenStream) -> Result<(VariableSymbol, Span), PtxParseError>
385{
386    alt(
387        map(register_p(), |name, span| VariableSymbol {
388            val: name,
389            span,
390        }),
391        map(identifier_p(), |val, span| VariableSymbol { val, span }),
392    )
393}
394
395fn register_count() -> impl Fn(&mut PtxTokenStream) -> Result<(u32, Span), PtxParseError> {
396    between(langle_p(), rangle_p(), u32_p())
397}
398
399fn location_directive()
400-> impl Fn(&mut PtxTokenStream) -> Result<(LocationDirective, Span), PtxParseError> {
401    mapc!(
402        seq_n!(
403            skip_first(directive_exact_p("loc"), u32_p()),
404            u32_p(),
405            u32_p(),
406            pure(Option::<LocationInlinedAt>::None)
407        ),
408        LocationDirective {
409            file_index,
410            line,
411            column,
412            inlined_at,
413        }
414    )
415}
416
417fn section_name_parser() -> impl Fn(&mut PtxTokenStream) -> Result<(String, Span), PtxParseError> {
418    alt(
419        map(directive_p(), func!(|name| format!(".{name}"))),
420        identifier_p(),
421    )
422}
423
424fn section_body_parser()
425-> impl Fn(&mut PtxTokenStream) -> Result<(Vec<SectionEntry>, Span), PtxParseError> {
426    between(lbrace_p(), rbrace_p(), many(section_entry_parser()))
427}
428
429fn skip_optional_semicolon<T, P>(
430    parser: P,
431) -> impl Fn(&mut PtxTokenStream) -> Result<(T, Span), PtxParseError>
432where
433    P: Fn(&mut PtxTokenStream) -> Result<(T, Span), PtxParseError>,
434{
435    move |stream| {
436        let (value, span) = parser(stream)?;
437        let _ = optional(semicolon_p())(stream)?;
438        Ok((value, span))
439    }
440}
441
442fn section_entry_parser()
443-> impl Fn(&mut PtxTokenStream) -> Result<(SectionEntry, Span), PtxParseError> {
444    alt(
445        label_entry(),
446        map(
447            section_directive_line(),
448            func!(|line| SectionEntry::Directive(line)),
449        ),
450    )
451}
452
453fn label_entry() -> impl Fn(&mut PtxTokenStream) -> Result<(SectionEntry, Span), PtxParseError> {
454    map(seq(Label::parse(), colon_p()), |(label, _), span| {
455        SectionEntry::Label { label, span }
456    })
457}
458
459fn section_directive_line()
460-> impl Fn(&mut PtxTokenStream) -> Result<(StatementSectionDirectiveLine, Span), PtxParseError> {
461    let b8 = try_map(
462        skip_optional_semicolon(skip_first(
463            directive_exact_p("b8"),
464            sep_by1(signed_integer_literal(), comma_p()),
465        )),
466        func!(|values| {
467            let mut out = Vec::new();
468            for (text, value_span) in values {
469                let value = parse_signed_integer(&text, value_span, -128, 255)?;
470                out.push(value as i16);
471            }
472            ok!(StatementSectionDirectiveLine::B8 { values = out })
473        }),
474    );
475
476    let b16 = try_map(
477        skip_optional_semicolon(skip_first(
478            directive_exact_p("b16"),
479            sep_by1(signed_integer_literal(), comma_p()),
480        )),
481        |values, span| {
482            let mut out = Vec::new();
483            for (text, value_span) in values {
484                let value = parse_signed_integer(&text, value_span, -32_768, 65_535)?;
485                out.push(value as i32);
486            }
487            ok!(StatementSectionDirectiveLine::B16 { values = out })
488        },
489    );
490
491    let b32 = try_map(
492        skip_optional_semicolon(skip_first(directive_exact_p("b32"), b32_section_suffix())),
493        |line, span| Ok(line.with_span(span)),
494    );
495
496    let b64 = try_map(
497        skip_optional_semicolon(skip_first(directive_exact_p("b64"), b64_section_suffix())),
498        |line, span| Ok(line.with_span(span)),
499    );
500
501    alt!(b8, b16, b32, b64)
502}
503
504fn b32_section_suffix()
505-> impl Fn(&mut PtxTokenStream) -> Result<(StatementSectionDirectiveLine, Span), PtxParseError> {
506    let immediate = try_map(
507        sep_by1(signed_integer_literal(), comma_p()),
508        |values, span| {
509            let mut out = Vec::new();
510            for (text, value_span) in values {
511                let value =
512                    parse_signed_integer(&text, value_span, i64::MIN as i128, i64::MAX as i128)?;
513                out.push(value as i64);
514            }
515            ok!(StatementSectionDirectiveLine::B32Immediate { values = out })
516        },
517    );
518
519    let label_diff = try_map(
520        seq_n!(Label::parse(), minus_p(), Label::parse()),
521        |(left, _, right), span| {
522            ok!(StatementSectionDirectiveLine::B32LabelDiff {
523                entries = (left, right)
524            })
525        },
526    );
527
528    let label_plus = try_map(
529        seq_n!(
530            Label::parse(),
531            alt(map(plus_p(), |_, _| 1i32), map(minus_p(), |_, _| -1i32)),
532            integer_p(),
533        ),
534        |(label, sign, digits), span| {
535            let limit = if sign < 0 {
536                (i32::MAX as u128) + 1
537            } else {
538                i32::MAX as u128
539            };
540            let magnitude = parse_unsigned_integer(&digits, span, 0, limit)? as i128;
541            let value = if sign < 0 { -magnitude } else { magnitude };
542            ok!(StatementSectionDirectiveLine::B32LabelPlusImm {
543                entries = (label, value as i32)
544            })
545        },
546    );
547
548    let label_only = map(
549        Label::parse(),
550        |label, span| c!(StatementSectionDirectiveLine::B32Label { labels = label }),
551    );
552
553    alt!(immediate, label_diff, label_plus, label_only)
554}
555
556fn b64_section_suffix()
557-> impl Fn(&mut PtxTokenStream) -> Result<(StatementSectionDirectiveLine, Span), PtxParseError> {
558    let immediate = try_map(
559        sep_by1(signed_integer_literal(), comma_p()),
560        |values, span| {
561            let mut out = Vec::new();
562            for (text, value_span) in values {
563                let value = parse_signed_integer(&text, value_span, i128::MIN, i128::MAX)?;
564                out.push(value);
565            }
566            ok!(StatementSectionDirectiveLine::B64Immediate { values = out })
567        },
568    );
569
570    let label_diff = try_map(
571        seq_n!(Label::parse(), minus_p(), Label::parse()),
572        |(left, _, right), span| {
573            ok!(StatementSectionDirectiveLine::B64LabelDiff {
574                entries = (left, right)
575            })
576        },
577    );
578
579    let label_plus = try_map(
580        seq_n!(
581            Label::parse(),
582            alt(map(plus_p(), |_, _| 1i32), map(minus_p(), |_, _| -1i32)),
583            integer_p(),
584        ),
585        |(label, sign, digits), span| {
586            let limit = if sign < 0 {
587                (i64::MAX as u128) + 1
588            } else {
589                i64::MAX as u128
590            };
591            let magnitude = parse_unsigned_integer(&digits, span, 0, limit)? as i128;
592            let value = if sign < 0 { -magnitude } else { magnitude };
593            ok!(StatementSectionDirectiveLine::B64LabelPlusImm {
594                entries = (label, value as i64)
595            })
596        },
597    );
598
599    let label_only = map(
600        Label::parse(),
601        |label, span| c!(StatementSectionDirectiveLine::B64Label { labels = label }),
602    );
603
604    alt!(immediate, label_diff, label_plus, label_only)
605}
606
607fn signed_integer_literal()
608-> impl Fn(&mut PtxTokenStream) -> Result<((String, Span), Span), PtxParseError> {
609    map(
610        seq(
611            optional(alt(map(minus_p(), |_, _| '-'), map(plus_p(), |_, _| '+'))),
612            integer_p(),
613        ),
614        |(sign, digits), span| {
615            let mut value = String::new();
616            if let Some(ch) = sign {
617                if ch == '-' {
618                    value.push('-');
619                }
620            }
621            value.push_str(&digits);
622            (value, span)
623        },
624    )
625}
626
627fn unsigned_integer_literal()
628-> impl Fn(&mut PtxTokenStream) -> Result<((String, Span), Span), PtxParseError> {
629    map(integer_p(), |digits, span| (digits, span))
630}
631
632fn four_byte_keyword() -> impl Fn(&mut PtxTokenStream) -> Result<((), Span), PtxParseError> {
633    move |stream| {
634        stream.try_with_span(|stream| {
635            stream.expect(&PtxToken::Dot)?;
636            let (value, value_span) = integer_p()(stream)?;
637            if value != "4" {
638                return Err(crate::unexpected_value!(value_span, &["4"], value));
639            }
640            let (name, name_span) = identifier_p()(stream)?;
641            if name != "byte" {
642                return Err(crate::unexpected_value!(name_span, &["byte"], name));
643            }
644            Ok(())
645        })
646    }
647}
648
649impl PtxParser for AliasFunctionDirective {
650    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
651        use crate::parser::util::{comma_p, directive_exact_p, semicolon_p, skip_first};
652
653        try_map(
654            seq_n!(
655                skip_first(directive_exact_p("alias"), FunctionSymbol::parse()),
656                skip_first(comma_p(), FunctionSymbol::parse()),
657                semicolon_p()
658            ),
659            |(alias, target, _), span| ok!(AliasFunctionDirective { alias, target }),
660        )
661    }
662}
663
664impl PtxParser for FunctionBody {
665    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
666        try_map(
667            between(lbrace_p(), rbrace_p(), many(FunctionStatement::parse())),
668            |statements, span| ok!(FunctionBody { statements }),
669        )
670    }
671}
672
673impl PtxParser for FuncFunctionDirective {
674    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
675        let return_spec = alt(
676            map(
677                between(
678                    lparen_p(),
679                    rparen_p(),
680                    optional(ParameterDirective::parse()),
681                ),
682                |param, _| param,
683            ),
684            map(optional(ParameterDirective::parse()), |param, _| param),
685        );
686
687        let body_or_prototype = alt(
688            map(FunctionBody::parse(), |body, _| Some(body)),
689            map(semicolon_p(), |_, _| None),
690        );
691
692        mapc!(
693            seq_n!(
694                skip_first(directive_exact_p("func"), many(AttributeDirective::parse())),
695                return_spec,
696                FunctionSymbol::parse(),
697                between(
698                    lparen_p(),
699                    rparen_p(),
700                    sep_by(ParameterDirective::parse(), comma_p()),
701                ),
702                many(FuncFunctionHeaderDirective::parse()),
703                body_or_prototype,
704            ),
705            FuncFunctionDirective {
706                attributes,
707                return_param,
708                name,
709                params,
710                directives,
711                body
712            }
713        )
714    }
715}
716
717impl PtxParser for EntryFunctionDirective {
718    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
719        mapc!(
720            seq_n!(
721                skip_first(directive_exact_p("entry"), FunctionSymbol::parse()),
722                between(
723                    lparen_p(),
724                    rparen_p(),
725                    sep_by(ParameterDirective::parse(), comma_p()),
726                ),
727                many(EntryFunctionHeaderDirective::parse()),
728                optional(FunctionBody::parse()),
729            ),
730            EntryFunctionDirective {
731                name,
732                params,
733                directives,
734                body,
735            }
736        )
737    }
738}
739
740impl PtxParser for FuncFunctionHeaderDirective {
741    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
742        alt!(
743            mapc!(
744                directive_exact_p("noreturn"),
745                FuncFunctionHeaderDirective::NoReturn {}
746            ),
747            mapc!(
748                skip_first(
749                    directive_exact_p("pragma"),
750                    seq(sep_by1(string_literal_p(), comma_p()), semicolon_p())
751                ),
752                FuncFunctionHeaderDirective::Pragma { args, _ }
753            ),
754            mapc!(
755                skip_first(directive_exact_p("abi_preserve"), u32_p()),
756                FuncFunctionHeaderDirective::AbiPreserve { value }
757            ),
758            mapc!(
759                skip_first(directive_exact_p("abi_preserve_control"), u32_p()),
760                FuncFunctionHeaderDirective::AbiPreserveControl { value }
761            )
762        )
763    }
764}
765
766impl PtxParser for EntryFunctionHeaderDirective {
767    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
768        alt!(
769            mapc!(
770                skip_first(directive_exact_p("maxnreg"), u32_p()),
771                EntryFunctionHeaderDirective::MaxNReg { value }
772            ),
773            try_map(
774                skip_first(directive_exact_p("maxntid"), sep_by1(u32_p(), comma_p())),
775                |dim_strs, span| {
776                    let dim = parse_function_dim(&dim_strs, span)?;
777                    ok!(EntryFunctionHeaderDirective::MaxNTid { dim })
778                }
779            ),
780            try_map(
781                skip_first(directive_exact_p("reqntid"), sep_by1(u32_p(), comma_p())),
782                |dim_strs, span| {
783                    let dim = parse_function_dim(&dim_strs, span)?;
784                    ok!(EntryFunctionHeaderDirective::ReqNTid { dim })
785                }
786            ),
787            mapc!(
788                skip_first(directive_exact_p("minnctapersm"), u32_p()),
789                EntryFunctionHeaderDirective::MinNCtaPerSm { value }
790            ),
791            mapc!(
792                skip_first(directive_exact_p("maxnctapersm"), u32_p()),
793                EntryFunctionHeaderDirective::MaxNCtaPerSm { value }
794            ),
795            mapc!(
796                skip_first(
797                    directive_exact_p("pragma"),
798                    skip_second(sep_by1(string_literal_p(), comma_p()), semicolon_p())
799                ),
800                EntryFunctionHeaderDirective::Pragma { args }
801            )
802        )
803    }
804}
805
806fn parse_function_dim(dims: &[u32], span: Span) -> Result<FunctionDim, PtxParseError> {
807    match dims.len() {
808        1 => {
809            let x = dims[0];
810            Ok(FunctionDim::X { x, span })
811        }
812        2 => {
813            let x = dims[0];
814            let y = dims[1];
815            Ok(FunctionDim::XY { x, y, span })
816        }
817        3 => {
818            let x = dims[0];
819            let y = dims[1];
820            let z = dims[2];
821            Ok(FunctionDim::XYZ { x, y, z, span })
822        }
823        _ => Err(PtxParseError {
824            kind: ParseErrorKind::InvalidLiteral(format!(
825                "expected 1-3 dimensions, got {}",
826                dims.len()
827            )),
828            span,
829        }),
830    }
831}