ptx_parser/parser/
function.rs

1use crate::parser::common::{
2    invalid_literal, parse_register_name, parse_u64_literal, try_parse_label,
3};
4use crate::r#type::common::{CodeLinkage, Instruction};
5use crate::unlexer::PtxUnlexer;
6use crate::{
7    lexer::{PtxToken, tokenize},
8    parser::{
9        ParseErrorKind, PtxParseError, PtxParser, PtxTokenStream, Span, expect_directive_value,
10        peek_directive, unexpected_value,
11    },
12    r#type::{
13        function::{
14            DwarfDirective, EntryFunction, FuncFunction, FunctionAlias, FunctionBody, FunctionDim3,
15            FunctionHeaderDirective, FunctionKernelDirective, FunctionStatement, LocationDirective,
16            PragmaDirective, RegisterDirective, StatementDirective, StatementSectionDirective,
17        },
18        variable::VariableDirective,
19    },
20};
21
22impl FunctionHeaderDirective {
23    fn parse_list(stream: &mut PtxTokenStream) -> Result<Vec<Self>, PtxParseError> {
24        let mut directives = Vec::new();
25        loop {
26            let Some((name, span)) = peek_directive(stream)? else {
27                break;
28            };
29            match name.as_str() {
30                "visible" | "extern" | "weak" => {
31                    let linkage = CodeLinkage::parse(stream)?;
32                    let linkage_span = linkage.span();
33                    directives.push(FunctionHeaderDirective::Linkage { linkage, span: linkage_span });
34                }
35                "entry" | "func" | "alias" => break,
36                other => {
37                    return Err(unexpected_value(
38                        span,
39                        &[".visible", ".extern", ".weak", ".entry", ".func", ".alias"],
40                        format!(".{other}"),
41                    ));
42                }
43            }
44        }
45        Ok(directives)
46    }
47}
48
49fn parse_register_range(stream: &mut PtxTokenStream) -> Result<Option<u32>, PtxParseError> {
50    if stream
51        .consume_if(|token| matches!(token, PtxToken::LAngle))
52        .is_none()
53    {
54        return Ok(None);
55    }
56
57    let (value, span) = parse_u64_literal(stream)?;
58    if value > u32::MAX as u64 {
59        return Err(invalid_literal(
60            span.clone(),
61            "register range exceeds u32::MAX",
62        ));
63    }
64    stream.expect(&PtxToken::RAngle)?;
65    Ok(Some(value as u32))
66}
67
68fn tokens_to_string(tokens: &[PtxToken], span: &Span) -> Result<String, PtxParseError> {
69    PtxUnlexer::to_string(tokens)
70        .map_err(|_| invalid_literal(span.clone(), "failed to serialize token sequence"))
71}
72
73fn parse_parameter_tokens(
74    tokens: &[PtxToken],
75    span: &Span,
76) -> Result<VariableDirective, PtxParseError> {
77    let serialized = tokens_to_string(tokens, span)?;
78    let source = format!("{};", serialized);
79    let tokenized = tokenize(&source)
80        .map_err(|_| invalid_literal(span.clone(), "failed to tokenize function parameter"))?;
81    let mut temp_stream = PtxTokenStream::new(&tokenized);
82    let directive = VariableDirective::parse(&mut temp_stream)?;
83    Ok(directive)
84}
85
86fn collect_parameter_tokens(
87    stream: &mut PtxTokenStream,
88) -> Result<(Vec<PtxToken>, Span), PtxParseError> {
89    let (first_token, first_span) = stream.peek()?;
90    if matches!(first_token, PtxToken::Comma | PtxToken::RParen) {
91        return Err(unexpected_value(
92            first_span.clone(),
93            &["function parameter"],
94            format!("{first_token:?}"),
95        ));
96    }
97
98    let mut tokens = Vec::new();
99    let mut paren_depth = 0usize;
100    let mut bracket_depth = 0usize;
101
102    loop {
103        let (next_token, _) = stream.peek()?;
104        if paren_depth == 0 && bracket_depth == 0 {
105            if matches!(next_token, PtxToken::Comma | PtxToken::RParen) {
106                break;
107            }
108        }
109
110        let (token, _) = stream.consume()?;
111        match token {
112            PtxToken::LParen => paren_depth += 1,
113            PtxToken::RParen => paren_depth = paren_depth.saturating_sub(1),
114            PtxToken::LBracket => bracket_depth += 1,
115            PtxToken::RBracket => bracket_depth = bracket_depth.saturating_sub(1),
116            _ => {}
117        }
118        tokens.push(token.clone());
119    }
120
121    Ok((tokens, first_span.clone()))
122}
123
124fn parse_parameter(stream: &mut PtxTokenStream) -> Result<VariableDirective, PtxParseError> {
125    let (tokens, span) = collect_parameter_tokens(stream)?;
126    if tokens.is_empty() {
127        return Err(unexpected_value(
128            span.clone(),
129            &["function parameter"],
130            "".to_string(),
131        ));
132    }
133    parse_parameter_tokens(&tokens, &span)
134}
135
136fn parse_parameter_list(
137    stream: &mut PtxTokenStream,
138) -> Result<Vec<VariableDirective>, PtxParseError> {
139    stream.expect(&PtxToken::LParen)?;
140    if stream
141        .consume_if(|token| matches!(token, PtxToken::RParen))
142        .is_some()
143    {
144        return Ok(Vec::new());
145    }
146
147    let mut params = Vec::new();
148    loop {
149        let param = parse_parameter(stream)?;
150        params.push(param);
151        if stream
152            .consume_if(|token| matches!(token, PtxToken::Comma))
153            .is_none()
154        {
155            break;
156        }
157    }
158    stream.expect(&PtxToken::RParen)?;
159    Ok(params)
160}
161
162fn parse_return_parameter(
163    stream: &mut PtxTokenStream,
164) -> Result<Option<VariableDirective>, PtxParseError> {
165    if stream
166        .consume_if(|token| matches!(token, PtxToken::LParen))
167        .is_none()
168    {
169        return Ok(None);
170    }
171
172    if stream
173        .consume_if(|token| matches!(token, PtxToken::RParen))
174        .is_some()
175    {
176        return Ok(None);
177    }
178
179    let param = parse_parameter(stream)?;
180    stream.expect(&PtxToken::RParen)?;
181    Ok(Some(param))
182}
183
184fn parse_optional_noreturn(
185    stream: &mut PtxTokenStream,
186    directives: &mut Vec<FunctionHeaderDirective>,
187) -> Result<bool, PtxParseError> {
188    if let Some((token, _)) = stream.peek().ok() {
189        if let PtxToken::Dot = token {
190            // Check if it's a directive
191            let saved_pos = stream.position();
192            let (_, dot_span) = stream.consume()?; // consume dot
193            if let Ok((name, name_span)) = stream.expect_identifier() {
194                if name == "noreturn" {
195                    if !directives
196                        .iter()
197                        .any(|directive| matches!(directive, FunctionHeaderDirective::NoReturn { .. }))
198                    {
199                        let noreturn_span = dot_span.start..name_span.end;
200                        directives.push(FunctionHeaderDirective::NoReturn { span: noreturn_span });
201                    }
202                    if stream
203                        .consume_if(|token| matches!(token, PtxToken::Semicolon))
204                        .is_some()
205                    {
206                        return Ok(true);
207                    }
208                } else {
209                    stream.set_position(saved_pos);
210                }
211            } else {
212                stream.set_position(saved_pos);
213            }
214        }
215    }
216    Ok(false)
217}
218
219fn parse_argument_strings(
220    stream: &mut PtxTokenStream,
221    base_span: &Span,
222    raw_tokens: &mut Vec<PtxToken>,
223) -> Result<Vec<String>, PtxParseError> {
224    let mut arguments = Vec::new();
225    let mut current_tokens: Vec<PtxToken> = Vec::new();
226    let mut current_span = base_span.clone();
227
228    while !stream.check(|token| matches!(token, PtxToken::Semicolon)) {
229        let (token, span) = stream.consume()?;
230        raw_tokens.push(token.clone());
231        if matches!(token, PtxToken::Comma) {
232            if !current_tokens.is_empty() {
233                let text = tokens_to_string(&current_tokens, &current_span)?;
234                arguments.push(text);
235                current_tokens.clear();
236            } else {
237                arguments.push(String::new());
238            }
239        } else {
240            if current_tokens.is_empty() {
241                current_span = span.clone();
242            }
243            current_tokens.push(token.clone());
244        }
245    }
246
247    if !current_tokens.is_empty() {
248        let text = tokens_to_string(&current_tokens, &current_span)?;
249        arguments.push(text);
250    }
251
252    stream.expect(&PtxToken::Semicolon)?;
253    raw_tokens.push(PtxToken::Semicolon);
254    Ok(arguments)
255}
256
257fn parse_block_statements(
258    stream: &mut PtxTokenStream,
259) -> Result<Vec<FunctionStatement>, PtxParseError> {
260    let mut statements = Vec::new();
261
262    loop {
263        if stream.check(|token| matches!(token, PtxToken::RBrace)) {
264            stream.consume()?;
265            break;
266        }
267
268        if stream.is_at_end() {
269            return Err(PtxParseError {
270                kind: ParseErrorKind::UnexpectedEof,
271                span: 0..0,
272            });
273        }
274
275        let position = stream.position();
276        match FunctionStatement::parse(stream) {
277            Ok(statement) => statements.push(statement),
278            Err(_err) => {
279                stream.set_position(position);
280                let (tokens, span) = collect_body_tokens(stream)?;
281                if !tokens.is_empty() {
282                    let pragma = PragmaDirective {
283                        arguments: Vec::new(),
284                        comment: None,
285                        span: span.clone(),
286                    };
287                    let directive = StatementDirective::Pragma { directive: pragma, span: span.clone() };
288                    statements.push(FunctionStatement::Directive { directive, span });
289                }
290                return Ok(statements);
291            }
292        }
293    }
294
295    Ok(statements)
296}
297
298fn collect_body_tokens(
299    stream: &mut PtxTokenStream,
300) -> Result<(Vec<PtxToken>, Span), PtxParseError> {
301    let mut tokens = Vec::new();
302    let mut depth = 1usize;
303    let mut first_span: Option<Span> = None;
304
305    while depth > 0 {
306        let (token, span) = stream.consume()?;
307        if first_span.is_none() {
308            first_span = Some(span.clone());
309        }
310        match token {
311            PtxToken::LBrace => {
312                depth += 1;
313                tokens.push(token.clone());
314            }
315            PtxToken::RBrace => {
316                depth -= 1;
317                if depth == 0 {
318                    break;
319                }
320                tokens.push(token.clone());
321            }
322            _ => tokens.push(token.clone()),
323        }
324    }
325
326    Ok((tokens, first_span.unwrap_or(0..0)))
327}
328
329impl PtxParser for FunctionBody {
330    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
331        match stream.peek() {
332            Ok((PtxToken::Semicolon, _)) => {
333                stream.consume()?;
334                Ok(FunctionBody::default())
335            }
336            Ok((PtxToken::LBrace, _)) => {
337                stream.consume()?;
338                let mut body = FunctionBody::default();
339                loop {
340                    if stream.check(|token| matches!(token, PtxToken::RBrace)) {
341                        stream.consume()?;
342                        break;
343                    }
344
345                    if stream.is_at_end() {
346                        return Err(PtxParseError {
347                            kind: ParseErrorKind::UnexpectedEof,
348                            span: 0..0,
349                        });
350                    }
351
352                    let position = stream.position();
353                    match FunctionStatement::parse(stream) {
354                        Ok(statement) => body.statements.push(statement),
355                        Err(_) => {
356                            stream.set_position(position);
357                            let (tokens, span) = collect_body_tokens(stream)?;
358                            if !tokens.is_empty() {
359                                let pragma = PragmaDirective {
360                                    arguments: Vec::new(),
361                                    comment: None,
362                                    span: span.clone(),
363                                };
364                                let directive = StatementDirective::Pragma { directive: pragma, span: span.clone() };
365                                body.statements.push(FunctionStatement::Directive { directive, span });
366                            }
367                            return Ok(body);
368                        }
369                    }
370                }
371
372                Ok(body)
373            }
374            Ok((token, _)) => {
375                let span = stream.peek()?.1.clone();
376                Err(unexpected_value(
377                    span,
378                    &[";", ".noreturn", "{"],
379                    format!("{token:?}"),
380                ))
381            }
382            Err(_) => Err(PtxParseError {
383                kind: ParseErrorKind::UnexpectedEof,
384                span: 0..0,
385            }),
386        }
387    }
388}
389
390impl PtxParser for EntryFunction {
391    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
392        let start_pos = stream.position();
393        let mut directives = FunctionHeaderDirective::parse_list(stream)?;
394        expect_directive_value(stream, "entry")?;
395        let (name, _) = stream.expect_identifier()?;
396        let params = parse_parameter_list(stream)?;
397        let body = if parse_optional_noreturn(stream, &mut directives)? {
398            FunctionBody::default()
399        } else {
400            FunctionBody::parse(stream)?
401        };
402        let end_pos = stream.position();
403        let span = start_pos.char_offset..end_pos.char_offset;
404        Ok(EntryFunction {
405            name,
406            directives,
407            params,
408            body,
409            span,
410        })
411    }
412}
413
414impl PtxParser for FuncFunction {
415    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
416        let start_pos = stream.position();
417        let mut directives = FunctionHeaderDirective::parse_list(stream)?;
418        expect_directive_value(stream, "func")?;
419
420        let return_param = parse_return_parameter(stream)?;
421
422        let (name, _) = stream.expect_identifier()?;
423        let params = parse_parameter_list(stream)?;
424        let body = if parse_optional_noreturn(stream, &mut directives)? {
425            FunctionBody::default()
426        } else {
427            FunctionBody::parse(stream)?
428        };
429        let end_pos = stream.position();
430        let span = start_pos.char_offset..end_pos.char_offset;
431        Ok(FuncFunction {
432            name,
433            directives,
434            return_param,
435            params,
436            body,
437            span,
438        })
439    }
440}
441
442impl PtxParser for FunctionAlias {
443    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
444        let start_pos = stream.position();
445        let _ = FunctionHeaderDirective::parse_list(stream)?;
446        expect_directive_value(stream, "alias")?;
447        let (alias, _) = stream.expect_identifier()?;
448        stream.expect(&PtxToken::Comma)?;
449        let (target, _) = stream.expect_identifier()?;
450        stream.expect(&PtxToken::Semicolon)?;
451        let end_pos = stream.position();
452        let span = start_pos.char_offset..end_pos.char_offset;
453        Ok(FunctionAlias {
454            alias,
455            target,
456            span,
457        })
458    }
459}
460
461impl PtxParser for FunctionKernelDirective {
462    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
463        let position = stream.position();
464        if let Ok(entry) = EntryFunction::parse(stream) {
465            let span = entry.span.clone();
466            return Ok(FunctionKernelDirective::Entry { function: entry, span });
467        }
468        stream.set_position(position);
469        if let Ok(func) = FuncFunction::parse(stream) {
470            let span = func.span.clone();
471            return Ok(FunctionKernelDirective::Func { function: func, span });
472        }
473        stream.set_position(position);
474        let alias = FunctionAlias::parse(stream)?;
475        let span = alias.span.clone();
476        Ok(FunctionKernelDirective::Alias { alias, span })
477    }
478}
479
480impl PtxParser for FunctionStatement {
481    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
482        let start_pos = stream.position();
483        if let Some(label) = try_parse_label(stream)? {
484            let end_pos = stream.position();
485            let span = start_pos.char_offset..end_pos.char_offset;
486            return Ok(FunctionStatement::Label { name: label, span });
487        }
488
489        if peek_directive(stream)?.is_some() {
490            let directive = StatementDirective::parse(stream)?;
491            let span = directive.span();
492            return Ok(FunctionStatement::Directive { directive, span });
493        }
494
495        if stream.check(|token| matches!(token, PtxToken::LBrace)) {
496            let (_, brace_span) = stream.consume()?;
497            let block_statements = parse_block_statements(stream)?;
498            let end_pos = stream.position();
499            let span = brace_span.start..end_pos.char_offset;
500            return Ok(FunctionStatement::Block { statements: block_statements, span });
501        }
502
503        let instruction = Instruction::parse(stream)?;
504        let span = instruction.span.clone();
505        Ok(FunctionStatement::Instruction { instruction, span })
506    }
507}
508
509impl PtxParser for StatementDirective {
510    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
511        let (name, span) = if let Some(value) = peek_directive(stream)? {
512            value
513        } else {
514            let (token, span) = stream
515                .peek()
516                .map(|(token, span)| (token.clone(), span.clone()))?;
517            return Err(unexpected_value(
518                span,
519                &["function directive"],
520                format!("{token:?}"),
521            ));
522        };
523
524        match name.as_str() {
525            "reg" => {
526                let directive = RegisterDirective::parse(stream)?;
527                let span = directive.span.clone();
528                Ok(StatementDirective::Reg { directive, span })
529            }
530            "local" => {
531                let directive = VariableDirective::parse(stream)?;
532                let span = directive.span.clone();
533                Ok(StatementDirective::Local { directive, span })
534            }
535            "param" => {
536                let directive = VariableDirective::parse(stream)?;
537                let span = directive.span.clone();
538                Ok(StatementDirective::Param { directive, span })
539            }
540            "shared" => {
541                let directive = VariableDirective::parse(stream)?;
542                let span = directive.span.clone();
543                Ok(StatementDirective::Shared { directive, span })
544            }
545            "pragma" => {
546                let (_, directive_span) = stream.expect_directive()?;
547                let mut raw_tokens =
548                    vec![PtxToken::Dot, PtxToken::Identifier("pragma".to_string())];
549                let arguments = parse_argument_strings(stream, &directive_span, &mut raw_tokens)?;
550                let end_pos = stream.position();
551                let span = directive_span.start..end_pos.char_offset;
552                let pragma = PragmaDirective {
553                    arguments,
554                    comment: None,
555                    span: span.clone(),
556                };
557                Ok(StatementDirective::Pragma { directive: pragma, span })
558            }
559            "loc" => {
560                let (_, directive_span) = stream.expect_directive()?;
561                let (file_token, file_span) = stream.consume()?;
562                let file_index = match file_token {
563                    PtxToken::DecimalInteger(value) => value.parse::<u32>().map_err(|_| {
564                        invalid_literal(
565                            file_span.clone(),
566                            "expected 32-bit unsigned integer literal",
567                        )
568                    })?,
569                    ref other => {
570                        return Err(unexpected_value(
571                            file_span.clone(),
572                            &["decimal literal"],
573                            format!("{other:?}"),
574                        ));
575                    }
576                };
577
578                let (line_token, line_span) = stream.consume()?;
579                let line = match line_token {
580                    PtxToken::DecimalInteger(value) => value.parse::<u32>().map_err(|_| {
581                        invalid_literal(
582                            line_span.clone(),
583                            "expected 32-bit unsigned integer literal",
584                        )
585                    })?,
586                    ref other => {
587                        return Err(unexpected_value(
588                            line_span.clone(),
589                            &["decimal literal"],
590                            format!("{other:?}"),
591                        ));
592                    }
593                };
594
595                let (column_token, column_span) = stream.consume()?;
596                let column = match column_token {
597                    PtxToken::DecimalInteger(value) => value.parse::<u32>().map_err(|_| {
598                        invalid_literal(
599                            column_span.clone(),
600                            "expected 32-bit unsigned integer literal",
601                        )
602                    })?,
603                    ref other => {
604                        return Err(unexpected_value(
605                            column_span.clone(),
606                            &["decimal literal"],
607                            format!("{other:?}"),
608                        ));
609                    }
610                };
611
612                let options = Vec::new();
613                let end_pos = if stream
614                    .consume_if(|token| matches!(token, PtxToken::Semicolon))
615                    .is_some()
616                {
617                    stream.position()
618                } else {
619                    stream.position()
620                };
621
622                let span = directive_span.start..end_pos.char_offset;
623                let loc = LocationDirective {
624                    file_index,
625                    line,
626                    column,
627                    options,
628                    comment: None,
629                    span: span.clone(),
630                };
631                Ok(StatementDirective::Loc { directive: loc, span })
632            }
633            "dwarf" => {
634                let (_, directive_span) = stream.expect_directive()?;
635                let mut raw_tokens = vec![PtxToken::Dot, PtxToken::Identifier("dwarf".to_string())];
636                let (keyword, keyword_span) = stream.expect_identifier()?;
637                raw_tokens.push(PtxToken::Identifier(keyword.clone()));
638                let arguments = parse_argument_strings(stream, &keyword_span, &mut raw_tokens)?;
639                let end_pos = stream.position();
640                let span = directive_span.start..end_pos.char_offset;
641                let dwarf = DwarfDirective {
642                    keyword,
643                    arguments,
644                    comment: None,
645                    span: span.clone(),
646                };
647                Ok(StatementDirective::Dwarf { directive: dwarf, span })
648            }
649            "section" => {
650                let (_, directive_span) = stream.expect_directive()?;
651                let mut raw_tokens =
652                    vec![PtxToken::Dot, PtxToken::Identifier("section".to_string())];
653                let arguments = parse_argument_strings(stream, &directive_span, &mut raw_tokens)?;
654                let mut iter = arguments.into_iter();
655                let name_str = iter.next().ok_or_else(|| {
656                    unexpected_value(directive_span.clone(), &["section name"], "".to_string())
657                })?;
658                let end_pos = stream.position();
659                let span = directive_span.start..end_pos.char_offset;
660                let section = StatementSectionDirective {
661                    name: name_str,
662                    arguments: iter.collect(),
663                    comment: None,
664                    span: span.clone(),
665                };
666                Ok(StatementDirective::Section { directive: section, span })
667            }
668            other => Err(unexpected_value(
669                span,
670                &[
671                    ".reg", ".local", ".param", ".shared", ".pragma", ".loc", ".dwarf", ".section",
672                ],
673                format!(".{other}"),
674            )),
675        }
676    }
677}
678
679impl PtxParser for RegisterDirective {
680    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
681        let start_pos = stream.position();
682        expect_directive_value(stream, "reg")?;
683
684        let ty = if stream.check(|token| matches!(token, PtxToken::Dot)) {
685            let (directive, _) = stream.expect_directive()?;
686            Some(directive)
687        } else {
688            None
689        };
690
691        let (name, _) = if stream.check(|token| matches!(token, PtxToken::Register(_))) {
692            parse_register_name(stream)?
693        } else {
694            stream.expect_identifier()?
695        };
696
697        let range = parse_register_range(stream)?;
698        stream.expect(&PtxToken::Semicolon)?;
699        let end_pos = stream.position();
700        let span = start_pos.char_offset..end_pos.char_offset;
701
702        Ok(RegisterDirective {
703            name,
704            ty,
705            range,
706            comment: None,
707            span,
708        })
709    }
710}
711
712impl PtxParser for FunctionDim3 {
713    fn parse(_stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
714        Err(unexpected_value(
715            0..0,
716            &["dimension literal"],
717            "parsing function dimension directives is not supported yet".to_string(),
718        ))
719    }
720}