ptx_parser/parser/
function.rs

1use crate::parser::common::{
2    invalid_literal, parse_register_name, parse_u64_literal, try_parse_label,
3};
4use crate::r#type::common::{CodeLinkage, Instruction};
5use crate::unlexer::PtxUnlexer;
6use crate::{
7    lexer::{PtxToken, tokenize},
8    parser::{
9        ParseErrorKind, PtxParseError, PtxParser, PtxTokenStream, Span, expect_directive_value,
10        peek_directive, unexpected_value,
11    },
12    r#type::{
13        function::{
14            DwarfDirective, EntryFunction, FuncFunction, FunctionAlias, FunctionBody, FunctionDim3,
15            FunctionHeaderDirective, FunctionKernelDirective, FunctionStatement, LocationDirective,
16            PragmaDirective, RegisterDirective, StatementDirective, StatementSectionDirective,
17        },
18        variable::VariableDirective,
19    },
20};
21
22impl FunctionHeaderDirective {
23    fn parse_list(stream: &mut PtxTokenStream) -> Result<Vec<Self>, PtxParseError> {
24        let mut directives = Vec::new();
25        loop {
26            let Some((name, span)) = peek_directive(stream)? else {
27                break;
28            };
29            match name.as_str() {
30                "visible" | "extern" | "weak" => {
31                    let linkage = CodeLinkage::parse(stream)?;
32                    directives.push(FunctionHeaderDirective::Linkage(linkage));
33                }
34                "entry" | "func" | "alias" => break,
35                other => {
36                    return Err(unexpected_value(
37                        span,
38                        &[".visible", ".extern", ".weak", ".entry", ".func", ".alias"],
39                        format!(".{other}"),
40                    ));
41                }
42            }
43        }
44        Ok(directives)
45    }
46}
47
48fn parse_register_range(stream: &mut PtxTokenStream) -> Result<Option<u32>, PtxParseError> {
49    if stream
50        .consume_if(|token| matches!(token, PtxToken::LAngle))
51        .is_none()
52    {
53        return Ok(None);
54    }
55
56    let (value, span) = parse_u64_literal(stream)?;
57    if value > u32::MAX as u64 {
58        return Err(invalid_literal(
59            span.clone(),
60            "register range exceeds u32::MAX",
61        ));
62    }
63    stream.expect(&PtxToken::RAngle)?;
64    Ok(Some(value as u32))
65}
66
67fn tokens_to_string(tokens: &[PtxToken], span: &Span) -> Result<String, PtxParseError> {
68    PtxUnlexer::to_string(tokens)
69        .map_err(|_| invalid_literal(span.clone(), "failed to serialize token sequence"))
70}
71
72fn parse_parameter_tokens(
73    tokens: &[PtxToken],
74    span: &Span,
75) -> Result<VariableDirective, PtxParseError> {
76    let serialized = tokens_to_string(tokens, span)?;
77    let source = format!("{};", serialized);
78    let tokenized = tokenize(&source)
79        .map_err(|_| invalid_literal(span.clone(), "failed to tokenize function parameter"))?;
80    let mut temp_stream = PtxTokenStream::new(&tokenized);
81    let mut directive = VariableDirective::parse(&mut temp_stream)?;
82    directive.raw = serialized;
83    Ok(directive)
84}
85
86fn collect_parameter_tokens(
87    stream: &mut PtxTokenStream,
88) -> Result<(Vec<PtxToken>, Span), PtxParseError> {
89    let (first_token, first_span) = stream.peek()?;
90    if matches!(first_token, PtxToken::Comma | PtxToken::RParen) {
91        return Err(unexpected_value(
92            first_span.clone(),
93            &["function parameter"],
94            format!("{first_token:?}"),
95        ));
96    }
97
98    let mut tokens = Vec::new();
99    let mut paren_depth = 0usize;
100    let mut bracket_depth = 0usize;
101
102    loop {
103        let (next_token, _) = stream.peek()?;
104        if paren_depth == 0 && bracket_depth == 0 {
105            if matches!(next_token, PtxToken::Comma | PtxToken::RParen) {
106                break;
107            }
108        }
109
110        let (token, _) = stream.consume()?;
111        match token {
112            PtxToken::LParen => paren_depth += 1,
113            PtxToken::RParen => paren_depth = paren_depth.saturating_sub(1),
114            PtxToken::LBracket => bracket_depth += 1,
115            PtxToken::RBracket => bracket_depth = bracket_depth.saturating_sub(1),
116            _ => {}
117        }
118        tokens.push(token.clone());
119    }
120
121    Ok((tokens, first_span.clone()))
122}
123
124fn parse_parameter(stream: &mut PtxTokenStream) -> Result<VariableDirective, PtxParseError> {
125    let (tokens, span) = collect_parameter_tokens(stream)?;
126    if tokens.is_empty() {
127        return Err(unexpected_value(
128            span.clone(),
129            &["function parameter"],
130            "".to_string(),
131        ));
132    }
133    parse_parameter_tokens(&tokens, &span)
134}
135
136fn parse_parameter_list(
137    stream: &mut PtxTokenStream,
138) -> Result<Vec<VariableDirective>, PtxParseError> {
139    stream.expect(&PtxToken::LParen)?;
140    if stream
141        .consume_if(|token| matches!(token, PtxToken::RParen))
142        .is_some()
143    {
144        return Ok(Vec::new());
145    }
146
147    let mut params = Vec::new();
148    loop {
149        let param = parse_parameter(stream)?;
150        params.push(param);
151        if stream
152            .consume_if(|token| matches!(token, PtxToken::Comma))
153            .is_none()
154        {
155            break;
156        }
157    }
158    stream.expect(&PtxToken::RParen)?;
159    Ok(params)
160}
161
162fn parse_return_parameter(
163    stream: &mut PtxTokenStream,
164) -> Result<Option<VariableDirective>, PtxParseError> {
165    if stream
166        .consume_if(|token| matches!(token, PtxToken::LParen))
167        .is_none()
168    {
169        return Ok(None);
170    }
171
172    if stream
173        .consume_if(|token| matches!(token, PtxToken::RParen))
174        .is_some()
175    {
176        return Ok(None);
177    }
178
179    let param = parse_parameter(stream)?;
180    stream.expect(&PtxToken::RParen)?;
181    Ok(Some(param))
182}
183
184fn parse_optional_noreturn(
185    stream: &mut PtxTokenStream,
186    directives: &mut Vec<FunctionHeaderDirective>,
187) -> Result<bool, PtxParseError> {
188    if let Some((token, _)) = stream.peek().ok() {
189        if let PtxToken::Dot = token {
190            // Check if it's a directive
191            let saved_pos = stream.position();
192            stream.consume()?; // consume dot
193            if let Ok((name, _)) = stream.expect_identifier() {
194                if name == "noreturn" {
195                    if !directives
196                        .iter()
197                        .any(|directive| matches!(directive, FunctionHeaderDirective::NoReturn))
198                    {
199                        directives.push(FunctionHeaderDirective::NoReturn);
200                    }
201                    if stream
202                        .consume_if(|token| matches!(token, PtxToken::Semicolon))
203                        .is_some()
204                    {
205                        return Ok(true);
206                    }
207                } else {
208                    stream.set_position(saved_pos);
209                }
210            } else {
211                stream.set_position(saved_pos);
212            }
213        }
214    }
215    Ok(false)
216}
217
218fn parse_argument_strings(
219    stream: &mut PtxTokenStream,
220    base_span: &Span,
221    raw_tokens: &mut Vec<PtxToken>,
222) -> Result<Vec<String>, PtxParseError> {
223    let mut arguments = Vec::new();
224    let mut current_tokens: Vec<PtxToken> = Vec::new();
225    let mut current_span = base_span.clone();
226
227    while !stream.check(|token| matches!(token, PtxToken::Semicolon)) {
228        let (token, span) = stream.consume()?;
229        raw_tokens.push(token.clone());
230        if matches!(token, PtxToken::Comma) {
231            if !current_tokens.is_empty() {
232                let text = tokens_to_string(&current_tokens, &current_span)?;
233                arguments.push(text);
234                current_tokens.clear();
235            } else {
236                arguments.push(String::new());
237            }
238        } else {
239            if current_tokens.is_empty() {
240                current_span = span.clone();
241            }
242            current_tokens.push(token.clone());
243        }
244    }
245
246    if !current_tokens.is_empty() {
247        let text = tokens_to_string(&current_tokens, &current_span)?;
248        arguments.push(text);
249    }
250
251    stream.expect(&PtxToken::Semicolon)?;
252    raw_tokens.push(PtxToken::Semicolon);
253    Ok(arguments)
254}
255
256fn parse_block_statements(
257    stream: &mut PtxTokenStream,
258) -> Result<Vec<FunctionStatement>, PtxParseError> {
259    let mut statements = Vec::new();
260
261    loop {
262        if stream.check(|token| matches!(token, PtxToken::RBrace)) {
263            stream.consume()?;
264            break;
265        }
266
267        if stream.is_at_end() {
268            return Err(PtxParseError {
269                kind: ParseErrorKind::UnexpectedEof,
270                span: 0..0,
271            });
272        }
273
274        let position = stream.position();
275        match FunctionStatement::parse(stream) {
276            Ok(statement) => statements.push(statement),
277            Err(_err) => {
278                stream.set_position(position);
279                let (tokens, span) = collect_body_tokens(stream)?;
280                if !tokens.is_empty() {
281                    let raw = tokens_to_string(&tokens, &span)?;
282                    statements.push(FunctionStatement::Directive(StatementDirective::Pragma(
283                        PragmaDirective {
284                            arguments: Vec::new(),
285                            comment: None,
286                            raw,
287                        },
288                    )));
289                }
290                return Ok(statements);
291            }
292        }
293    }
294
295    Ok(statements)
296}
297
298fn collect_body_tokens(
299    stream: &mut PtxTokenStream,
300) -> Result<(Vec<PtxToken>, Span), PtxParseError> {
301    let mut tokens = Vec::new();
302    let mut depth = 1usize;
303    let mut first_span: Option<Span> = None;
304
305    while depth > 0 {
306        let (token, span) = stream.consume()?;
307        if first_span.is_none() {
308            first_span = Some(span.clone());
309        }
310        match token {
311            PtxToken::LBrace => {
312                depth += 1;
313                tokens.push(token.clone());
314            }
315            PtxToken::RBrace => {
316                depth -= 1;
317                if depth == 0 {
318                    break;
319                }
320                tokens.push(token.clone());
321            }
322            _ => tokens.push(token.clone()),
323        }
324    }
325
326    Ok((tokens, first_span.unwrap_or(0..0)))
327}
328
329impl PtxParser for FunctionBody {
330    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
331        match stream.peek() {
332            Ok((PtxToken::Semicolon, _)) => {
333                stream.consume()?;
334                Ok(FunctionBody::default())
335            }
336            Ok((PtxToken::LBrace, _)) => {
337                stream.consume()?;
338                let mut body = FunctionBody::default();
339                loop {
340                    if stream.check(|token| matches!(token, PtxToken::RBrace)) {
341                        stream.consume()?;
342                        break;
343                    }
344
345                    if stream.is_at_end() {
346                        return Err(PtxParseError {
347                            kind: ParseErrorKind::UnexpectedEof,
348                            span: 0..0,
349                        });
350                    }
351
352                    let position = stream.position();
353                    match FunctionStatement::parse(stream) {
354                        Ok(statement) => body.statements.push(statement),
355                        Err(_) => {
356                            stream.set_position(position);
357                            let (tokens, span) = collect_body_tokens(stream)?;
358                            if !tokens.is_empty() {
359                                let raw = tokens_to_string(&tokens, &span)?;
360                                body.statements.push(FunctionStatement::Directive(
361                                    StatementDirective::Pragma(PragmaDirective {
362                                        arguments: Vec::new(),
363                                        comment: None,
364                                        raw,
365                                    }),
366                                ));
367                            }
368                            return Ok(body);
369                        }
370                    }
371                }
372
373                Ok(body)
374            }
375            Ok((token, _)) => {
376                let span = stream.peek()?.1.clone();
377                Err(unexpected_value(
378                    span,
379                    &[";", ".noreturn", "{"],
380                    format!("{token:?}"),
381                ))
382            }
383            Err(_) => Err(PtxParseError {
384                kind: ParseErrorKind::UnexpectedEof,
385                span: 0..0,
386            }),
387        }
388    }
389}
390
391impl PtxParser for EntryFunction {
392    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
393        let mut directives = FunctionHeaderDirective::parse_list(stream)?;
394        expect_directive_value(stream, "entry")?;
395        let (name, _) = stream.expect_identifier()?;
396        let params = parse_parameter_list(stream)?;
397        let body = if parse_optional_noreturn(stream, &mut directives)? {
398            FunctionBody::default()
399        } else {
400            FunctionBody::parse(stream)?
401        };
402        Ok(EntryFunction {
403            name,
404            directives,
405            params,
406            body,
407        })
408    }
409}
410
411impl PtxParser for FuncFunction {
412    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
413        let mut directives = FunctionHeaderDirective::parse_list(stream)?;
414        expect_directive_value(stream, "func")?;
415
416        let return_param = parse_return_parameter(stream)?;
417
418        let (name, _) = stream.expect_identifier()?;
419        let params = parse_parameter_list(stream)?;
420        let body = if parse_optional_noreturn(stream, &mut directives)? {
421            FunctionBody::default()
422        } else {
423            FunctionBody::parse(stream)?
424        };
425        Ok(FuncFunction {
426            name,
427            directives,
428            return_param,
429            params,
430            body,
431        })
432    }
433}
434
435impl PtxParser for FunctionAlias {
436    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
437        let _ = FunctionHeaderDirective::parse_list(stream)?;
438        expect_directive_value(stream, "alias")?;
439        let (alias, _) = stream.expect_identifier()?;
440        stream.expect(&PtxToken::Comma)?;
441        let (target, _) = stream.expect_identifier()?;
442        stream.expect(&PtxToken::Semicolon)?;
443        Ok(FunctionAlias {
444            alias,
445            target,
446            raw: String::new(),
447        })
448    }
449}
450
451impl PtxParser for FunctionKernelDirective {
452    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
453        let position = stream.position();
454        if let Ok(entry) = EntryFunction::parse(stream) {
455            return Ok(FunctionKernelDirective::Entry(entry));
456        }
457        stream.set_position(position);
458        if let Ok(func) = FuncFunction::parse(stream) {
459            return Ok(FunctionKernelDirective::Func(func));
460        }
461        stream.set_position(position);
462        let alias = FunctionAlias::parse(stream)?;
463        Ok(FunctionKernelDirective::Alias(alias))
464    }
465}
466
467impl PtxParser for FunctionStatement {
468    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
469        if let Some(label) = try_parse_label(stream)? {
470            return Ok(FunctionStatement::Label(label));
471        }
472
473        if peek_directive(stream)?.is_some() {
474            let directive = StatementDirective::parse(stream)?;
475            return Ok(FunctionStatement::Directive(directive));
476        }
477
478        if stream.check(|token| matches!(token, PtxToken::LBrace)) {
479            stream.consume()?;
480            let block_statements = parse_block_statements(stream)?;
481            return Ok(FunctionStatement::Block(block_statements));
482        }
483
484        let instruction = Instruction::parse(stream)?;
485        Ok(FunctionStatement::Instruction(instruction))
486    }
487}
488
489impl PtxParser for StatementDirective {
490    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
491        let (name, span) = if let Some(value) = peek_directive(stream)? {
492            value
493        } else {
494            let (token, span) = stream
495                .peek()
496                .map(|(token, span)| (token.clone(), span.clone()))?;
497            return Err(unexpected_value(
498                span,
499                &["function directive"],
500                format!("{token:?}"),
501            ));
502        };
503
504        match name.as_str() {
505            "reg" => RegisterDirective::parse(stream).map(StatementDirective::Reg),
506            "local" => VariableDirective::parse(stream).map(StatementDirective::Local),
507            "param" => VariableDirective::parse(stream).map(StatementDirective::Param),
508            "shared" => VariableDirective::parse(stream).map(StatementDirective::Shared),
509            "pragma" => {
510                let (_, directive_span) = stream.expect_directive()?;
511                let mut raw_tokens =
512                    vec![PtxToken::Dot, PtxToken::Identifier("pragma".to_string())];
513                let arguments = parse_argument_strings(stream, &directive_span, &mut raw_tokens)?;
514                let raw = tokens_to_string(&raw_tokens, &directive_span)?;
515                Ok(StatementDirective::Pragma(PragmaDirective {
516                    arguments,
517                    comment: None,
518                    raw,
519                }))
520            }
521            "loc" => {
522                let (_, directive_span) = stream.expect_directive()?;
523                let mut raw_tokens = vec![PtxToken::Dot, PtxToken::Identifier("loc".to_string())];
524                let (file_token, file_span) = stream.consume()?;
525                raw_tokens.push(file_token.clone());
526                let file_index = match file_token {
527                    PtxToken::DecimalInteger(value) => value.parse::<u32>().map_err(|_| {
528                        invalid_literal(
529                            file_span.clone(),
530                            "expected 32-bit unsigned integer literal",
531                        )
532                    })?,
533                    ref other => {
534                        return Err(unexpected_value(
535                            file_span.clone(),
536                            &["decimal literal"],
537                            format!("{other:?}"),
538                        ));
539                    }
540                };
541
542                let (line_token, line_span) = stream.consume()?;
543                raw_tokens.push(line_token.clone());
544                let line = match line_token {
545                    PtxToken::DecimalInteger(value) => value.parse::<u32>().map_err(|_| {
546                        invalid_literal(
547                            line_span.clone(),
548                            "expected 32-bit unsigned integer literal",
549                        )
550                    })?,
551                    ref other => {
552                        return Err(unexpected_value(
553                            line_span.clone(),
554                            &["decimal literal"],
555                            format!("{other:?}"),
556                        ));
557                    }
558                };
559
560                let (column_token, column_span) = stream.consume()?;
561                raw_tokens.push(column_token.clone());
562                let column = match column_token {
563                    PtxToken::DecimalInteger(value) => value.parse::<u32>().map_err(|_| {
564                        invalid_literal(
565                            column_span.clone(),
566                            "expected 32-bit unsigned integer literal",
567                        )
568                    })?,
569                    ref other => {
570                        return Err(unexpected_value(
571                            column_span.clone(),
572                            &["decimal literal"],
573                            format!("{other:?}"),
574                        ));
575                    }
576                };
577
578                let options = Vec::new();
579                if stream
580                    .consume_if(|token| matches!(token, PtxToken::Semicolon))
581                    .is_some()
582                {
583                    raw_tokens.push(PtxToken::Semicolon);
584                }
585
586                let raw = tokens_to_string(&raw_tokens, &directive_span)?;
587                Ok(StatementDirective::Loc(LocationDirective {
588                    file_index,
589                    line,
590                    column,
591                    options,
592                    comment: None,
593                    raw,
594                }))
595            }
596            "dwarf" => {
597                let (_, directive_span) = stream.expect_directive()?;
598                let mut raw_tokens = vec![PtxToken::Dot, PtxToken::Identifier("dwarf".to_string())];
599                let (keyword, keyword_span) = stream.expect_identifier()?;
600                raw_tokens.push(PtxToken::Identifier(keyword.clone()));
601                let arguments = parse_argument_strings(stream, &keyword_span, &mut raw_tokens)?;
602                let raw = tokens_to_string(&raw_tokens, &directive_span)?;
603                Ok(StatementDirective::Dwarf(DwarfDirective {
604                    keyword,
605                    arguments,
606                    comment: None,
607                    raw,
608                }))
609            }
610            "section" => {
611                let (_, directive_span) = stream.expect_directive()?;
612                let mut raw_tokens =
613                    vec![PtxToken::Dot, PtxToken::Identifier("section".to_string())];
614                let arguments = parse_argument_strings(stream, &directive_span, &mut raw_tokens)?;
615                let mut iter = arguments.into_iter();
616                let name_str = iter.next().ok_or_else(|| {
617                    unexpected_value(directive_span.clone(), &["section name"], "".to_string())
618                })?;
619                let raw = tokens_to_string(&raw_tokens, &directive_span)?;
620                Ok(StatementDirective::Section(StatementSectionDirective {
621                    name: name_str,
622                    arguments: iter.collect(),
623                    comment: None,
624                    raw,
625                }))
626            }
627            other => Err(unexpected_value(
628                span,
629                &[
630                    ".reg", ".local", ".param", ".shared", ".pragma", ".loc", ".dwarf", ".section",
631                ],
632                format!(".{other}"),
633            )),
634        }
635    }
636}
637
638impl PtxParser for RegisterDirective {
639    fn parse(stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
640        expect_directive_value(stream, "reg")?;
641
642        let ty = if stream.check(|token| matches!(token, PtxToken::Dot)) {
643            let (directive, _) = stream.expect_directive()?;
644            Some(directive)
645        } else {
646            None
647        };
648
649        let (name, _) = if stream.check(|token| matches!(token, PtxToken::Register(_))) {
650            parse_register_name(stream)?
651        } else {
652            stream.expect_identifier()?
653        };
654
655        let range = parse_register_range(stream)?;
656        stream.expect(&PtxToken::Semicolon)?;
657
658        Ok(RegisterDirective {
659            name,
660            ty,
661            range,
662            comment: None,
663            raw: String::new(),
664        })
665    }
666}
667
668impl PtxParser for FunctionDim3 {
669    fn parse(_stream: &mut PtxTokenStream) -> Result<Self, PtxParseError> {
670        Err(unexpected_value(
671            0..0,
672            &["dimension literal"],
673            "parsing function dimension directives is not supported yet".to_string(),
674        ))
675    }
676}