ptoxide/ast/
mod.rs

1mod lex;
2
3use std::ops::Range;
4
5use lex::LexError;
6use lex::Token;
7use thiserror::Error;
8
9#[derive(Debug, Clone, Copy)]
10pub struct SourceLocation {
11    byte: usize,
12    line: usize,
13    col: usize,
14}
15
16impl std::fmt::Display for SourceLocation {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        write!(f, "@{}:{} (byte {})", self.line, self.col, self.byte)
19    }
20}
21
22#[derive(Error, Debug)]
23pub enum ParseErr {
24    #[error("Unexpected token \"{:?}\"", .0)]
25    UnexpectedToken(String, SourceLocation),
26    #[error("Unexpected end of file")]
27    UnexpectedEof,
28    #[error("Lex error \"{:?}\" at {:?}", .0, .1)]
29    LexError(LexError, SourceLocation),
30    #[error("Unknown token \"{:?}\" at {:?}", .0, .1)]
31    UnknownToken(String, SourceLocation),
32}
33
34type ParseResult<T> = Result<T, ParseErr>;
35
36type Ident = String;
37
38#[derive(Debug, Clone, Copy, PartialEq)]
39pub struct Version {
40    major: u32,
41    minor: u32,
42}
43
44#[derive(Clone, Debug)]
45pub struct Pragma(String);
46
47#[derive(Debug)]
48pub enum AddressSize {
49    Adr32,
50    Adr64,
51    Other,
52}
53
54#[derive(Debug)]
55pub struct Module(pub Vec<Directive>);
56
57#[derive(Debug)]
58pub struct Function {
59    pub ident: Ident,
60    pub visible: bool,
61    pub entry: bool,
62    pub noreturn: bool,
63    pub return_param: Option<FunctionParam>,
64    pub params: Vec<FunctionParam>,
65    pub body: Box<Statement>,
66}
67
68#[derive(Debug)]
69pub struct FunctionParam {
70    pub ident: Ident,
71    pub ty: Type,
72    pub alignment: Option<u32>,
73    pub array_bounds: Vec<u32>,
74}
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq)]
77pub enum StateSpace {
78    Global,
79    Local,
80    Shared,
81    Register,
82    Constant,
83    Parameter,
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87pub enum Type {
88    B128,
89    B64,
90    B32,
91    B16,
92    B8,
93    U64,
94    U32,
95    U16,
96    U8,
97    S64,
98    S32,
99    S16,
100    S8,
101    F64,
102    F32,
103    F16x2,
104    F16,
105    Pred,
106}
107
108#[derive(Debug, Clone, Copy)]
109pub enum Vector {
110    V2,
111    V4,
112}
113
114#[derive(Debug, Clone, Copy)]
115pub enum SpecialReg {
116    StackPtr,
117    ThreadId,
118    ThreadIdX,
119    ThreadIdY,
120    ThreadIdZ,
121    NumThread,
122    NumThreadX,
123    NumThreadY,
124    NumThreadZ,
125    CtaId,
126    CtaIdX,
127    CtaIdY,
128    CtaIdZ,
129    NumCta,
130    NumCtaX,
131    NumCtaY,
132    NumCtaZ,
133}
134
135impl From<SpecialReg> for Operand {
136    fn from(value: SpecialReg) -> Self {
137        Operand::SpecialReg(value)
138    }
139}
140
141#[derive(Debug, Clone)]
142pub struct VarDecl {
143    pub state_space: StateSpace,
144    pub ty: Type,
145    pub vector: Option<Vector>,
146    pub ident: Ident,
147    pub alignment: Option<u32>,
148    pub array_bounds: Vec<u32>,
149    pub multiplicity: Option<u32>,
150}
151
152#[derive(Debug, Clone)]
153pub enum AddressOperand {
154    Address(Ident),
155    AddressOffset(Ident, i64),
156    AddressOffsetVar(Ident, Ident),
157    ArrayIndex(Ident, usize),
158}
159
160impl AddressOperand {
161    pub fn get_ident(&self) -> &Ident {
162        match self {
163            AddressOperand::Address(ident) => ident,
164            AddressOperand::AddressOffset(ident, _) => ident,
165            AddressOperand::AddressOffsetVar(ident, _) => ident,
166            AddressOperand::ArrayIndex(ident, _) => ident,
167        }
168    }
169}
170
171#[derive(Debug, Clone)]
172pub enum Operand {
173    SpecialReg(SpecialReg),
174    Variable(Ident),
175    Immediate(Immediate),
176    Address(AddressOperand),
177}
178
179#[derive(Debug, Clone, Copy)]
180pub enum Immediate {
181    Float32(f32),
182    Float64(f64),
183    Int64(i64),
184    UInt64(u64),
185}
186
187#[derive(Debug, Clone)]
188pub enum Guard {
189    Normal(Ident),
190    Negated(Ident),
191}
192
193#[derive(Debug)]
194pub enum Directive {
195    VarDecl(VarDecl),
196    Version(Version),
197    Target(String),
198    AddressSize(AddressSize),
199    Function(Function),
200    Pragma(Pragma),
201}
202
203#[derive(Debug, Clone)]
204pub struct Instruction {
205    pub guard: Option<Guard>,
206    pub specifier: Operation,
207    pub operands: Vec<Operand>,
208}
209
210#[derive(Debug)]
211pub enum Statement {
212    Directive(Directive),
213    Instruction(Instruction),
214    Grouping(Vec<Statement>),
215    Label(Ident),
216}
217
218#[derive(Debug, Clone, Copy)]
219pub enum PredicateOp {
220    LessThan,
221    LessThanEqual,
222    GreaterThan,
223    GreaterThanEqual,
224    Equal,
225    NotEqual,
226}
227
228#[derive(Debug, Clone, Copy)]
229pub enum MulMode {
230    Low,
231    High,
232    Wide,
233}
234
235#[derive(Debug, Clone, Copy)]
236pub enum RoundingMode {
237    NearestEvent,
238    Zero,
239    NegInf,
240    PosInf,
241}
242
243#[derive(Debug, Clone)]
244pub enum Operation {
245    Load(StateSpace, Type),
246    Store(StateSpace, Type),
247    Move(Type),
248    Add(Type),
249    Sub(Type),
250    Or(Type),
251    And(Type),
252    Not(Type),
253    FusedMulAdd(RoundingMode, Type),
254    Negate(Type),
255    Multiply(MulMode, Type),
256    MultiplyAdd(MulMode, Type),
257    Convert {
258        from: Type,
259        to: Type,
260    },
261    ConvertAddress(Type, StateSpace),
262    ConvertAddressTo(Type, StateSpace),
263    SetPredicate(PredicateOp, Type),
264    ShiftLeft(Type),
265    Call {
266        uniform: bool,
267        ident: Ident,
268        ret_param: Option<Ident>,
269        params: Vec<Ident>,
270    },
271    BarrierSync,
272    Branch,
273    Return,
274}
275
276type TokenPos<'a> = Range<usize>;
277
278struct Parser<'a> {
279    src: &'a str,
280    inner: std::iter::Peekable<logos::SpannedIter<'a, Token<'a>>>,
281}
282
283impl<'a> Parser<'a> {
284    pub fn new(src: &'a str) -> Self {
285        use logos::Logos;
286        Self {
287            src,
288            inner: Token::lexer(src).spanned().peekable(),
289        }
290    }
291
292    fn locate(&self, span: Range<usize>) -> SourceLocation {
293        let text = self.src.as_bytes();
294
295        let mut line = 1;
296        let mut col = 0;
297
298        let end = span.end.min(text.len());
299
300        for &c in &text[..end] {
301            match c {
302                b'\n' => {
303                    line += 1;
304                    col = 0;
305                },
306                b'\t' => {
307                    col = (col / 4) * 4 + 4;
308                }
309                _ => col += 1,
310            }
311        }
312
313        SourceLocation {
314            byte: span.start,
315            line,
316            col,
317        }
318    }
319
320    fn unexpected(&self, (token, pos): (Token, TokenPos)) -> ParseErr {
321        ParseErr::UnexpectedToken(token.to_string(), self.locate(pos))
322    }
323
324    fn get(&mut self) -> ParseResult<Option<(Token<'a>, TokenPos)>> {
325        match self.inner.peek().cloned() {
326            Some((Ok(tok), pos)) => Ok(Some((tok, pos))),
327            Some((Err(LexError::Unknown), pos)) => Err(ParseErr::UnknownToken(
328                self.src[pos.clone()].to_string(),
329                self.locate(pos),
330            )),
331            Some((Err(err), pos)) => Err(ParseErr::LexError(err, self.locate(pos))),
332            None => Ok(None),
333        }
334    }
335
336    fn must_get(&mut self) -> Result<(Token<'a>, TokenPos), ParseErr> {
337        self.get()?.ok_or(ParseErr::UnexpectedEof)
338    }
339
340    fn skip(&mut self) {
341        self.inner.next();
342    }
343
344    fn consume(&mut self, token: Token) -> ParseResult<()> {
345        let head = self.must_get()?;
346        if head.0 == token {
347            self.skip();
348            Ok(())
349        } else {
350            Err(self.unexpected(head))
351        }
352    }
353
354    fn consume_match(&mut self, token: Token) -> ParseResult<bool> {
355        let Some(head) = self.get()? else {
356            return Ok(false);
357        };
358        if head.0 == token {
359            self.skip();
360            Ok(true)
361        } else {
362            Ok(false)
363        }
364    }
365
366    fn pop(&mut self) -> ParseResult<Option<(Token<'a>, TokenPos)>> {
367        match self.inner.next() {
368            Some((Ok(tok), pos)) => Ok(Some((tok, pos))),
369            Some((Err(err), pos)) => Err(ParseErr::LexError(err, self.locate(pos))),
370            None => Ok(None),
371        }
372    }
373
374    fn must_pop(&mut self) -> Result<(Token<'a>, TokenPos), ParseErr> {
375        self.pop()?.ok_or(ParseErr::UnexpectedEof)
376    }
377
378    fn parse_pragma(&mut self) -> ParseResult<Pragma> {
379        self.consume(Token::Pragma)?;
380        let t = self.must_pop()?;
381        match t.0 {
382            Token::StringLiteral(s) => {
383                self.consume(Token::Semicolon)?;
384                Ok(Pragma(s.to_string()))
385            }
386            _ => Err(self.unexpected(t)),
387        }
388    }
389
390    fn parse_version(&mut self) -> ParseResult<Version> {
391        let t = self.must_pop()?;
392        match t.0 {
393            Token::Version((major, minor)) => Ok(Version { major, minor }),
394            _ => Err(self.unexpected(t)),
395        }
396    }
397
398    fn parse_target(&mut self) -> ParseResult<String> {
399        self.consume(Token::Target)?;
400        let t = self.must_pop()?;
401        match t.0 {
402            Token::Identifier(target) => Ok(target.to_string()),
403            _ => Err(self.unexpected(t)),
404        }
405    }
406
407    fn parse_address_size(&mut self) -> ParseResult<AddressSize> {
408        self.consume(Token::AddressSize)?;
409        let t = self.must_pop()?;
410        let Token::IntegerConst(size) = t.0 else {
411            return Err(self.unexpected(t));
412        };
413        match size {
414            32 => Ok(AddressSize::Adr32),
415            64 => Ok(AddressSize::Adr64),
416            _ => Ok(AddressSize::Other),
417        }
418    }
419
420    fn parse_module(&mut self) -> ParseResult<Module> {
421        let mut directives = Vec::new();
422        while self.get()?.is_some() {
423            match self.parse_directive() {
424                Ok(directive) => {
425                    directives.push(directive);
426                }
427                Err(e) => return Err(e),
428            }
429        }
430        Ok(Module(directives))
431    }
432
433    fn parse_array_bounds(&mut self) -> ParseResult<Vec<u32>> {
434        let mut bounds = Vec::new();
435        loop {
436            match self.get()? {
437                Some((Token::LeftBracket, _)) => self.skip(),
438                _ => break Ok(bounds),
439            }
440            let t = self.must_pop()?;
441            let Token::IntegerConst(bound) = t.0 else {
442                return Err(self.unexpected(t));
443            };
444            self.consume(Token::RightBracket)?;
445            // todo clean up raw casts
446            bounds.push(bound as u32);
447        }
448    }
449
450    fn parse_state_space(&mut self) -> ParseResult<StateSpace> {
451        let t = self.must_pop()?;
452        match t.0 {
453            Token::Global => Ok(StateSpace::Global),
454            Token::Local => Ok(StateSpace::Local),
455            Token::Shared => Ok(StateSpace::Shared),
456            Token::Reg => Ok(StateSpace::Register),
457            Token::Param => Ok(StateSpace::Parameter),
458            Token::Const => Ok(StateSpace::Constant),
459            _ => Err(self.unexpected(t)),
460        }
461    }
462
463    fn parse_alignment(&mut self) -> ParseResult<u32> {
464        self.consume(Token::Align)?;
465        let t = self.must_pop()?;
466        let alignment = match t.0 {
467            Token::IntegerConst(i) => i as u32,
468            _ => return Err(self.unexpected(t)),
469        };
470        Ok(alignment)
471    }
472
473    fn parse_type(&mut self) -> ParseResult<Type> {
474        let t = self.must_pop()?;
475        let ty = match t.0 {
476            Token::Bit8 => Type::B8,
477            Token::Bit16 => Type::B16,
478            Token::Bit32 => Type::B32,
479            Token::Bit64 => Type::B64,
480            Token::Bit128 => Type::B128,
481            Token::Unsigned8 => Type::U8,
482            Token::Unsigned16 => Type::U16,
483            Token::Unsigned32 => Type::U32,
484            Token::Unsigned64 => Type::U64,
485            Token::Signed8 => Type::S8,
486            Token::Signed16 => Type::S16,
487            Token::Signed32 => Type::S32,
488            Token::Signed64 => Type::S64,
489            Token::Float16 => Type::F16,
490            Token::Float16x2 => Type::F16x2,
491            Token::Float32 => Type::F32,
492            Token::Float64 => Type::F64,
493            Token::Predicate => Type::Pred,
494            _ => return Err(self.unexpected(t)),
495        };
496        Ok(ty)
497    }
498
499    fn parse_rounding_mode(&mut self) -> ParseResult<RoundingMode> {
500        let t = self.must_pop()?;
501        let mode = match t.0 {
502            Token::Rn => RoundingMode::NearestEvent,
503            Token::Rz => RoundingMode::Zero,
504            Token::Rm => RoundingMode::NegInf,
505            Token::Rp => RoundingMode::PosInf,
506            _ => return Err(self.unexpected(t)),
507        };
508        Ok(mode)
509    }
510
511    fn parse_mul_mode(&mut self) -> ParseResult<MulMode> {
512        let t = self.must_pop()?;
513        let mode = match t.0 {
514            Token::Low => MulMode::Low,
515            Token::High => MulMode::High,
516            Token::Wide => MulMode::Wide,
517            _ => return Err(self.unexpected(t)),
518        };
519        Ok(mode)
520    }
521
522    fn parse_variable(&mut self) -> ParseResult<VarDecl> {
523        let state_space = self.parse_state_space()?;
524
525        let t = self.get()?;
526        let alignment = if let Some((Token::Align, _)) = t {
527            Some(self.parse_alignment()?)
528        } else {
529            None
530        };
531
532        let t = self.get()?;
533        let vector = match t {
534            Some((Token::V2, _)) => {
535                self.skip();
536                Some(Vector::V2)
537            }
538            Some((Token::V4, _)) => {
539                self.skip();
540                Some(Vector::V4)
541            }
542            _ => None,
543        };
544
545        let ty = self.parse_type()?;
546
547        let t = self.must_pop()?;
548        let ident = match t.0 {
549            Token::Identifier(s) => s.to_string(),
550            _ => return Err(self.unexpected(t)),
551        };
552
553        let t = self.must_get()?;
554        let multiplicity = match t.0 {
555            Token::RegMultiplicity(m) => {
556                self.skip();
557                Some(m)
558            }
559            _ => None,
560        };
561
562        let array_bounds = self.parse_array_bounds()?;
563
564        self.consume(Token::Semicolon)?;
565
566        Ok(VarDecl {
567            state_space,
568            ty,
569            vector,
570            alignment,
571            array_bounds,
572            ident: ident.to_string(),
573            multiplicity,
574        })
575    }
576
577    fn parse_guard(&mut self) -> ParseResult<Guard> {
578        self.consume(Token::At)?;
579        let t = self.must_pop()?;
580        let guard = match t.0 {
581            Token::Identifier(s) => Guard::Normal(s.to_string()),
582            Token::Bang => {
583                let t = self.must_pop()?;
584                let ident = match t.0 {
585                    Token::Identifier(s) => s,
586                    _ => return Err(self.unexpected(t)),
587                };
588                Guard::Negated(ident.to_string())
589            }
590            _ => return Err(self.unexpected(t)),
591        };
592        Ok(guard)
593    }
594
595    fn parse_predicate(&mut self) -> ParseResult<PredicateOp> {
596        let t = self.must_pop()?;
597        let pred = match t.0 {
598            Token::Ge => PredicateOp::GreaterThanEqual,
599            Token::Gt => PredicateOp::GreaterThan,
600            Token::Le => PredicateOp::LessThanEqual,
601            Token::Lt => PredicateOp::LessThan,
602            Token::Eq => PredicateOp::Equal,
603            Token::Ne => PredicateOp::NotEqual,
604            _ => return Err(self.unexpected(t)),
605        };
606        Ok(pred)
607    }
608
609    fn parse_operation(&mut self) -> ParseResult<Operation> {
610        let t = self.must_pop()?;
611        match t.0 {
612            Token::Ld => {
613                let state_space = self.parse_state_space()?;
614                let ty = self.parse_type()?;
615                Ok(Operation::Load(state_space, ty))
616            }
617            Token::St => {
618                let state_space = self.parse_state_space()?;
619                let ty = self.parse_type()?;
620                Ok(Operation::Store(state_space, ty))
621            }
622            Token::Mov => {
623                let ty = self.parse_type()?;
624                Ok(Operation::Move(ty))
625            }
626            Token::Add => {
627                let ty = self.parse_type()?;
628                Ok(Operation::Add(ty))
629            }
630            Token::Sub => {
631                let ty = self.parse_type()?;
632                Ok(Operation::Sub(ty))
633            }
634            Token::Or => {
635                let ty = self.parse_type()?;
636                Ok(Operation::Or(ty))
637            }
638            Token::And => {
639                let ty = self.parse_type()?;
640                Ok(Operation::And(ty))
641            }
642            Token::Not => {
643                let ty = self.parse_type()?;
644                Ok(Operation::Not(ty))
645            }
646            Token::Mul => {
647                let mode = self.parse_mul_mode()?;
648                let ty = self.parse_type()?;
649                Ok(Operation::Multiply(mode, ty))
650            }
651            Token::Mad => {
652                let mode = self.parse_mul_mode()?;
653                let ty = self.parse_type()?;
654                Ok(Operation::MultiplyAdd(mode, ty))
655            }
656            Token::Fma => {
657                let mode = self.parse_rounding_mode()?;
658                let ty = self.parse_type()?;
659                Ok(Operation::FusedMulAdd(mode, ty))
660            }
661            Token::Neg => {
662                let ty = self.parse_type()?;
663                Ok(Operation::Negate(ty))
664            }
665            Token::Cvt => {
666                let to = self.parse_type()?;
667                let from = self.parse_type()?;
668                Ok(Operation::Convert { to, from })
669            }
670            Token::Call => {
671                let uniform = self.consume_match(Token::Uniform)?;
672                let ret_param = if let Token::LeftParen = self.must_get()?.0 {
673                    self.skip();
674                    let t = self.must_pop()?;
675                    let ident = match t.0 {
676                        Token::Identifier(s) => s.to_string(),
677                        _ => return Err(self.unexpected(t)),
678                    };
679                    self.consume(Token::RightParen)?;
680                    self.consume(Token::Comma)?;
681                    Some(ident)
682                } else {
683                    None
684                };
685                let t = self.must_pop()?;
686                let ident = match t.0 {
687                    Token::Identifier(s) => s.to_string(),
688                    _ => return Err(self.unexpected(t)),
689                };
690                self.consume(Token::Comma)?;
691                let mut params = Vec::new();
692                if let Token::LeftParen = self.must_get()?.0 {
693                    self.skip();
694                    loop {
695                        let t = self.must_pop()?;
696                        let ident = match t.0 {
697                            Token::Identifier(s) => s.to_string(),
698                            _ => return Err(self.unexpected(t)),
699                        };
700                        params.push(ident);
701                        let t = self.must_pop()?;
702                        match t.0 {
703                            Token::RightParen => break,
704                            Token::Comma => {}
705                            _ => return Err(self.unexpected(t)),
706                        }
707                    }
708                };
709
710                Ok(Operation::Call {
711                    uniform,
712                    ident: ident.to_string(),
713                    ret_param,
714                    params,
715                })
716            }
717            Token::Cvta => match self.must_get()?.0 {
718                Token::To => {
719                    self.skip();
720                    let state_space = self.parse_state_space()?;
721                    let ty = self.parse_type()?;
722                    Ok(Operation::ConvertAddressTo(ty, state_space))
723                }
724                _ => {
725                    let state_space = self.parse_state_space()?;
726                    let ty = self.parse_type()?;
727                    Ok(Operation::ConvertAddress(ty, state_space))
728                }
729            },
730            Token::Setp => {
731                let pred = self.parse_predicate()?;
732                let ty = self.parse_type()?;
733                Ok(Operation::SetPredicate(pred, ty))
734            }
735            Token::Shl => {
736                let ty = self.parse_type()?;
737                Ok(Operation::ShiftLeft(ty))
738            }
739            Token::Bra => {
740                self.consume_match(Token::Uniform)?;
741                Ok(Operation::Branch)
742            }
743            Token::Ret => Ok(Operation::Return),
744            Token::Bar => {
745                // cta token is meaningless
746                self.consume_match(Token::Cta)?;
747                self.consume(Token::Sync)?;
748                Ok(Operation::BarrierSync)
749            }
750            _ => Err(self.unexpected(t)),
751        }
752    }
753
754    fn parse_operand(&mut self) -> ParseResult<Operand> {
755        let t = self.must_pop()?;
756        let operand = match t.0 {
757            Token::ThreadId => SpecialReg::ThreadId.into(),
758            Token::ThreadIdX => SpecialReg::ThreadIdX.into(),
759            Token::ThreadIdY => SpecialReg::ThreadIdY.into(),
760            Token::ThreadIdZ => SpecialReg::ThreadIdZ.into(),
761            Token::NumThreads => SpecialReg::NumThread.into(),
762            Token::NumThreadsX => SpecialReg::NumThreadX.into(),
763            Token::NumThreadsY => SpecialReg::NumThreadY.into(),
764            Token::NumThreadsZ => SpecialReg::NumThreadZ.into(),
765            Token::CtaId => SpecialReg::CtaId.into(),
766            Token::CtaIdX => SpecialReg::CtaIdX.into(),
767            Token::CtaIdY => SpecialReg::CtaIdY.into(),
768            Token::CtaIdZ => SpecialReg::CtaIdZ.into(),
769            Token::IntegerConst(i) => Operand::Immediate(Immediate::Int64(i)),
770            Token::Float64Const(f) => Operand::Immediate(Immediate::Float64(f)),
771            Token::Float32Const(f) => Operand::Immediate(Immediate::Float32(f)),
772            Token::Identifier(s) => {
773                let t = self.get()?;
774                if let Some((Token::LeftBracket, _)) = t {
775                    todo!("array syntax in operands")
776                } else {
777                    Operand::Variable(s.to_string())
778                }
779            }
780            Token::LeftBracket => {
781                let t = self.must_pop()?;
782                let Token::Identifier(s) = t.0 else {
783                    return Err(self.unexpected(t));
784                };
785                let ident = s.to_string();
786
787                let t = self.must_get()?;
788                let res = if let Token::Plus = t.0 {
789                    self.skip();
790                    let t = self.must_pop()?;
791                    match t.0 {
792                        Token::IntegerConst(i) => {
793                            Operand::Address(AddressOperand::AddressOffset(ident, i))
794                        }
795                        Token::Identifier(s) => {
796                            Operand::Address(AddressOperand::AddressOffsetVar(ident, s.to_string()))
797                        }
798                        _ => return Err(self.unexpected(t)),
799                    }
800                } else {
801                    Operand::Address(AddressOperand::Address(ident))
802                };
803                self.consume(Token::RightBracket)?;
804                res
805            }
806            _ => return Err(self.unexpected(t)),
807        };
808        Ok(operand)
809    }
810
811    fn parse_operands(&mut self) -> ParseResult<Vec<Operand>> {
812        let mut operands = Vec::new();
813        loop {
814            let t = self.must_get()?;
815            match t.0 {
816                Token::Semicolon => {
817                    self.skip();
818                    break Ok(operands);
819                }
820                Token::Comma => self.skip(),
821                _ => {}
822            }
823            let op = self.parse_operand()?;
824            operands.push(op);
825        }
826    }
827
828    fn parse_grouping(&mut self) -> ParseResult<Vec<Statement>> {
829        self.consume(Token::LeftBrace)?; // Consume the left brace
830        let mut statements = Vec::new();
831        loop {
832            let t = self.must_get()?;
833            if let Token::RightBrace = t.0 {
834                self.skip();
835                break Ok(statements);
836            }
837            statements.push(self.parse_statement()?);
838        }
839    }
840
841    fn parse_directive(&mut self) -> ParseResult<Directive> {
842        let t = self.must_get()?;
843        let res = match t.0 {
844            Token::Version(_) => {
845                let version = self.parse_version()?;
846                Directive::Version(version)
847            }
848            Token::Target => {
849                let target = self.parse_target()?;
850                Directive::Target(target)
851            }
852            Token::AddressSize => {
853                let addr_size = self.parse_address_size()?;
854                Directive::AddressSize(addr_size)
855            }
856            Token::Func | Token::Visible | Token::Entry => {
857                let function = self.parse_function()?;
858                Directive::Function(function)
859            }
860            Token::Pragma => {
861                let pragma = self.parse_pragma()?;
862                Directive::Pragma(pragma)
863            }
864            _ => {
865                let var = self.parse_variable()?;
866                Directive::VarDecl(var)
867            }
868        };
869        Ok(res)
870    }
871
872    fn parse_instruction(&mut self) -> ParseResult<Instruction> {
873        let t = self.must_get()?;
874        let guard = if let Token::At = t.0 {
875            Some(self.parse_guard()?)
876        } else {
877            None
878        };
879
880        let specifier = self.parse_operation()?;
881        let operands = self.parse_operands()?;
882
883        Ok(Instruction {
884            guard,
885            specifier,
886            operands,
887        })
888    }
889
890    fn parse_statement(&mut self) -> ParseResult<Statement> {
891        let t = self.must_get()?;
892        match t.0 {
893            Token::LeftBrace => {
894                let grouping = self.parse_grouping()?;
895                Ok(Statement::Grouping(grouping))
896            }
897            t if t.is_directive() => {
898                let dir = self.parse_directive()?;
899                Ok(Statement::Directive(dir))
900            }
901            Token::Identifier(i) => {
902                let i = i.to_string();
903                self.skip();
904                self.consume(Token::Colon)?;
905                Ok(Statement::Label(i.to_string()))
906            }
907            _ => {
908                let instr = self.parse_instruction()?;
909                Ok(Statement::Instruction(instr))
910            }
911        }
912    }
913
914    fn parse_function_param(&mut self) -> ParseResult<FunctionParam> {
915        self.consume(Token::Param)?; // Consume the param keyword
916
917        let alignment = None; // todo parse alignment in function param
918
919        let ty = self.parse_type()?;
920        let ident = loop {
921            let t = self.must_pop()?;
922            if let Token::Identifier(s) = t.0 {
923                break s.to_string();
924            }
925        };
926
927        let array_bounds = self.parse_array_bounds()?;
928
929        Ok(FunctionParam {
930            alignment,
931            ident: ident.to_string(),
932            ty,
933            array_bounds,
934        })
935    }
936
937    fn parse_function_params(&mut self) -> ParseResult<Vec<FunctionParam>> {
938        // if there is no left parenthesis, there are no parameters
939        if !self.consume_match(Token::LeftParen)? {
940            return Ok(Vec::new());
941        }
942        // if we immediately see a right parenthesis, there are no parameters
943        if self.consume_match(Token::RightParen)? {
944            return Ok(Vec::new());
945        }
946
947        let mut params = Vec::new();
948        loop {
949            params.push(self.parse_function_param()?);
950            let t = self.must_pop()?;
951            match t.0 {
952                Token::Comma => {}
953                Token::RightParen => break Ok(params),
954                _ => return Err(self.unexpected(t)),
955            }
956        }
957    }
958
959    fn parse_return_param(&mut self) -> ParseResult<Option<FunctionParam>> {
960        let t = self.must_get()?;
961        if let Token::LeftParen = t.0 {
962            self.skip();
963        } else {
964            return Ok(None);
965        }
966        let param = self.parse_function_param()?;
967        self.consume(Token::RightParen)?;
968        Ok(Some(param))
969    }
970
971    fn parse_function(&mut self) -> ParseResult<Function> {
972        let visible = if let Token::Visible = self.must_get()?.0 {
973            self.skip();
974            true
975        } else {
976            false
977        };
978        let t = self.must_pop()?;
979        let entry = match t.0 {
980            Token::Entry => true,
981            Token::Func => false,
982            _ => return Err(self.unexpected(t)),
983        };
984
985        let return_param = self.parse_return_param()?;
986
987        let t = self.must_pop()?;
988        let ident = match t.0 {
989            Token::Identifier(s) => s.to_string(),
990            _ => return Err(self.unexpected(t)),
991        };
992
993        let noreturn = if let Token::Noreturn = self.must_get()?.0 {
994            self.skip();
995            true
996        } else {
997            false
998        };
999
1000        let params = self.parse_function_params()?;
1001        let body = self.parse_statement()?;
1002
1003        Ok(Function {
1004            ident: ident.to_string(),
1005            visible,
1006            entry,
1007            return_param,
1008            noreturn,
1009            params,
1010            body: Box::new(body),
1011        })
1012    }
1013}
1014
1015pub fn parse_program(src: &str) -> Result<Module, ParseErr> {
1016    Parser::new(src).parse_module()
1017}
1018
1019#[cfg(test)]
1020mod test {
1021    use super::*;
1022
1023    #[test]
1024    fn test_parse_add() {
1025        let contents = std::fs::read_to_string("kernels/add.ptx").unwrap();
1026        let _ = parse_program(&contents).unwrap();
1027    }
1028
1029    #[test]
1030    fn test_parse_transpose() {
1031        let contents = std::fs::read_to_string("kernels/transpose.ptx").unwrap();
1032        let _ = parse_program(&contents).unwrap();
1033    }
1034
1035    #[test]
1036    fn test_parse_add_simple() {
1037        let contents = std::fs::read_to_string("kernels/add_simple.ptx").unwrap();
1038        let _ = parse_program(&contents).unwrap();
1039    }
1040
1041    #[test]
1042    fn test_parse_fncall() {
1043        let contents = std::fs::read_to_string("kernels/fncall.ptx").unwrap();
1044        let _ = parse_program(&contents).unwrap();
1045    }
1046
1047    #[test]
1048    fn test_parse_gemm() {
1049        let contents = std::fs::read_to_string("kernels/gemm.ptx").unwrap();
1050        let _ = parse_program(&contents).unwrap();
1051    }
1052}