1mod lex;
2
3use std::ops::Range;
4
5use lex::LexError;
6use lex::Token;
7use thiserror::Error;
8
9#[derive(Debug, Clone, Copy)]
10pub struct SourceLocation {
11 byte: usize,
12 line: usize,
13 col: usize,
14}
15
16impl std::fmt::Display for SourceLocation {
17 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18 write!(f, "@{}:{} (byte {})", self.line, self.col, self.byte)
19 }
20}
21
22#[derive(Error, Debug)]
23pub enum ParseErr {
24 #[error("Unexpected token \"{:?}\"", .0)]
25 UnexpectedToken(String, SourceLocation),
26 #[error("Unexpected end of file")]
27 UnexpectedEof,
28 #[error("Lex error \"{:?}\" at {:?}", .0, .1)]
29 LexError(LexError, SourceLocation),
30 #[error("Unknown token \"{:?}\" at {:?}", .0, .1)]
31 UnknownToken(String, SourceLocation),
32}
33
34type ParseResult<T> = Result<T, ParseErr>;
35
36type Ident = String;
37
38#[derive(Debug, Clone, Copy, PartialEq)]
39pub struct Version {
40 major: u32,
41 minor: u32,
42}
43
44#[derive(Clone, Debug)]
45pub struct Pragma(String);
46
47#[derive(Debug)]
48pub enum AddressSize {
49 Adr32,
50 Adr64,
51 Other,
52}
53
54#[derive(Debug)]
55pub struct Module(pub Vec<Directive>);
56
57#[derive(Debug)]
58pub struct Function {
59 pub ident: Ident,
60 pub visible: bool,
61 pub entry: bool,
62 pub noreturn: bool,
63 pub return_param: Option<FunctionParam>,
64 pub params: Vec<FunctionParam>,
65 pub body: Box<Statement>,
66}
67
68#[derive(Debug)]
69pub struct FunctionParam {
70 pub ident: Ident,
71 pub ty: Type,
72 pub alignment: Option<u32>,
73 pub array_bounds: Vec<u32>,
74}
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq)]
77pub enum StateSpace {
78 Global,
79 Local,
80 Shared,
81 Register,
82 Constant,
83 Parameter,
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87pub enum Type {
88 B128,
89 B64,
90 B32,
91 B16,
92 B8,
93 U64,
94 U32,
95 U16,
96 U8,
97 S64,
98 S32,
99 S16,
100 S8,
101 F64,
102 F32,
103 F16x2,
104 F16,
105 Pred,
106}
107
108#[derive(Debug, Clone, Copy)]
109pub enum Vector {
110 V2,
111 V4,
112}
113
114#[derive(Debug, Clone, Copy)]
115pub enum SpecialReg {
116 StackPtr,
117 ThreadId,
118 ThreadIdX,
119 ThreadIdY,
120 ThreadIdZ,
121 NumThread,
122 NumThreadX,
123 NumThreadY,
124 NumThreadZ,
125 CtaId,
126 CtaIdX,
127 CtaIdY,
128 CtaIdZ,
129 NumCta,
130 NumCtaX,
131 NumCtaY,
132 NumCtaZ,
133}
134
135impl From<SpecialReg> for Operand {
136 fn from(value: SpecialReg) -> Self {
137 Operand::SpecialReg(value)
138 }
139}
140
141#[derive(Debug, Clone)]
142pub struct VarDecl {
143 pub state_space: StateSpace,
144 pub ty: Type,
145 pub vector: Option<Vector>,
146 pub ident: Ident,
147 pub alignment: Option<u32>,
148 pub array_bounds: Vec<u32>,
149 pub multiplicity: Option<u32>,
150}
151
152#[derive(Debug, Clone)]
153pub enum AddressOperand {
154 Address(Ident),
155 AddressOffset(Ident, i64),
156 AddressOffsetVar(Ident, Ident),
157 ArrayIndex(Ident, usize),
158}
159
160impl AddressOperand {
161 pub fn get_ident(&self) -> &Ident {
162 match self {
163 AddressOperand::Address(ident) => ident,
164 AddressOperand::AddressOffset(ident, _) => ident,
165 AddressOperand::AddressOffsetVar(ident, _) => ident,
166 AddressOperand::ArrayIndex(ident, _) => ident,
167 }
168 }
169}
170
171#[derive(Debug, Clone)]
172pub enum Operand {
173 SpecialReg(SpecialReg),
174 Variable(Ident),
175 Immediate(Immediate),
176 Address(AddressOperand),
177}
178
179#[derive(Debug, Clone, Copy)]
180pub enum Immediate {
181 Float32(f32),
182 Float64(f64),
183 Int64(i64),
184 UInt64(u64),
185}
186
187#[derive(Debug, Clone)]
188pub enum Guard {
189 Normal(Ident),
190 Negated(Ident),
191}
192
193#[derive(Debug)]
194pub enum Directive {
195 VarDecl(VarDecl),
196 Version(Version),
197 Target(String),
198 AddressSize(AddressSize),
199 Function(Function),
200 Pragma(Pragma),
201}
202
203#[derive(Debug, Clone)]
204pub struct Instruction {
205 pub guard: Option<Guard>,
206 pub specifier: Operation,
207 pub operands: Vec<Operand>,
208}
209
210#[derive(Debug)]
211pub enum Statement {
212 Directive(Directive),
213 Instruction(Instruction),
214 Grouping(Vec<Statement>),
215 Label(Ident),
216}
217
218#[derive(Debug, Clone, Copy)]
219pub enum PredicateOp {
220 LessThan,
221 LessThanEqual,
222 GreaterThan,
223 GreaterThanEqual,
224 Equal,
225 NotEqual,
226}
227
228#[derive(Debug, Clone, Copy)]
229pub enum MulMode {
230 Low,
231 High,
232 Wide,
233}
234
235#[derive(Debug, Clone, Copy)]
236pub enum RoundingMode {
237 NearestEvent,
238 Zero,
239 NegInf,
240 PosInf,
241}
242
243#[derive(Debug, Clone)]
244pub enum Operation {
245 Load(StateSpace, Type),
246 Store(StateSpace, Type),
247 Move(Type),
248 Add(Type),
249 Sub(Type),
250 Or(Type),
251 And(Type),
252 Not(Type),
253 FusedMulAdd(RoundingMode, Type),
254 Negate(Type),
255 Multiply(MulMode, Type),
256 MultiplyAdd(MulMode, Type),
257 Convert {
258 from: Type,
259 to: Type,
260 },
261 ConvertAddress(Type, StateSpace),
262 ConvertAddressTo(Type, StateSpace),
263 SetPredicate(PredicateOp, Type),
264 ShiftLeft(Type),
265 Call {
266 uniform: bool,
267 ident: Ident,
268 ret_param: Option<Ident>,
269 params: Vec<Ident>,
270 },
271 BarrierSync,
272 Branch,
273 Return,
274}
275
276type TokenPos<'a> = Range<usize>;
277
278struct Parser<'a> {
279 src: &'a str,
280 inner: std::iter::Peekable<logos::SpannedIter<'a, Token<'a>>>,
281}
282
283impl<'a> Parser<'a> {
284 pub fn new(src: &'a str) -> Self {
285 use logos::Logos;
286 Self {
287 src,
288 inner: Token::lexer(src).spanned().peekable(),
289 }
290 }
291
292 fn locate(&self, span: Range<usize>) -> SourceLocation {
293 let text = self.src.as_bytes();
294
295 let mut line = 1;
296 let mut col = 0;
297
298 let end = span.end.min(text.len());
299
300 for &c in &text[..end] {
301 match c {
302 b'\n' => {
303 line += 1;
304 col = 0;
305 },
306 b'\t' => {
307 col = (col / 4) * 4 + 4;
308 }
309 _ => col += 1,
310 }
311 }
312
313 SourceLocation {
314 byte: span.start,
315 line,
316 col,
317 }
318 }
319
320 fn unexpected(&self, (token, pos): (Token, TokenPos)) -> ParseErr {
321 ParseErr::UnexpectedToken(token.to_string(), self.locate(pos))
322 }
323
324 fn get(&mut self) -> ParseResult<Option<(Token<'a>, TokenPos)>> {
325 match self.inner.peek().cloned() {
326 Some((Ok(tok), pos)) => Ok(Some((tok, pos))),
327 Some((Err(LexError::Unknown), pos)) => Err(ParseErr::UnknownToken(
328 self.src[pos.clone()].to_string(),
329 self.locate(pos),
330 )),
331 Some((Err(err), pos)) => Err(ParseErr::LexError(err, self.locate(pos))),
332 None => Ok(None),
333 }
334 }
335
336 fn must_get(&mut self) -> Result<(Token<'a>, TokenPos), ParseErr> {
337 self.get()?.ok_or(ParseErr::UnexpectedEof)
338 }
339
340 fn skip(&mut self) {
341 self.inner.next();
342 }
343
344 fn consume(&mut self, token: Token) -> ParseResult<()> {
345 let head = self.must_get()?;
346 if head.0 == token {
347 self.skip();
348 Ok(())
349 } else {
350 Err(self.unexpected(head))
351 }
352 }
353
354 fn consume_match(&mut self, token: Token) -> ParseResult<bool> {
355 let Some(head) = self.get()? else {
356 return Ok(false);
357 };
358 if head.0 == token {
359 self.skip();
360 Ok(true)
361 } else {
362 Ok(false)
363 }
364 }
365
366 fn pop(&mut self) -> ParseResult<Option<(Token<'a>, TokenPos)>> {
367 match self.inner.next() {
368 Some((Ok(tok), pos)) => Ok(Some((tok, pos))),
369 Some((Err(err), pos)) => Err(ParseErr::LexError(err, self.locate(pos))),
370 None => Ok(None),
371 }
372 }
373
374 fn must_pop(&mut self) -> Result<(Token<'a>, TokenPos), ParseErr> {
375 self.pop()?.ok_or(ParseErr::UnexpectedEof)
376 }
377
378 fn parse_pragma(&mut self) -> ParseResult<Pragma> {
379 self.consume(Token::Pragma)?;
380 let t = self.must_pop()?;
381 match t.0 {
382 Token::StringLiteral(s) => {
383 self.consume(Token::Semicolon)?;
384 Ok(Pragma(s.to_string()))
385 }
386 _ => Err(self.unexpected(t)),
387 }
388 }
389
390 fn parse_version(&mut self) -> ParseResult<Version> {
391 let t = self.must_pop()?;
392 match t.0 {
393 Token::Version((major, minor)) => Ok(Version { major, minor }),
394 _ => Err(self.unexpected(t)),
395 }
396 }
397
398 fn parse_target(&mut self) -> ParseResult<String> {
399 self.consume(Token::Target)?;
400 let t = self.must_pop()?;
401 match t.0 {
402 Token::Identifier(target) => Ok(target.to_string()),
403 _ => Err(self.unexpected(t)),
404 }
405 }
406
407 fn parse_address_size(&mut self) -> ParseResult<AddressSize> {
408 self.consume(Token::AddressSize)?;
409 let t = self.must_pop()?;
410 let Token::IntegerConst(size) = t.0 else {
411 return Err(self.unexpected(t));
412 };
413 match size {
414 32 => Ok(AddressSize::Adr32),
415 64 => Ok(AddressSize::Adr64),
416 _ => Ok(AddressSize::Other),
417 }
418 }
419
420 fn parse_module(&mut self) -> ParseResult<Module> {
421 let mut directives = Vec::new();
422 while self.get()?.is_some() {
423 match self.parse_directive() {
424 Ok(directive) => {
425 directives.push(directive);
426 }
427 Err(e) => return Err(e),
428 }
429 }
430 Ok(Module(directives))
431 }
432
433 fn parse_array_bounds(&mut self) -> ParseResult<Vec<u32>> {
434 let mut bounds = Vec::new();
435 loop {
436 match self.get()? {
437 Some((Token::LeftBracket, _)) => self.skip(),
438 _ => break Ok(bounds),
439 }
440 let t = self.must_pop()?;
441 let Token::IntegerConst(bound) = t.0 else {
442 return Err(self.unexpected(t));
443 };
444 self.consume(Token::RightBracket)?;
445 bounds.push(bound as u32);
447 }
448 }
449
450 fn parse_state_space(&mut self) -> ParseResult<StateSpace> {
451 let t = self.must_pop()?;
452 match t.0 {
453 Token::Global => Ok(StateSpace::Global),
454 Token::Local => Ok(StateSpace::Local),
455 Token::Shared => Ok(StateSpace::Shared),
456 Token::Reg => Ok(StateSpace::Register),
457 Token::Param => Ok(StateSpace::Parameter),
458 Token::Const => Ok(StateSpace::Constant),
459 _ => Err(self.unexpected(t)),
460 }
461 }
462
463 fn parse_alignment(&mut self) -> ParseResult<u32> {
464 self.consume(Token::Align)?;
465 let t = self.must_pop()?;
466 let alignment = match t.0 {
467 Token::IntegerConst(i) => i as u32,
468 _ => return Err(self.unexpected(t)),
469 };
470 Ok(alignment)
471 }
472
473 fn parse_type(&mut self) -> ParseResult<Type> {
474 let t = self.must_pop()?;
475 let ty = match t.0 {
476 Token::Bit8 => Type::B8,
477 Token::Bit16 => Type::B16,
478 Token::Bit32 => Type::B32,
479 Token::Bit64 => Type::B64,
480 Token::Bit128 => Type::B128,
481 Token::Unsigned8 => Type::U8,
482 Token::Unsigned16 => Type::U16,
483 Token::Unsigned32 => Type::U32,
484 Token::Unsigned64 => Type::U64,
485 Token::Signed8 => Type::S8,
486 Token::Signed16 => Type::S16,
487 Token::Signed32 => Type::S32,
488 Token::Signed64 => Type::S64,
489 Token::Float16 => Type::F16,
490 Token::Float16x2 => Type::F16x2,
491 Token::Float32 => Type::F32,
492 Token::Float64 => Type::F64,
493 Token::Predicate => Type::Pred,
494 _ => return Err(self.unexpected(t)),
495 };
496 Ok(ty)
497 }
498
499 fn parse_rounding_mode(&mut self) -> ParseResult<RoundingMode> {
500 let t = self.must_pop()?;
501 let mode = match t.0 {
502 Token::Rn => RoundingMode::NearestEvent,
503 Token::Rz => RoundingMode::Zero,
504 Token::Rm => RoundingMode::NegInf,
505 Token::Rp => RoundingMode::PosInf,
506 _ => return Err(self.unexpected(t)),
507 };
508 Ok(mode)
509 }
510
511 fn parse_mul_mode(&mut self) -> ParseResult<MulMode> {
512 let t = self.must_pop()?;
513 let mode = match t.0 {
514 Token::Low => MulMode::Low,
515 Token::High => MulMode::High,
516 Token::Wide => MulMode::Wide,
517 _ => return Err(self.unexpected(t)),
518 };
519 Ok(mode)
520 }
521
522 fn parse_variable(&mut self) -> ParseResult<VarDecl> {
523 let state_space = self.parse_state_space()?;
524
525 let t = self.get()?;
526 let alignment = if let Some((Token::Align, _)) = t {
527 Some(self.parse_alignment()?)
528 } else {
529 None
530 };
531
532 let t = self.get()?;
533 let vector = match t {
534 Some((Token::V2, _)) => {
535 self.skip();
536 Some(Vector::V2)
537 }
538 Some((Token::V4, _)) => {
539 self.skip();
540 Some(Vector::V4)
541 }
542 _ => None,
543 };
544
545 let ty = self.parse_type()?;
546
547 let t = self.must_pop()?;
548 let ident = match t.0 {
549 Token::Identifier(s) => s.to_string(),
550 _ => return Err(self.unexpected(t)),
551 };
552
553 let t = self.must_get()?;
554 let multiplicity = match t.0 {
555 Token::RegMultiplicity(m) => {
556 self.skip();
557 Some(m)
558 }
559 _ => None,
560 };
561
562 let array_bounds = self.parse_array_bounds()?;
563
564 self.consume(Token::Semicolon)?;
565
566 Ok(VarDecl {
567 state_space,
568 ty,
569 vector,
570 alignment,
571 array_bounds,
572 ident: ident.to_string(),
573 multiplicity,
574 })
575 }
576
577 fn parse_guard(&mut self) -> ParseResult<Guard> {
578 self.consume(Token::At)?;
579 let t = self.must_pop()?;
580 let guard = match t.0 {
581 Token::Identifier(s) => Guard::Normal(s.to_string()),
582 Token::Bang => {
583 let t = self.must_pop()?;
584 let ident = match t.0 {
585 Token::Identifier(s) => s,
586 _ => return Err(self.unexpected(t)),
587 };
588 Guard::Negated(ident.to_string())
589 }
590 _ => return Err(self.unexpected(t)),
591 };
592 Ok(guard)
593 }
594
595 fn parse_predicate(&mut self) -> ParseResult<PredicateOp> {
596 let t = self.must_pop()?;
597 let pred = match t.0 {
598 Token::Ge => PredicateOp::GreaterThanEqual,
599 Token::Gt => PredicateOp::GreaterThan,
600 Token::Le => PredicateOp::LessThanEqual,
601 Token::Lt => PredicateOp::LessThan,
602 Token::Eq => PredicateOp::Equal,
603 Token::Ne => PredicateOp::NotEqual,
604 _ => return Err(self.unexpected(t)),
605 };
606 Ok(pred)
607 }
608
609 fn parse_operation(&mut self) -> ParseResult<Operation> {
610 let t = self.must_pop()?;
611 match t.0 {
612 Token::Ld => {
613 let state_space = self.parse_state_space()?;
614 let ty = self.parse_type()?;
615 Ok(Operation::Load(state_space, ty))
616 }
617 Token::St => {
618 let state_space = self.parse_state_space()?;
619 let ty = self.parse_type()?;
620 Ok(Operation::Store(state_space, ty))
621 }
622 Token::Mov => {
623 let ty = self.parse_type()?;
624 Ok(Operation::Move(ty))
625 }
626 Token::Add => {
627 let ty = self.parse_type()?;
628 Ok(Operation::Add(ty))
629 }
630 Token::Sub => {
631 let ty = self.parse_type()?;
632 Ok(Operation::Sub(ty))
633 }
634 Token::Or => {
635 let ty = self.parse_type()?;
636 Ok(Operation::Or(ty))
637 }
638 Token::And => {
639 let ty = self.parse_type()?;
640 Ok(Operation::And(ty))
641 }
642 Token::Not => {
643 let ty = self.parse_type()?;
644 Ok(Operation::Not(ty))
645 }
646 Token::Mul => {
647 let mode = self.parse_mul_mode()?;
648 let ty = self.parse_type()?;
649 Ok(Operation::Multiply(mode, ty))
650 }
651 Token::Mad => {
652 let mode = self.parse_mul_mode()?;
653 let ty = self.parse_type()?;
654 Ok(Operation::MultiplyAdd(mode, ty))
655 }
656 Token::Fma => {
657 let mode = self.parse_rounding_mode()?;
658 let ty = self.parse_type()?;
659 Ok(Operation::FusedMulAdd(mode, ty))
660 }
661 Token::Neg => {
662 let ty = self.parse_type()?;
663 Ok(Operation::Negate(ty))
664 }
665 Token::Cvt => {
666 let to = self.parse_type()?;
667 let from = self.parse_type()?;
668 Ok(Operation::Convert { to, from })
669 }
670 Token::Call => {
671 let uniform = self.consume_match(Token::Uniform)?;
672 let ret_param = if let Token::LeftParen = self.must_get()?.0 {
673 self.skip();
674 let t = self.must_pop()?;
675 let ident = match t.0 {
676 Token::Identifier(s) => s.to_string(),
677 _ => return Err(self.unexpected(t)),
678 };
679 self.consume(Token::RightParen)?;
680 self.consume(Token::Comma)?;
681 Some(ident)
682 } else {
683 None
684 };
685 let t = self.must_pop()?;
686 let ident = match t.0 {
687 Token::Identifier(s) => s.to_string(),
688 _ => return Err(self.unexpected(t)),
689 };
690 self.consume(Token::Comma)?;
691 let mut params = Vec::new();
692 if let Token::LeftParen = self.must_get()?.0 {
693 self.skip();
694 loop {
695 let t = self.must_pop()?;
696 let ident = match t.0 {
697 Token::Identifier(s) => s.to_string(),
698 _ => return Err(self.unexpected(t)),
699 };
700 params.push(ident);
701 let t = self.must_pop()?;
702 match t.0 {
703 Token::RightParen => break,
704 Token::Comma => {}
705 _ => return Err(self.unexpected(t)),
706 }
707 }
708 };
709
710 Ok(Operation::Call {
711 uniform,
712 ident: ident.to_string(),
713 ret_param,
714 params,
715 })
716 }
717 Token::Cvta => match self.must_get()?.0 {
718 Token::To => {
719 self.skip();
720 let state_space = self.parse_state_space()?;
721 let ty = self.parse_type()?;
722 Ok(Operation::ConvertAddressTo(ty, state_space))
723 }
724 _ => {
725 let state_space = self.parse_state_space()?;
726 let ty = self.parse_type()?;
727 Ok(Operation::ConvertAddress(ty, state_space))
728 }
729 },
730 Token::Setp => {
731 let pred = self.parse_predicate()?;
732 let ty = self.parse_type()?;
733 Ok(Operation::SetPredicate(pred, ty))
734 }
735 Token::Shl => {
736 let ty = self.parse_type()?;
737 Ok(Operation::ShiftLeft(ty))
738 }
739 Token::Bra => {
740 self.consume_match(Token::Uniform)?;
741 Ok(Operation::Branch)
742 }
743 Token::Ret => Ok(Operation::Return),
744 Token::Bar => {
745 self.consume_match(Token::Cta)?;
747 self.consume(Token::Sync)?;
748 Ok(Operation::BarrierSync)
749 }
750 _ => Err(self.unexpected(t)),
751 }
752 }
753
754 fn parse_operand(&mut self) -> ParseResult<Operand> {
755 let t = self.must_pop()?;
756 let operand = match t.0 {
757 Token::ThreadId => SpecialReg::ThreadId.into(),
758 Token::ThreadIdX => SpecialReg::ThreadIdX.into(),
759 Token::ThreadIdY => SpecialReg::ThreadIdY.into(),
760 Token::ThreadIdZ => SpecialReg::ThreadIdZ.into(),
761 Token::NumThreads => SpecialReg::NumThread.into(),
762 Token::NumThreadsX => SpecialReg::NumThreadX.into(),
763 Token::NumThreadsY => SpecialReg::NumThreadY.into(),
764 Token::NumThreadsZ => SpecialReg::NumThreadZ.into(),
765 Token::CtaId => SpecialReg::CtaId.into(),
766 Token::CtaIdX => SpecialReg::CtaIdX.into(),
767 Token::CtaIdY => SpecialReg::CtaIdY.into(),
768 Token::CtaIdZ => SpecialReg::CtaIdZ.into(),
769 Token::IntegerConst(i) => Operand::Immediate(Immediate::Int64(i)),
770 Token::Float64Const(f) => Operand::Immediate(Immediate::Float64(f)),
771 Token::Float32Const(f) => Operand::Immediate(Immediate::Float32(f)),
772 Token::Identifier(s) => {
773 let t = self.get()?;
774 if let Some((Token::LeftBracket, _)) = t {
775 todo!("array syntax in operands")
776 } else {
777 Operand::Variable(s.to_string())
778 }
779 }
780 Token::LeftBracket => {
781 let t = self.must_pop()?;
782 let Token::Identifier(s) = t.0 else {
783 return Err(self.unexpected(t));
784 };
785 let ident = s.to_string();
786
787 let t = self.must_get()?;
788 let res = if let Token::Plus = t.0 {
789 self.skip();
790 let t = self.must_pop()?;
791 match t.0 {
792 Token::IntegerConst(i) => {
793 Operand::Address(AddressOperand::AddressOffset(ident, i))
794 }
795 Token::Identifier(s) => {
796 Operand::Address(AddressOperand::AddressOffsetVar(ident, s.to_string()))
797 }
798 _ => return Err(self.unexpected(t)),
799 }
800 } else {
801 Operand::Address(AddressOperand::Address(ident))
802 };
803 self.consume(Token::RightBracket)?;
804 res
805 }
806 _ => return Err(self.unexpected(t)),
807 };
808 Ok(operand)
809 }
810
811 fn parse_operands(&mut self) -> ParseResult<Vec<Operand>> {
812 let mut operands = Vec::new();
813 loop {
814 let t = self.must_get()?;
815 match t.0 {
816 Token::Semicolon => {
817 self.skip();
818 break Ok(operands);
819 }
820 Token::Comma => self.skip(),
821 _ => {}
822 }
823 let op = self.parse_operand()?;
824 operands.push(op);
825 }
826 }
827
828 fn parse_grouping(&mut self) -> ParseResult<Vec<Statement>> {
829 self.consume(Token::LeftBrace)?; let mut statements = Vec::new();
831 loop {
832 let t = self.must_get()?;
833 if let Token::RightBrace = t.0 {
834 self.skip();
835 break Ok(statements);
836 }
837 statements.push(self.parse_statement()?);
838 }
839 }
840
841 fn parse_directive(&mut self) -> ParseResult<Directive> {
842 let t = self.must_get()?;
843 let res = match t.0 {
844 Token::Version(_) => {
845 let version = self.parse_version()?;
846 Directive::Version(version)
847 }
848 Token::Target => {
849 let target = self.parse_target()?;
850 Directive::Target(target)
851 }
852 Token::AddressSize => {
853 let addr_size = self.parse_address_size()?;
854 Directive::AddressSize(addr_size)
855 }
856 Token::Func | Token::Visible | Token::Entry => {
857 let function = self.parse_function()?;
858 Directive::Function(function)
859 }
860 Token::Pragma => {
861 let pragma = self.parse_pragma()?;
862 Directive::Pragma(pragma)
863 }
864 _ => {
865 let var = self.parse_variable()?;
866 Directive::VarDecl(var)
867 }
868 };
869 Ok(res)
870 }
871
872 fn parse_instruction(&mut self) -> ParseResult<Instruction> {
873 let t = self.must_get()?;
874 let guard = if let Token::At = t.0 {
875 Some(self.parse_guard()?)
876 } else {
877 None
878 };
879
880 let specifier = self.parse_operation()?;
881 let operands = self.parse_operands()?;
882
883 Ok(Instruction {
884 guard,
885 specifier,
886 operands,
887 })
888 }
889
890 fn parse_statement(&mut self) -> ParseResult<Statement> {
891 let t = self.must_get()?;
892 match t.0 {
893 Token::LeftBrace => {
894 let grouping = self.parse_grouping()?;
895 Ok(Statement::Grouping(grouping))
896 }
897 t if t.is_directive() => {
898 let dir = self.parse_directive()?;
899 Ok(Statement::Directive(dir))
900 }
901 Token::Identifier(i) => {
902 let i = i.to_string();
903 self.skip();
904 self.consume(Token::Colon)?;
905 Ok(Statement::Label(i.to_string()))
906 }
907 _ => {
908 let instr = self.parse_instruction()?;
909 Ok(Statement::Instruction(instr))
910 }
911 }
912 }
913
914 fn parse_function_param(&mut self) -> ParseResult<FunctionParam> {
915 self.consume(Token::Param)?; let alignment = None; let ty = self.parse_type()?;
920 let ident = loop {
921 let t = self.must_pop()?;
922 if let Token::Identifier(s) = t.0 {
923 break s.to_string();
924 }
925 };
926
927 let array_bounds = self.parse_array_bounds()?;
928
929 Ok(FunctionParam {
930 alignment,
931 ident: ident.to_string(),
932 ty,
933 array_bounds,
934 })
935 }
936
937 fn parse_function_params(&mut self) -> ParseResult<Vec<FunctionParam>> {
938 if !self.consume_match(Token::LeftParen)? {
940 return Ok(Vec::new());
941 }
942 if self.consume_match(Token::RightParen)? {
944 return Ok(Vec::new());
945 }
946
947 let mut params = Vec::new();
948 loop {
949 params.push(self.parse_function_param()?);
950 let t = self.must_pop()?;
951 match t.0 {
952 Token::Comma => {}
953 Token::RightParen => break Ok(params),
954 _ => return Err(self.unexpected(t)),
955 }
956 }
957 }
958
959 fn parse_return_param(&mut self) -> ParseResult<Option<FunctionParam>> {
960 let t = self.must_get()?;
961 if let Token::LeftParen = t.0 {
962 self.skip();
963 } else {
964 return Ok(None);
965 }
966 let param = self.parse_function_param()?;
967 self.consume(Token::RightParen)?;
968 Ok(Some(param))
969 }
970
971 fn parse_function(&mut self) -> ParseResult<Function> {
972 let visible = if let Token::Visible = self.must_get()?.0 {
973 self.skip();
974 true
975 } else {
976 false
977 };
978 let t = self.must_pop()?;
979 let entry = match t.0 {
980 Token::Entry => true,
981 Token::Func => false,
982 _ => return Err(self.unexpected(t)),
983 };
984
985 let return_param = self.parse_return_param()?;
986
987 let t = self.must_pop()?;
988 let ident = match t.0 {
989 Token::Identifier(s) => s.to_string(),
990 _ => return Err(self.unexpected(t)),
991 };
992
993 let noreturn = if let Token::Noreturn = self.must_get()?.0 {
994 self.skip();
995 true
996 } else {
997 false
998 };
999
1000 let params = self.parse_function_params()?;
1001 let body = self.parse_statement()?;
1002
1003 Ok(Function {
1004 ident: ident.to_string(),
1005 visible,
1006 entry,
1007 return_param,
1008 noreturn,
1009 params,
1010 body: Box::new(body),
1011 })
1012 }
1013}
1014
1015pub fn parse_program(src: &str) -> Result<Module, ParseErr> {
1016 Parser::new(src).parse_module()
1017}
1018
1019#[cfg(test)]
1020mod test {
1021 use super::*;
1022
1023 #[test]
1024 fn test_parse_add() {
1025 let contents = std::fs::read_to_string("kernels/add.ptx").unwrap();
1026 let _ = parse_program(&contents).unwrap();
1027 }
1028
1029 #[test]
1030 fn test_parse_transpose() {
1031 let contents = std::fs::read_to_string("kernels/transpose.ptx").unwrap();
1032 let _ = parse_program(&contents).unwrap();
1033 }
1034
1035 #[test]
1036 fn test_parse_add_simple() {
1037 let contents = std::fs::read_to_string("kernels/add_simple.ptx").unwrap();
1038 let _ = parse_program(&contents).unwrap();
1039 }
1040
1041 #[test]
1042 fn test_parse_fncall() {
1043 let contents = std::fs::read_to_string("kernels/fncall.ptx").unwrap();
1044 let _ = parse_program(&contents).unwrap();
1045 }
1046
1047 #[test]
1048 fn test_parse_gemm() {
1049 let contents = std::fs::read_to_string("kernels/gemm.ptx").unwrap();
1050 let _ = parse_program(&contents).unwrap();
1051 }
1052}