fugue_sleigh/
ast.rs

1use std::fmt::Display;
2use std::num::ParseIntError;
3use std::ops::Range;
4
5use itertools::Itertools;
6
7use pest::error::Error;
8use pest::iterators::Pair;
9use pest::{Parser, Span};
10
11use thiserror::Error;
12use ustr::Ustr;
13
14use crate::parse::{Rule, SleighParser};
15
16#[derive(Debug, Clone, Eq, PartialEq)]
17pub struct CodeBlock {
18    stmts: Vec<Stmt>,
19}
20
21#[derive(Debug, Clone, Eq, PartialEq)]
22pub enum Stmt {
23    Assign {
24        name: Ident,
25        decl: bool,
26        size: Option<u32>,
27        bits: Option<Range<u32>>,
28        source: Expr,
29    },
30    Declare {
31        name: Ident,
32        size: Option<u32>,
33    },
34
35    Store {
36        space: Option<Ident>,
37        size: Option<u32>,
38        target: Expr,
39        source: Expr,
40    },
41
42    Branch {
43        target: BranchTarget,
44    },
45    CBranch {
46        target: BranchTarget,
47        condition: Expr,
48    },
49
50    Call {
51        target: BranchTarget,
52    },
53    Return {
54        target: BranchTarget,
55    },
56
57    Intrinsic {
58        name: Ident,
59        arguments: Vec<Expr>,
60    },
61
62    Label {
63        label: Ident,
64    },
65}
66
67impl Stmt {
68    pub fn is_branch(&self) -> bool {
69        matches!(
70            self,
71            Self::Branch { .. } | Stmt::CBranch { .. } | Stmt::Call { .. } | Stmt::Return { .. }
72        )
73    }
74
75    pub fn has_fall(&self) -> bool {
76        !matches!(self, Self::Return { .. } | Self::Branch { .. })
77    }
78
79    pub fn branch_target(&self) -> Option<Ident> {
80        if let Self::Branch {
81            target: BranchTarget::Label(label),
82        }
83        | Self::CBranch {
84            target: BranchTarget::Label(label),
85            ..
86        } = self
87        {
88            Some(*label)
89        } else {
90            None
91        }
92    }
93
94    pub fn label(&self) -> Option<Ident> {
95        if let Self::Label { label } = self {
96            Some(*label)
97        } else {
98            None
99        }
100    }
101}
102
103#[derive(Debug, Clone, Eq, PartialEq)]
104pub enum Expr {
105    Ident {
106        value: Ident,
107        size: Option<u32>,
108    },
109    Literal {
110        value: u64,
111        size: Option<u32>,
112    },
113
114    UnOp {
115        op: UnOp,
116        value: Box<Expr>,
117    },
118
119    BinOp {
120        op: BinOp,
121        lvalue: Box<Expr>,
122        rvalue: Box<Expr>,
123    },
124    BinRel {
125        op: BinRel,
126        lvalue: Box<Expr>,
127        rvalue: Box<Expr>,
128    },
129
130    Load {
131        space: Option<Ident>,
132        size: Option<u32>,
133        source: Box<Expr>,
134    },
135
136    AddressOf {
137        value: Ident,
138        size: Option<u32>,
139    },
140    BitsOf {
141        value: Ident,
142        range: Range<u32>,
143        size: u32,
144    },
145
146    Intrinsic {
147        name: Ident,
148        arguments: Vec<Expr>,
149    },
150}
151
152#[derive(Debug, Clone, Copy, Eq, PartialEq)]
153pub enum BinOp {
154    BoolOr,
155    BoolAnd,
156    BoolXor,
157
158    Or,
159    And,
160    Xor,
161
162    ShiftLeft,
163    ShiftRight,
164    SignedShiftRight,
165
166    Add,
167    Sub,
168    Mul,
169    Div,
170    Rem,
171
172    SignedDiv,
173    SignedRem,
174
175    FloatAdd,
176    FloatSub,
177    FloatMul,
178    FloatDiv,
179}
180
181#[derive(Debug, Clone, Copy, Eq, PartialEq)]
182pub enum BinRel {
183    Eq,
184    NotEq,
185
186    Less,
187    LessEq,
188    Greater,
189    GreaterEq,
190
191    SignedLess,
192    SignedLessEq,
193    SignedGreater,
194    SignedGreaterEq,
195
196    FloatEq,
197    FloatNotEq,
198
199    FloatLess,
200    FloatLessEq,
201    FloatGreater,
202    FloatGreaterEq,
203}
204
205#[derive(Debug, Clone, Copy, Eq, PartialEq)]
206pub enum UnOp {
207    BoolNot,
208    Not,
209    Neg,
210    FloatNeg,
211}
212
213#[derive(Debug, Clone, Eq, PartialEq)]
214pub enum BranchTarget {
215    Direct(BranchLabel),
216    Indirect(Expr),
217    Label(Ident),
218}
219
220#[derive(Debug, Clone, Eq, PartialEq)]
221pub enum BranchLabel {
222    Offset { offset: u64, space: Option<Ident> },
223    Varnode { name: Ident },
224}
225
226pub type Ident = Ustr;
227
228#[derive(Debug, Error)]
229pub enum AstError {
230    #[error(transparent)]
231    Parse(#[from] Error<Rule>),
232    #[error("{0}: {1}")]
233    Integer(ErrorSpan, ParseIntError),
234    #[error("{0}: invalid bit-range")]
235    BitRange(ErrorSpan),
236    #[error("{0}: invalid size")]
237    Size(ErrorSpan),
238    #[error("attempt to define label `{0}` more than once")]
239    DuplicateLabel(Ustr),
240    #[error("reference to undefined label `{0}`")]
241    UndefinedLabel(Ustr),
242}
243
244#[derive(Debug)]
245pub struct ErrorSpan {
246    pub start: usize,
247    pub end: usize,
248}
249
250impl Display for ErrorSpan {
251    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
252        write!(f, "{}..{}", self.start, self.end)
253    }
254}
255
256impl From<Span<'_>> for ErrorSpan {
257    fn from(value: Span<'_>) -> Self {
258        Self {
259            start: value.start(),
260            end: value.end(),
261        }
262    }
263}
264
265impl CodeBlock {
266    pub fn parse(input: &str) -> Result<Self, AstError> {
267        let parsed = SleighParser::parse(Rule::code_block, input)?
268            .into_iter()
269            .next() // code_block
270            .unwrap()
271            .into_inner()
272            .next() // statements
273            .unwrap();
274
275        Ok(Self {
276            stmts: parsed
277                .into_inner()
278                .map(|stmt| Self::parse_stmt(stmt))
279                .collect::<Result<Vec<_>, _>>()?,
280        })
281    }
282
283    pub fn statements(&self) -> &[Stmt] {
284        &self.stmts
285    }
286
287    fn parse_assignment(assign: Pair<'_, Rule>, decl: bool) -> Result<Stmt, AstError> {
288        let mut pairs = assign.into_inner();
289
290        let lvalue = pairs.next().unwrap().into_inner().next().unwrap();
291        let expr = Self::parse_expr(pairs.next().unwrap())?;
292
293        let assign = match lvalue.as_rule() {
294            Rule::sembitrange => {
295                let (name, bits, size) = Self::parse_sembitrange(lvalue)?;
296
297                Stmt::Assign {
298                    name,
299                    decl,
300                    size: Some(size),
301                    bits: Some(bits),
302                    source: expr,
303                }
304            }
305            Rule::sized_identifier => {
306                let (name, size) = Self::parse_sized(lvalue, Self::parse_identifier)?;
307
308                Stmt::Assign {
309                    name,
310                    decl,
311                    size: Some(size),
312                    bits: None,
313                    source: expr,
314                }
315            }
316            Rule::identifier => Stmt::Assign {
317                name: Self::parse_identifier(lvalue)?,
318                decl,
319                size: None,
320                bits: None,
321                source: expr,
322            },
323            Rule::sized_star_expr => {
324                let mut pairs = lvalue.into_inner();
325                let (space, size) = Self::parse_sized_star(pairs.next().unwrap())?;
326                let target = Self::parse_expr(pairs.next().unwrap())?;
327
328                Stmt::Store {
329                    space,
330                    size,
331                    target,
332                    source: expr,
333                }
334            }
335            _ => unimplemented!(),
336        };
337
338        Ok(assign)
339    }
340
341    fn parse_declaration(decl: Pair<'_, Rule>) -> Result<Stmt, AstError> {
342        let mut pairs = decl.into_inner();
343        let decl = pairs.next().unwrap();
344
345        let decl = match decl.as_rule() {
346            Rule::declaration_ => Stmt::Declare {
347                name: Self::parse_identifier(decl.into_inner().next().unwrap())?,
348                size: None,
349            },
350            Rule::declaration_with_size => {
351                let mut pairs = decl.into_inner();
352
353                let ident = Self::parse_identifier(pairs.next().unwrap())?;
354                let size = Self::parse_size(pairs.next().unwrap(), true)?;
355
356                Stmt::Declare {
357                    name: ident,
358                    size: Some(size),
359                }
360            }
361            _ => unreachable!(),
362        };
363
364        Ok(decl)
365    }
366
367    fn parse_expr(expr: Pair<'_, Rule>) -> Result<Expr, AstError> {
368        let mut pairs = expr.into_inner();
369        Self::parse_expr_boolor(pairs.next().unwrap())
370    }
371
372    fn parse_expr_boolor(expr: Pair<'_, Rule>) -> Result<Expr, AstError> {
373        let mut pairs = expr.into_inner();
374        let mut expr = Self::parse_expr_booland(pairs.next().unwrap())?;
375
376        for (_, pair) in pairs.tuples() {
377            expr = Expr::BinOp {
378                op: BinOp::BoolOr,
379                lvalue: Box::new(expr),
380                rvalue: Box::new(Self::parse_expr_booland(pair)?),
381            };
382        }
383
384        Ok(expr)
385    }
386
387    fn parse_expr_booland(expr: Pair<'_, Rule>) -> Result<Expr, AstError> {
388        let mut pairs = expr.into_inner();
389        let mut expr = Self::parse_expr_or(pairs.next().unwrap())?;
390
391        for (op, pair) in pairs.tuples() {
392            let op = match op.as_str() {
393                "&&" => BinOp::BoolAnd,
394                "^^" => BinOp::BoolXor,
395                _ => unreachable!(),
396            };
397
398            expr = Expr::BinOp {
399                op,
400                lvalue: Box::new(expr),
401                rvalue: Box::new(Self::parse_expr_or(pair)?),
402            };
403        }
404
405        Ok(expr)
406    }
407
408    fn parse_expr_or(expr: Pair<'_, Rule>) -> Result<Expr, AstError> {
409        let mut pairs = expr.into_inner();
410        let mut expr = Self::parse_expr_xor(pairs.next().unwrap())?;
411
412        for (_, pair) in pairs.tuples() {
413            expr = Expr::BinOp {
414                op: BinOp::Or,
415                lvalue: Box::new(expr),
416                rvalue: Box::new(Self::parse_expr_xor(pair)?),
417            };
418        }
419
420        Ok(expr)
421    }
422
423    fn parse_expr_xor(expr: Pair<'_, Rule>) -> Result<Expr, AstError> {
424        let mut pairs = expr.into_inner();
425        let mut expr = Self::parse_expr_and(pairs.next().unwrap())?;
426
427        for (_, pair) in pairs.tuples() {
428            expr = Expr::BinOp {
429                op: BinOp::Xor,
430                lvalue: Box::new(expr),
431                rvalue: Box::new(Self::parse_expr_and(pair)?),
432            };
433        }
434
435        Ok(expr)
436    }
437
438    fn parse_expr_and(expr: Pair<'_, Rule>) -> Result<Expr, AstError> {
439        let mut pairs = expr.into_inner();
440        let mut expr = Self::parse_expr_eq(pairs.next().unwrap())?;
441
442        for (_, pair) in pairs.tuples() {
443            expr = Expr::BinOp {
444                op: BinOp::And,
445                lvalue: Box::new(expr),
446                rvalue: Box::new(Self::parse_expr_eq(pair)?),
447            };
448        }
449
450        Ok(expr)
451    }
452
453    fn parse_expr_eq(expr: Pair<'_, Rule>) -> Result<Expr, AstError> {
454        let mut pairs = expr.into_inner();
455        let mut expr = Self::parse_expr_comp(pairs.next().unwrap())?;
456
457        for (op, pair) in pairs.tuples() {
458            let op = match op.as_str() {
459                "==" => BinRel::Eq,
460                "!=" => BinRel::NotEq,
461                "f==" => BinRel::FloatEq,
462                "f!=" => BinRel::FloatNotEq,
463                _ => unreachable!(),
464            };
465
466            expr = Expr::BinRel {
467                op,
468                lvalue: Box::new(expr),
469                rvalue: Box::new(Self::parse_expr_comp(pair)?),
470            };
471        }
472
473        Ok(expr)
474    }
475
476    fn parse_expr_comp(expr: Pair<'_, Rule>) -> Result<Expr, AstError> {
477        let mut pairs = expr.into_inner();
478        let mut expr = Self::parse_expr_shift(pairs.next().unwrap())?;
479
480        for (op, pair) in pairs.tuples() {
481            let op = match op.as_str() {
482                "<" => BinRel::Less,
483                "<=" => BinRel::LessEq,
484                ">" => BinRel::Greater,
485                ">=" => BinRel::GreaterEq,
486                "s<" => BinRel::SignedLess,
487                "s<=" => BinRel::SignedLessEq,
488                "s>" => BinRel::SignedGreater,
489                "s>=" => BinRel::SignedGreaterEq,
490                "f<" => BinRel::FloatLess,
491                "f<=" => BinRel::FloatLessEq,
492                "f>" => BinRel::FloatGreater,
493                "f>=" => BinRel::FloatGreaterEq,
494                _ => unreachable!(),
495            };
496
497            expr = Expr::BinRel {
498                op,
499                lvalue: Box::new(expr),
500                rvalue: Box::new(Self::parse_expr_shift(pair)?),
501            };
502        }
503
504        Ok(expr)
505    }
506
507    fn parse_expr_shift(expr: Pair<'_, Rule>) -> Result<Expr, AstError> {
508        let mut pairs = expr.into_inner();
509        let mut expr = Self::parse_expr_add(pairs.next().unwrap())?;
510
511        for (op, pair) in pairs.tuples() {
512            let op = match op.as_str() {
513                "<<" => BinOp::ShiftLeft,
514                ">>" => BinOp::ShiftRight,
515                ">>>" => BinOp::SignedShiftRight,
516                _ => unreachable!(),
517            };
518
519            expr = Expr::BinOp {
520                op,
521                lvalue: Box::new(expr),
522                rvalue: Box::new(Self::parse_expr_add(pair)?),
523            };
524        }
525
526        Ok(expr)
527    }
528
529    fn parse_expr_add(expr: Pair<'_, Rule>) -> Result<Expr, AstError> {
530        let mut pairs = expr.into_inner();
531        let mut expr = Self::parse_expr_mult(pairs.next().unwrap())?;
532
533        for (op, pair) in pairs.tuples() {
534            let op = match op.as_str() {
535                "+" => BinOp::Add,
536                "-" => BinOp::Sub,
537                "f+" => BinOp::FloatAdd,
538                "f-" => BinOp::FloatSub,
539                _ => unreachable!(),
540            };
541
542            expr = Expr::BinOp {
543                op,
544                lvalue: Box::new(expr),
545                rvalue: Box::new(Self::parse_expr_mult(pair)?),
546            };
547        }
548
549        Ok(expr)
550    }
551
552    fn parse_expr_mult(expr: Pair<'_, Rule>) -> Result<Expr, AstError> {
553        let mut pairs = expr.into_inner();
554        let mut expr = Self::parse_expr_unary(pairs.next().unwrap())?;
555
556        for (op, pair) in pairs.tuples() {
557            let op = match op.as_str() {
558                "*" => BinOp::Mul,
559                "/" => BinOp::Div,
560                "%" => BinOp::Rem,
561                "s/" => BinOp::SignedDiv,
562                "s%" => BinOp::SignedRem,
563                "f*" => BinOp::FloatMul,
564                "f/" => BinOp::FloatDiv,
565                _ => unreachable!(),
566            };
567
568            expr = Expr::BinOp {
569                op,
570                lvalue: Box::new(expr),
571                rvalue: Box::new(Self::parse_expr_unary(pair)?),
572            };
573        }
574
575        Ok(expr)
576    }
577
578    fn parse_expr_unary(expr: Pair<'_, Rule>) -> Result<Expr, AstError> {
579        let mut pairs = expr.into_inner();
580        let op_or_expr = pairs.next().unwrap();
581
582        if op_or_expr.as_rule() == Rule::unary_op {
583            let op = match op_or_expr.as_str() {
584                "!" => UnOp::BoolNot,
585                "~" => UnOp::Not,
586                "-" => UnOp::Neg,
587                "f-" => UnOp::FloatNeg,
588                _ => {
589                    let (space, size) =
590                        Self::parse_sized_star(op_or_expr.into_inner().next().unwrap())?;
591                    let source = Self::parse_expr_func(pairs.next().unwrap())?;
592
593                    return Ok(Expr::Load {
594                        space,
595                        size,
596                        source: Box::new(source),
597                    });
598                }
599            };
600
601            Ok(Expr::UnOp {
602                op,
603                value: Box::new(Self::parse_expr_func(pairs.next().unwrap())?),
604            })
605        } else {
606            Self::parse_expr_func(op_or_expr)
607        }
608    }
609
610    fn parse_expr_func(expr: Pair<'_, Rule>) -> Result<Expr, AstError> {
611        let mut pairs = expr.into_inner();
612        let expr = pairs.next().unwrap();
613
614        match expr.as_rule() {
615            Rule::expr_apply => Self::parse_expr_apply(expr),
616            Rule::expr_term => Self::parse_expr_term(expr),
617            _ => unreachable!("{expr:#?}"),
618        }
619    }
620
621    fn parse_expr_apply(expr: Pair<'_, Rule>) -> Result<Expr, AstError> {
622        let mut pairs = expr.into_inner();
623        let name = Self::parse_identifier(pairs.next().unwrap())?;
624
625        let arguments = pairs
626            .next()
627            .unwrap()
628            .into_inner()
629            .next()
630            .map(|exprs| {
631                exprs
632                    .into_inner()
633                    .map(|expr| Self::parse_expr(expr))
634                    .collect::<Result<Vec<_>, _>>()
635            })
636            .transpose()?
637            .unwrap_or_default();
638
639        Ok(Expr::Intrinsic { name, arguments })
640    }
641
642    fn parse_expr_term(expr: Pair<'_, Rule>) -> Result<Expr, AstError> {
643        let mut pairs = expr.into_inner();
644        let expr = pairs.next().unwrap();
645
646        match expr.as_rule() {
647            Rule::varnode => Self::parse_varnode(expr),
648            Rule::sembitrange => {
649                let (value, range, size) = Self::parse_sembitrange(expr)?;
650                Ok(Expr::BitsOf { value, range, size })
651            }
652            _ => Self::parse_expr(expr),
653        }
654    }
655
656    fn parse_varnode(target: Pair<'_, Rule>) -> Result<Expr, AstError> {
657        let mut pairs = target.into_inner();
658        let target = pairs.next().unwrap();
659
660        let result = match target.as_rule() {
661            Rule::integer => Expr::Literal {
662                value: Self::parse_integer(target)?,
663                size: None,
664            },
665            Rule::identifier => Expr::Ident {
666                value: Self::parse_identifier(target)?,
667                size: None,
668            },
669            Rule::addressof => Expr::AddressOf {
670                value: Self::parse_identifier(target.into_inner().next().unwrap())?,
671                size: None,
672            },
673            Rule::integer_with_size => {
674                let (value, size) = Self::parse_sized(target, Self::parse_integer)?;
675                Expr::Literal {
676                    value,
677                    size: Some(size),
678                }
679            }
680            Rule::identifier_with_size => {
681                let (value, size) = Self::parse_sized(target, Self::parse_identifier)?;
682                Expr::Ident {
683                    value,
684                    size: Some(size),
685                }
686            }
687            Rule::addressof_with_size => {
688                let mut pairs = target.into_inner();
689
690                let size = Self::parse_size(pairs.next().unwrap(), true)?;
691                let value = Self::parse_identifier(pairs.next().unwrap())?;
692
693                Expr::AddressOf {
694                    value,
695                    size: Some(size),
696                }
697            }
698            _ => {
699                todo!("varnode: {target:#?}")
700            }
701        };
702
703        Ok(result)
704    }
705
706    fn parse_integer(target: Pair<'_, Rule>) -> Result<u64, AstError> {
707        let integer = target.as_str();
708
709        if let Some(integer) = integer.strip_prefix("0x") {
710            return u64::from_str_radix(integer, 16)
711                .map_err(|e| AstError::Integer(target.as_span().into(), e));
712        }
713
714        if let Some(integer) = integer.strip_prefix("0b") {
715            return u64::from_str_radix(integer, 2)
716                .map_err(|e| AstError::Integer(target.as_span().into(), e));
717        }
718
719        u64::from_str_radix(integer, 10).map_err(|e| AstError::Integer(target.as_span().into(), e))
720    }
721
722    fn parse_size(target: Pair<'_, Rule>, check: bool) -> Result<u32, AstError> {
723        let integer = target.as_str();
724
725        if let Some(integer) = integer.strip_prefix("0x") {
726            return u32::from_str_radix(integer, 16)
727                .map_err(|e| AstError::Integer(target.as_span().into(), e));
728        }
729
730        if let Some(integer) = integer.strip_prefix("0b") {
731            return u32::from_str_radix(integer, 2)
732                .map_err(|e| AstError::Integer(target.as_span().into(), e));
733        }
734
735        let size = u32::from_str_radix(integer, 10)
736            .map_err(|e| AstError::Integer(target.as_span().into(), e))?;
737
738        if check && size == 0 {
739            Err(AstError::Size(target.as_span().into()))
740        } else {
741            Ok(size)
742        }
743    }
744
745    fn parse_identifier(target: Pair<'_, Rule>) -> Result<Ident, AstError> {
746        Ok(target.as_str().into())
747    }
748
749    fn parse_sembitrange(target: Pair<'_, Rule>) -> Result<(Ident, Range<u32>, u32), AstError> {
750        let mut pairs = target.into_inner();
751
752        let ident = Self::parse_identifier(pairs.next().unwrap())?;
753        let start = Self::parse_size(pairs.next().unwrap(), false)?;
754
755        let nbits_pair = pairs.next().unwrap();
756        let nbits_span = nbits_pair.as_span();
757        let nbits = Self::parse_size(nbits_pair, false)?;
758
759        if nbits == 0 {
760            return Err(AstError::BitRange(nbits_span.into()));
761        }
762
763        let size = nbits.div_ceil(8);
764
765        Ok((ident, start..start + nbits, size))
766    }
767
768    fn parse_sized<F, T>(target: Pair<'_, Rule>, rule: F) -> Result<(T, u32), AstError>
769    where
770        F: FnOnce(Pair<'_, Rule>) -> Result<T, AstError>,
771    {
772        let mut pairs = target.into_inner();
773
774        let t = rule(pairs.next().unwrap())?;
775        let s = Self::parse_size(pairs.next().unwrap(), true)?;
776
777        Ok((t, s))
778    }
779
780    fn parse_sized_star(target: Pair<'_, Rule>) -> Result<(Option<Ident>, Option<u32>), AstError> {
781        let mut pairs = target.into_inner();
782
783        let Some(ident_or_size) = pairs.next() else {
784            return Ok((None, None));
785        };
786
787        let ident = if ident_or_size.as_rule() == Rule::identifier {
788            Self::parse_identifier(ident_or_size)?
789        } else {
790            return Ok((None, Some(Self::parse_size(ident_or_size, true)?)));
791        };
792
793        let Some(size) = pairs.next() else {
794            return Ok((Some(ident), None));
795        };
796
797        Ok((Some(ident), Some(Self::parse_size(size, true)?)))
798    }
799
800    fn parse_branch_target(target: Pair<'_, Rule>) -> Result<BranchTarget, AstError> {
801        let mut pairs = target.into_inner();
802        let target = pairs.next().unwrap();
803
804        let target = match target.as_rule() {
805            Rule::identifier => BranchTarget::Direct(BranchLabel::Varnode {
806                name: Self::parse_identifier(target)?,
807            }),
808            Rule::offset_in_space => {
809                let mut pairs = target.into_inner();
810
811                let offset = Self::parse_integer(pairs.next().unwrap())?;
812                let space = Self::parse_identifier(pairs.next().unwrap())?;
813
814                BranchTarget::Direct(BranchLabel::Offset {
815                    offset,
816                    space: Some(space),
817                })
818            }
819            Rule::integer => {
820                let offset = Self::parse_integer(target)?;
821
822                BranchTarget::Direct(BranchLabel::Offset {
823                    offset,
824                    space: None,
825                })
826            }
827            Rule::label => {
828                let label = Self::parse_identifier(target.into_inner().next().unwrap())?;
829
830                BranchTarget::Label(label)
831            }
832            _ => {
833                let target = Self::parse_expr(target)?;
834                BranchTarget::Indirect(target)
835            }
836        };
837
838        Ok(target)
839    }
840
841    fn parse_stmt(pair: Pair<'_, Rule>) -> Result<Stmt, AstError> {
842        let pair = pair.into_inner().next().unwrap();
843
844        match pair.as_rule() {
845            Rule::assignment => {
846                let assign = pair.into_inner().next().unwrap();
847                let decl = assign.as_rule() == Rule::assignment_with_local;
848                Self::parse_assignment(assign, decl)
849            }
850            Rule::declaration => {
851                let decl = pair.into_inner().next().unwrap();
852                Self::parse_declaration(decl)
853            }
854            Rule::funcall => {
855                let mut pairs = pair.into_inner().next().unwrap().into_inner();
856                let name = pairs.next().unwrap();
857                let args = pairs
858                    .next()
859                    .unwrap()
860                    .into_inner()
861                    .next() // expr_list
862                    .map(|list| {
863                        list.into_inner()
864                            .map(Self::parse_expr)
865                            .collect::<Result<Vec<_>, _>>()
866                    })
867                    .transpose()?
868                    .unwrap_or_default();
869
870                Ok(Stmt::Intrinsic {
871                    name: name.as_str().into(),
872                    arguments: args,
873                })
874            }
875            Rule::goto_stmt => {
876                let target = pair.into_inner().next().unwrap();
877                Ok(Stmt::Branch {
878                    target: Self::parse_branch_target(target)?,
879                })
880            }
881            Rule::cond_stmt => {
882                let mut pairs = pair.into_inner();
883                let condition = Self::parse_expr(pairs.next().unwrap())?;
884                let target = Self::parse_branch_target(pairs.next().unwrap())?;
885                Ok(Stmt::CBranch { condition, target })
886            }
887            Rule::call_stmt => {
888                let target = pair.into_inner().next().unwrap();
889                Ok(Stmt::Call {
890                    target: Self::parse_branch_target(target)?,
891                })
892            }
893            Rule::return_stmt => {
894                let target = pair.into_inner().next().unwrap();
895                Ok(Stmt::Return {
896                    target: BranchTarget::Indirect(Self::parse_expr(target)?),
897                })
898            }
899            Rule::label => {
900                let label = pair.into_inner().next().unwrap();
901                Ok(Stmt::Label {
902                    label: label.as_str().into(),
903                })
904            }
905            rule => unreachable!("{rule:?}"),
906        }
907    }
908}
909
910#[cfg(test)]
911mod test {
912    use super::*;
913
914    #[test]
915    fn test_parse() -> Result<(), Box<dyn std::error::Error>> {
916        let ast = CodeBlock::parse(r#"
917            memcpy(a || b || c, *[other] b, 10 + 20);
918
919            local b:32 = 10;
920            local a = *b;
921
922            *[ram] (a + 10) = b;
923
924            < label >
925
926            a[0,1] = b[0,2] * c[0,10] + d:10 / b(10);
927
928            goto dest1;
929            goto [dest2];
930            goto 0x0 [codespace];
931            goto 0x0;
932            goto [0x10 + 20];
933
934            if (var > 10) goto <dest>;
935
936            return [dest3];
937            return [0x10];
938
939            *:4 sp = inst_next;
940            sp = sp-4;
941            call dest;
942"#,
943        );
944
945        assert!(ast.is_ok());
946        assert_eq!(ast.unwrap().stmts.len(), 17);
947
948        Ok(())
949    }
950}