ptx_parser/parser/
function.rs

1use ptx_90_parser_construct::func;
2
3use crate::{
4    alt, c,
5    lexer::PtxToken,
6    mapc, ok,
7    parser::{
8        ParseErrorKind, PtxParseError, PtxParser, PtxTokenStream, Span,
9        util::{
10            alt, between, colon_p, comma_p, directive_exact_p, directive_p, identifier_p,
11            integer_p, langle_p, lbrace_p, lparen_p, many, map, minus_p, optional,
12            parse_signed_integer, parse_u32_literal, parse_unsigned_integer, plus_p, pure,
13            rangle_p, rbrace_p, register_p, rparen_p, semicolon_p, sep_by, sep_by1, seq, seq5,
14            skip_first, skip_second, skip_semicolon, string_literal_p, try_map, u32_p,
15        },
16    },
17    seq_n,
18    r#type::{
19        AliasFunctionDirective, AttributeDirective, BranchTargetsDirective, CallPrototypeDirective,
20        CallTargetsDirective, DataType, DwarfDirective, DwarfDirectiveKind, EntryFunctionDirective,
21        EntryFunctionHeaderDirective, FuncFunctionDirective, FuncFunctionHeaderDirective,
22        FunctionBody, FunctionDim, FunctionStatement, FunctionSymbol, Instruction, Label,
23        LocationDirective, LocationInlinedAt, ParameterDirective, PragmaDirective,
24        PragmaDirectiveKind, RegisterDirective, RegisterTarget, SectionDirective, SectionEntry,
25        StatementDirective, StatementSectionDirectiveLine, VariableDirective, VariableSymbol,
26    },
27};
28
29impl PtxParser for StatementDirective {
30    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
31        let branch_targets = try_map(
32            skip_semicolon(skip_first(
33                directive_exact_p("branchtargets"),
34                sep_by1(Label::parse(), comma_p()),
35            )),
36            |labels, span| {
37                let directive = BranchTargetsDirective {
38                    labels,
39                    span: span.clone(),
40                };
41                ok!(StatementDirective::BranchTargets { directive })
42            },
43        );
44
45        let call_targets = try_map(
46            skip_semicolon(skip_first(
47                directive_exact_p("calltargets"),
48                sep_by1(FunctionSymbol::parse(), comma_p()),
49            )),
50            |targets, span| {
51                let directive = CallTargetsDirective {
52                    targets,
53                    span: span.clone(),
54                };
55                ok!(StatementDirective::CallTargets { directive })
56            },
57        );
58
59        let call_prototype = try_map(
60            skip_semicolon(skip_first(
61                directive_exact_p("callprototype"),
62                seq5(
63                    return_spec_parser(),
64                    parameter_list_parser(),
65                    noreturn_parser(),
66                    abi_preserve_parser(),
67                    abi_preserve_control_parser(),
68                ),
69            )),
70            |(return_param, params, noreturn, abi_preserve, abi_preserve_control), span| {
71                let directive = CallPrototypeDirective {
72                    return_param,
73                    params,
74                    noreturn,
75                    abi_preserve,
76                    abi_preserve_control,
77                    span: span.clone(),
78                };
79                ok!(StatementDirective::CallPrototype { directive })
80            },
81        );
82
83        let location = mapc!(location_directive(), StatementDirective::Loc { directive });
84
85        let reg_stmt = mapc!(register_statement(), StatementDirective::Reg { directive });
86
87        let local_stmt = mapc!(
88            skip_first(directive_exact_p("local"), VariableDirective::parse()),
89            StatementDirective::Local { directive }
90        );
91
92        let param_stmt = mapc!(
93            skip_first(directive_exact_p("param"), VariableDirective::parse()),
94            StatementDirective::Param { directive }
95        );
96
97        let shared_stmt = mapc!(
98            skip_first(directive_exact_p("shared"), VariableDirective::parse()),
99            StatementDirective::Shared { directive }
100        );
101
102        alt!(
103            location,
104            reg_stmt,
105            local_stmt,
106            param_stmt,
107            shared_stmt,
108            branch_targets,
109            call_targets,
110            call_prototype,
111            dwarf_directive(),
112            pragma_directive(),
113            section_directive()
114        )
115    }
116}
117
118impl PtxParser for SectionDirective {
119    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
120        map(
121            skip_first(
122                directive_exact_p("section"),
123                seq(section_name_parser(), section_body_parser()),
124            ),
125            |(name, entries), span| {
126                c!(SectionDirective {
127                    name = name,
128                    entries,
129                })
130            },
131        )
132    }
133}
134
135impl PtxParser for DwarfDirective {
136    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
137        skip_first(directive_exact_p("dwarf"), dwarf_kind_parser())
138    }
139}
140
141impl PtxParser for FunctionStatement {
142    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
143        let label_stmt = map(seq(Label::parse(), colon_p()), |(label, _), span| {
144            c!(FunctionStatement::Label { label })
145        });
146
147        let block_stmt = move |stream: &mut PtxTokenStream| {
148            map(
149                between(lbrace_p(), rbrace_p(), many(FunctionStatement::parse())),
150                |statements, span| c!(FunctionStatement::Block { statements }),
151            )(stream)
152        };
153
154        let directive_stmt = mapc!(
155            StatementDirective::parse(),
156            FunctionStatement::Directive { directive }
157        );
158
159        let instruction_stmt = mapc!(
160            Instruction::parse(),
161            FunctionStatement::Instruction { instruction }
162        );
163
164        alt!(label_stmt, block_stmt, directive_stmt, instruction_stmt)
165    }
166}
167
168fn return_spec_parser()
169-> impl Fn(&mut PtxTokenStream) -> Result<(Option<ParameterDirective>, Span), PtxParseError> {
170    alt(
171        map(ParameterDirective::parse(), |param, _| Some(param)),
172        map(underscore_placeholder(), |_, _| None),
173    )
174}
175
176fn underscore_placeholder() -> impl Fn(&mut PtxTokenStream) -> Result<((), Span), PtxParseError> {
177    try_map(identifier_p(), |name, span| {
178        if name == "_" {
179            Ok(())
180        } else {
181            Err(PtxParseError {
182                kind: ParseErrorKind::UnexpectedToken {
183                    expected: vec!["identifier `_`".into()],
184                    found: name,
185                },
186                span,
187            })
188        }
189    })
190}
191
192fn parameter_list_parser()
193-> impl Fn(&mut PtxTokenStream) -> Result<(Vec<ParameterDirective>, Span), PtxParseError> {
194    map(
195        between(
196            lparen_p(),
197            rparen_p(),
198            sep_by(ParameterDirective::parse(), comma_p()),
199        ),
200        |params, _| params,
201    )
202}
203
204fn noreturn_parser() -> impl Fn(&mut PtxTokenStream) -> Result<(bool, Span), PtxParseError> {
205    map(optional(directive_exact_p("noreturn")), |flag, _| {
206        flag.is_some()
207    })
208}
209
210fn abi_preserve_parser()
211-> impl Fn(&mut PtxTokenStream) -> Result<(Option<u32>, Span), PtxParseError> {
212    map(
213        optional(skip_first(directive_exact_p("abi_preserve"), u32_p())),
214        |value, _span| value,
215    )
216}
217
218fn abi_preserve_control_parser()
219-> impl Fn(&mut PtxTokenStream) -> Result<(Option<u32>, Span), PtxParseError> {
220    map(
221        optional(skip_first(
222            directive_exact_p("abi_preserve_control"),
223            u32_p(),
224        )),
225        |value, _span| value,
226    )
227}
228fn dwarf_directive()
229-> impl Fn(&mut PtxTokenStream) -> Result<(StatementDirective, Span), PtxParseError> {
230    mapc!(
231        skip_semicolon(DwarfDirective::parse()),
232        StatementDirective::Dwarf { directive }
233    )
234}
235
236fn dwarf_kind_parser()
237-> impl Fn(&mut PtxTokenStream) -> Result<(DwarfDirective, Span), PtxParseError> {
238    let byte_values = try_map(
239        seq(
240            directive_exact_p("byte"),
241            sep_by1(unsigned_integer_literal(), comma_p()),
242        ),
243        |(_, values), span| {
244            let mut parsed = Vec::new();
245            for (text, value_span) in values {
246                let value = parse_unsigned_integer(&text, value_span, 0, u8::MAX as u128)?;
247                parsed.push(value as u8);
248            }
249            ok!(DwarfDirective {
250                kind = DwarfDirectiveKind::ByteValues(parsed)
251            })
252        },
253    );
254
255    let four_byte_values = try_map(
256        seq(
257            four_byte_keyword(),
258            sep_by1(unsigned_integer_literal(), comma_p()),
259        ),
260        |(_, values), span| {
261            let mut parsed = Vec::new();
262            for (text, value_span) in values {
263                let value = parse_unsigned_integer(&text, value_span, 0, u32::MAX as u128)?;
264                parsed.push(value as u32);
265            }
266            ok!(DwarfDirective {
267                kind = DwarfDirectiveKind::FourByteValues(parsed)
268            })
269        },
270    );
271
272    let four_byte_label = try_map(
273        seq(four_byte_keyword(), Label::parse()),
274        |(_, label), span| {
275            ok!(DwarfDirective {
276                kind = DwarfDirectiveKind::FourByteLabel(label)
277            })
278        },
279    );
280
281    let quad_values = try_map(
282        seq(
283            directive_exact_p("quad"),
284            sep_by1(unsigned_integer_literal(), comma_p()),
285        ),
286        |(_, values), span| {
287            let mut parsed = Vec::new();
288            for (text, value_span) in values {
289                let value = parse_unsigned_integer(&text, value_span, 0, u64::MAX as u128)?;
290                parsed.push(value as u64);
291            }
292            ok!(DwarfDirective {
293                kind = DwarfDirectiveKind::QuadValues(parsed)
294            })
295        },
296    );
297
298    let quad_label = try_map(
299        seq(directive_exact_p("quad"), Label::parse()),
300        |(_, label), span| {
301            ok!(DwarfDirective {
302                kind = DwarfDirectiveKind::QuadLabel(label)
303            })
304        },
305    );
306
307    alt!(
308        byte_values,
309        four_byte_label,
310        quad_label,
311        four_byte_values,
312        quad_values
313    )
314}
315
316fn pragma_directive()
317-> impl Fn(&mut PtxTokenStream) -> Result<(StatementDirective, Span), PtxParseError> {
318    try_map(
319        skip_semicolon(seq(directive_exact_p("pragma"), string_literal_p())),
320        |(_, text), span| {
321            let kind = match text.trim() {
322                "nounroll" => PragmaDirectiveKind::Nounroll,
323                "enable_smem_spilling" => PragmaDirectiveKind::EnableSmemSpilling,
324                other if other.starts_with("used_bytes_mask") => {
325                    let mask = other["used_bytes_mask".len()..].trim().to_string();
326                    PragmaDirectiveKind::UsedBytesMask { mask }
327                }
328                other if other.starts_with("frequency") => {
329                    let value_str = other["frequency".len()..].trim();
330                    let value = parse_u32_literal(value_str, span)?;
331                    PragmaDirectiveKind::Frequency { value }
332                }
333                other => PragmaDirectiveKind::Raw(other.to_string()),
334            };
335            let directive = c!(PragmaDirective { kind });
336            ok!(StatementDirective::Pragma { directive })
337        },
338    )
339}
340
341fn section_directive()
342-> impl Fn(&mut PtxTokenStream) -> Result<(StatementDirective, Span), PtxParseError> {
343    mapc!(
344        SectionDirective::parse(),
345        StatementDirective::Section { directive }
346    )
347}
348
349fn register_statement()
350-> impl Fn(&mut PtxTokenStream) -> Result<(RegisterDirective, Span), PtxParseError> {
351    mapc!(
352        skip_semicolon(seq(
353            skip_first(directive_exact_p("reg"), DataType::parse()),
354            register_targets_parser(),
355        )),
356        RegisterDirective { ty, registers }
357    )
358}
359
360fn register_targets_parser()
361-> impl Fn(&mut PtxTokenStream) -> Result<(Vec<RegisterTarget>, Span), PtxParseError> {
362    map(
363        sep_by1(
364            seq(register_symbol(), optional(register_count())),
365            comma_p(),
366        ),
367        |entries, _span| {
368            let registers = entries
369                .into_iter()
370                .map(|(symbol, range)| {
371                    let symbol_span = symbol.span;
372                    RegisterTarget {
373                        name: symbol,
374                        range,
375                        span: symbol_span,
376                    }
377                })
378                .collect();
379            registers
380        },
381    )
382}
383
384fn register_symbol() -> impl Fn(&mut PtxTokenStream) -> Result<(VariableSymbol, Span), PtxParseError>
385{
386    alt(
387        map(register_p(), |name, span| VariableSymbol {
388            val: name,
389            span,
390        }),
391        map(identifier_p(), |val, span| VariableSymbol { val, span }),
392    )
393}
394
395fn register_count() -> impl Fn(&mut PtxTokenStream) -> Result<(u32, Span), PtxParseError> {
396    between(langle_p(), rangle_p(), u32_p())
397}
398
399fn location_directive()
400-> impl Fn(&mut PtxTokenStream) -> Result<(LocationDirective, Span), PtxParseError> {
401    mapc!(
402        seq_n!(
403            skip_first(directive_exact_p("loc"), u32_p()),
404            u32_p(),
405            u32_p(),
406            pure(Option::<LocationInlinedAt>::None)
407        ),
408        LocationDirective {
409            file_index,
410            line,
411            column,
412            inlined_at,
413        }
414    )
415}
416
417fn section_name_parser() -> impl Fn(&mut PtxTokenStream) -> Result<(String, Span), PtxParseError> {
418    alt(
419        map(directive_p(), func!(|name| format!(".{name}"))),
420        identifier_p(),
421    )
422}
423
424fn section_body_parser()
425-> impl Fn(&mut PtxTokenStream) -> Result<(Vec<SectionEntry>, Span), PtxParseError> {
426    between(lbrace_p(), rbrace_p(), many(section_entry_parser()))
427}
428
429fn section_entry_parser()
430-> impl Fn(&mut PtxTokenStream) -> Result<(SectionEntry, Span), PtxParseError> {
431    alt(
432        label_entry(),
433        map(
434            section_directive_line(),
435            func!(|line| SectionEntry::Directive(line)),
436        ),
437    )
438}
439
440fn label_entry() -> impl Fn(&mut PtxTokenStream) -> Result<(SectionEntry, Span), PtxParseError> {
441    map(seq(Label::parse(), colon_p()), |(label, _), span| {
442        SectionEntry::Label { label, span }
443    })
444}
445
446fn section_directive_line()
447-> impl Fn(&mut PtxTokenStream) -> Result<(StatementSectionDirectiveLine, Span), PtxParseError> {
448    let b8 = try_map(
449        skip_semicolon(skip_first(
450            directive_exact_p("b8"),
451            sep_by1(signed_integer_literal(), comma_p()),
452        )),
453        func!(|values| {
454            let mut out = Vec::new();
455            for (text, value_span) in values {
456                let value = parse_signed_integer(&text, value_span, -128, 255)?;
457                out.push(value as i16);
458            }
459            ok!(StatementSectionDirectiveLine::B8 { values = out })
460        }),
461    );
462
463    let b16 = try_map(
464        skip_semicolon(skip_first(
465            directive_exact_p("b16"),
466            sep_by1(signed_integer_literal(), comma_p()),
467        )),
468        |values, span| {
469            let mut out = Vec::new();
470            for (text, value_span) in values {
471                let value = parse_signed_integer(&text, value_span, -32_768, 65_535)?;
472                out.push(value as i32);
473            }
474            ok!(StatementSectionDirectiveLine::B16 { values = out })
475        },
476    );
477
478    let b32 = try_map(
479        skip_semicolon(skip_first(directive_exact_p("b32"), b32_section_suffix())),
480        |line, span| Ok(line.with_span(span)),
481    );
482
483    let b64 = try_map(
484        skip_semicolon(skip_first(directive_exact_p("b64"), b64_section_suffix())),
485        |line, span| Ok(line.with_span(span)),
486    );
487
488    alt!(b8, b16, b32, b64)
489}
490
491fn b32_section_suffix()
492-> impl Fn(&mut PtxTokenStream) -> Result<(StatementSectionDirectiveLine, Span), PtxParseError> {
493    let immediate = try_map(
494        sep_by1(signed_integer_literal(), comma_p()),
495        |values, span| {
496            let mut out = Vec::new();
497            for (text, value_span) in values {
498                let value =
499                    parse_signed_integer(&text, value_span, i64::MIN as i128, i64::MAX as i128)?;
500                out.push(value as i64);
501            }
502            ok!(StatementSectionDirectiveLine::B32Immediate { values = out })
503        },
504    );
505
506    let label_diff = try_map(
507        seq_n!(Label::parse(), minus_p(), Label::parse()),
508        |(left, _, right), span| {
509            ok!(StatementSectionDirectiveLine::B32LabelDiff {
510                entries = (left, right)
511            })
512        },
513    );
514
515    let label_plus = try_map(
516        seq_n!(
517            Label::parse(),
518            alt(map(plus_p(), |_, _| 1i32), map(minus_p(), |_, _| -1i32)),
519            integer_p(),
520        ),
521        |(label, sign, digits), span| {
522            let limit = if sign < 0 {
523                (i32::MAX as u128) + 1
524            } else {
525                i32::MAX as u128
526            };
527            let magnitude = parse_unsigned_integer(&digits, span, 0, limit)? as i128;
528            let value = if sign < 0 { -magnitude } else { magnitude };
529            ok!(StatementSectionDirectiveLine::B32LabelPlusImm {
530                entries = (label, value as i32)
531            })
532        },
533    );
534
535    let label_only = map(
536        Label::parse(),
537        |label, span| c!(StatementSectionDirectiveLine::B32Label { labels = label }),
538    );
539
540    alt!(immediate, label_diff, label_plus, label_only)
541}
542
543fn b64_section_suffix()
544-> impl Fn(&mut PtxTokenStream) -> Result<(StatementSectionDirectiveLine, Span), PtxParseError> {
545    let immediate = try_map(
546        sep_by1(signed_integer_literal(), comma_p()),
547        |values, span| {
548            let mut out = Vec::new();
549            for (text, value_span) in values {
550                let value = parse_signed_integer(&text, value_span, i128::MIN, i128::MAX)?;
551                out.push(value);
552            }
553            ok!(StatementSectionDirectiveLine::B64Immediate { values = out })
554        },
555    );
556
557    let label_diff = try_map(
558        seq_n!(Label::parse(), minus_p(), Label::parse()),
559        |(left, _, right), span| {
560            ok!(StatementSectionDirectiveLine::B64LabelDiff {
561                entries = (left, right)
562            })
563        },
564    );
565
566    let label_plus = try_map(
567        seq_n!(
568            Label::parse(),
569            alt(map(plus_p(), |_, _| 1i32), map(minus_p(), |_, _| -1i32)),
570            integer_p(),
571        ),
572        |(label, sign, digits), span| {
573            let limit = if sign < 0 {
574                (i64::MAX as u128) + 1
575            } else {
576                i64::MAX as u128
577            };
578            let magnitude = parse_unsigned_integer(&digits, span, 0, limit)? as i128;
579            let value = if sign < 0 { -magnitude } else { magnitude };
580            ok!(StatementSectionDirectiveLine::B64LabelPlusImm {
581                entries = (label, value as i64)
582            })
583        },
584    );
585
586    let label_only = map(
587        Label::parse(),
588        |label, span| c!(StatementSectionDirectiveLine::B64Label { labels = label }),
589    );
590
591    alt!(immediate, label_diff, label_plus, label_only)
592}
593
594fn signed_integer_literal()
595-> impl Fn(&mut PtxTokenStream) -> Result<((String, Span), Span), PtxParseError> {
596    map(
597        seq(
598            optional(alt(map(minus_p(), |_, _| '-'), map(plus_p(), |_, _| '+'))),
599            integer_p(),
600        ),
601        |(sign, digits), span| {
602            let mut value = String::new();
603            if let Some(ch) = sign {
604                if ch == '-' {
605                    value.push('-');
606                }
607            }
608            value.push_str(&digits);
609            (value, span)
610        },
611    )
612}
613
614fn unsigned_integer_literal()
615-> impl Fn(&mut PtxTokenStream) -> Result<((String, Span), Span), PtxParseError> {
616    map(integer_p(), |digits, span| (digits, span))
617}
618
619fn four_byte_keyword() -> impl Fn(&mut PtxTokenStream) -> Result<((), Span), PtxParseError> {
620    move |stream| {
621        stream.try_with_span(|stream| {
622            stream.expect(&PtxToken::Dot)?;
623            let (value, value_span) = integer_p()(stream)?;
624            if value != "4" {
625                return Err(crate::unexpected_value!(value_span, &["4"], value));
626            }
627            let (name, name_span) = identifier_p()(stream)?;
628            if name != "byte" {
629                return Err(crate::unexpected_value!(name_span, &["byte"], name));
630            }
631            Ok(())
632        })
633    }
634}
635
636impl PtxParser for AliasFunctionDirective {
637    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
638        use crate::parser::util::{comma_p, directive_exact_p, semicolon_p, skip_first};
639
640        try_map(
641            seq_n!(
642                skip_first(directive_exact_p("alias"), FunctionSymbol::parse()),
643                skip_first(comma_p(), FunctionSymbol::parse()),
644                semicolon_p()
645            ),
646            |(alias, target, _), span| ok!(AliasFunctionDirective { alias, target }),
647        )
648    }
649}
650
651impl PtxParser for FunctionBody {
652    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
653        try_map(
654            between(lbrace_p(), rbrace_p(), many(FunctionStatement::parse())),
655            |statements, span| ok!(FunctionBody { statements }),
656        )
657    }
658}
659
660impl PtxParser for FuncFunctionDirective {
661    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
662        let return_spec = alt(
663            map(
664                between(
665                    lparen_p(),
666                    rparen_p(),
667                    optional(ParameterDirective::parse()),
668                ),
669                |param, _| param,
670            ),
671            map(optional(ParameterDirective::parse()), |param, _| param),
672        );
673
674        let body_or_prototype = alt(
675            map(FunctionBody::parse(), |body, _| Some(body)),
676            map(semicolon_p(), |_, _| None),
677        );
678
679        mapc!(
680            seq_n!(
681                skip_first(directive_exact_p("func"), many(AttributeDirective::parse())),
682                return_spec,
683                FunctionSymbol::parse(),
684                between(
685                    lparen_p(),
686                    rparen_p(),
687                    sep_by(ParameterDirective::parse(), comma_p()),
688                ),
689                many(FuncFunctionHeaderDirective::parse()),
690                body_or_prototype,
691            ),
692            FuncFunctionDirective {
693                attributes,
694                return_param,
695                name,
696                params,
697                directives,
698                body
699            }
700        )
701    }
702}
703
704impl PtxParser for EntryFunctionDirective {
705    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
706        mapc!(
707            seq_n!(
708                skip_first(directive_exact_p("entry"), FunctionSymbol::parse()),
709                between(
710                    lparen_p(),
711                    rparen_p(),
712                    sep_by(ParameterDirective::parse(), comma_p()),
713                ),
714                many(EntryFunctionHeaderDirective::parse()),
715                optional(FunctionBody::parse()),
716            ),
717            EntryFunctionDirective {
718                name,
719                params,
720                directives,
721                body,
722            }
723        )
724    }
725}
726
727impl PtxParser for FuncFunctionHeaderDirective {
728    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
729        alt!(
730            mapc!(
731                directive_exact_p("noreturn"),
732                FuncFunctionHeaderDirective::NoReturn {}
733            ),
734            mapc!(
735                skip_first(
736                    directive_exact_p("pragma"),
737                    seq(sep_by1(string_literal_p(), comma_p()), semicolon_p())
738                ),
739                FuncFunctionHeaderDirective::Pragma { args, _ }
740            ),
741            mapc!(
742                skip_first(directive_exact_p("abi_preserve"), u32_p()),
743                FuncFunctionHeaderDirective::AbiPreserve { value }
744            ),
745            mapc!(
746                skip_first(directive_exact_p("abi_preserve_control"), u32_p()),
747                FuncFunctionHeaderDirective::AbiPreserveControl { value }
748            )
749        )
750    }
751}
752
753impl PtxParser for EntryFunctionHeaderDirective {
754    fn parse() -> impl Fn(&mut PtxTokenStream) -> Result<(Self, Span), PtxParseError> {
755        alt!(
756            mapc!(
757                skip_first(directive_exact_p("maxnreg"), u32_p()),
758                EntryFunctionHeaderDirective::MaxNReg { value }
759            ),
760            try_map(
761                skip_first(directive_exact_p("maxntid"), sep_by1(u32_p(), comma_p())),
762                |dim_strs, span| {
763                    let dim = parse_function_dim(&dim_strs, span)?;
764                    ok!(EntryFunctionHeaderDirective::MaxNTid { dim })
765                }
766            ),
767            try_map(
768                skip_first(directive_exact_p("reqntid"), sep_by1(u32_p(), comma_p())),
769                |dim_strs, span| {
770                    let dim = parse_function_dim(&dim_strs, span)?;
771                    ok!(EntryFunctionHeaderDirective::ReqNTid { dim })
772                }
773            ),
774            mapc!(
775                skip_first(directive_exact_p("minnctapersm"), u32_p()),
776                EntryFunctionHeaderDirective::MinNCtaPerSm { value }
777            ),
778            mapc!(
779                skip_first(directive_exact_p("maxnctapersm"), u32_p()),
780                EntryFunctionHeaderDirective::MaxNCtaPerSm { value }
781            ),
782            mapc!(
783                skip_first(
784                    directive_exact_p("pragma"),
785                    skip_second(sep_by1(string_literal_p(), comma_p()), semicolon_p())
786                ),
787                EntryFunctionHeaderDirective::Pragma { args }
788            )
789        )
790    }
791}
792
793fn parse_function_dim(dims: &[u32], span: Span) -> Result<FunctionDim, PtxParseError> {
794    match dims.len() {
795        1 => {
796            let x = dims[0];
797            Ok(FunctionDim::X { x, span })
798        }
799        2 => {
800            let x = dims[0];
801            let y = dims[1];
802            Ok(FunctionDim::XY { x, y, span })
803        }
804        3 => {
805            let x = dims[0];
806            let y = dims[1];
807            let z = dims[2];
808            Ok(FunctionDim::XYZ { x, y, z, span })
809        }
810        _ => Err(PtxParseError {
811            kind: ParseErrorKind::InvalidLiteral(format!(
812                "expected 1-3 dimensions, got {}",
813                dims.len()
814            )),
815            span,
816        }),
817    }
818}