Skip to main content

mist_parser/parser/common/
expr.rs

1use crate::{
2    Rule,
3    ast::*,
4    ast_ensure, ast_expr,
5    error::{AstError, AstResult, GetLength, IntoErr, collect_recovered, collect_recovered_map},
6};
7use pest::pratt_parser::PrattParser;
8use std::sync::OnceLock;
9
10impl<'a> TryFrom<pest::iterators::Pair<'a, Rule>> for Expression {
11    type Error = AstError<'a, Self>;
12
13    fn try_from(pair: pest::iterators::Pair<'a, Rule>) -> Result<Self, Self::Error> {
14        let rule = pair.as_rule();
15        let inner = pair.clone().into_inner();
16
17        match rule {
18            Rule::expr => {
19                static PRATT_PARSER: OnceLock<PrattParser<Rule>> = OnceLock::new();
20                let pratt = PRATT_PARSER.get_or_init(|| {
21                    use Rule::*;
22                    use pest::pratt_parser::{Assoc::*, Op};
23
24                    PrattParser::new()
25                        .op(Op::infix(range_inc, Left) | Op::infix(range_exc, Left))
26                        .op(Op::infix(or, Left))
27                        .op(Op::infix(and, Left))
28                        .op(Op::infix(bitor, Left))
29                        .op(Op::infix(bitxor, Left))
30                        .op(Op::infix(bitand, Left))
31                        .op(Op::infix(eq, Left) | Op::infix(neq, Left))
32                        .op(Op::infix(lt, Left)
33                            | Op::infix(lte, Left)
34                            | Op::infix(gt, Left)
35                            | Op::infix(gte, Left))
36                        .op(Op::infix(shl, Left) | Op::infix(shr, Left))
37                        .op(Op::infix(add, Left) | Op::infix(sub, Left))
38                        .op(Op::infix(mul, Left) | Op::infix(div, Left) | Op::infix(rem, Left))
39                });
40
41                pratt
42                    .map_primary(|primary_pair| Expression::try_from(primary_pair))
43                    .map_infix(|lhs, op, rhs| {
44                        let bin_op = match op.as_rule() {
45                            Rule::shl => BinaryOp::ShiftLeft,
46                            Rule::shr => BinaryOp::ShiftRight,
47                            Rule::range_inc => BinaryOp::RangeInclusive,
48                            Rule::range_exc => BinaryOp::RangeExclusive,
49                            Rule::lte => BinaryOp::LessThanOrEqual,
50                            Rule::gte => BinaryOp::GreaterThanOrEqual,
51                            Rule::eq => BinaryOp::Equal,
52                            Rule::neq => BinaryOp::NotEqual,
53                            Rule::and => BinaryOp::And,
54                            Rule::or => BinaryOp::Or,
55                            Rule::add => BinaryOp::Plus,
56                            Rule::sub => BinaryOp::Minus,
57                            Rule::mul => BinaryOp::Multiply,
58                            Rule::div => BinaryOp::Divide,
59                            Rule::rem => BinaryOp::Modulo,
60                            Rule::lt => BinaryOp::LessThan,
61                            Rule::gt => BinaryOp::GreaterThan,
62                            Rule::bitand => BinaryOp::BitAnd,
63                            Rule::bitor => BinaryOp::BitOr,
64                            Rule::bitxor => BinaryOp::BitXor,
65                            _ => return AstError::bug_unimplemented(op),
66                        };
67
68                        ast_expr!(Expression::Binary {
69                            lhs: lhs.map(Box::new),
70                            op: Ok(bin_op) as AstResult<'_, BinaryOp>,
71                            rhs: rhs.map(Box::new),
72                        })
73                    })
74                    .parse(inner)
75            }
76
77            Rule::term => {
78                let mut prefix_pairs = Vec::new();
79                let mut primary_pair = None;
80                let mut postfix_pairs = Vec::new();
81
82                for p in inner {
83                    match p.as_rule() {
84                        Rule::prefix => prefix_pairs.push(p),
85                        Rule::primary => primary_pair = Some(p),
86                        Rule::postfix => postfix_pairs.push(p),
87                        _ => {}
88                    }
89                }
90
91                let prefixes = collect_recovered::<Prefix, Prefix>(prefix_pairs.into_iter());
92                let exp = Expression::try_from(
93                    primary_pair.expect("Term must contain a primary expression"),
94                );
95                let postfixes = collect_recovered::<Postfix, Postfix>(postfix_pairs.into_iter());
96
97                if postfixes.len() > 0 || prefixes.len() > 0 {
98                    ast_expr!(Expression::Fix {
99                        initial: exp.map(Box::new),
100                        prefixes: prefixes,
101                        postfixes: postfixes,
102                    })
103                } else {
104                    ast_expr!(use exp?, prefixes, postfixes)
105                }
106            }
107
108            Rule::tuple => {
109                ast_expr!(Expression::Literal(
110                    collect_recovered(pair.into_inner())
111                        .map(Literal::Tuple)
112                        .get_map(Literal::Tuple)
113                ))
114            }
115
116            Rule::primary => pair.into_inner().next().unwrap().try_into(),
117            Rule::static_path => ast_expr!(Expression::Path(pair.try_into())),
118            Rule::literal => ast_expr!(Expression::Literal(pair.try_into())),
119            Rule::expr_path => ast_expr!(Expression::Path(pair.try_into())),
120
121            _ => AstError::bug_unimplemented(pair),
122        }
123    }
124}
125
126impl<'a> TryFrom<pest::iterators::Pair<'a, Rule>> for Prefix {
127    type Error = AstError<'a, Self>;
128
129    fn try_from(pair: pest::iterators::Pair<'a, Rule>) -> Result<Self, Self::Error> {
130        Ok(match pair.as_rule() {
131            Rule::prefix => Self::try_from(pair.into_inner().next().unwrap())?,
132            Rule::deref_px => Self::Deref,
133            Rule::mut_ref_px => Self::RefMut,
134            Rule::ref_px => Self::Ref,
135            Rule::new_px => Self::New(
136                pair.into_inner()
137                    .next()
138                    .map(|v| v.try_into().get())
139                    .transpose()?,
140            ),
141            Rule::not_px => Self::Not,
142            Rule::neg_px => Self::Neg,
143
144            _ => return AstError::bug_unimplemented(pair),
145        })
146    }
147}
148
149impl<'a> TryFrom<pest::iterators::Pair<'a, Rule>> for Postfix {
150    type Error = AstError<'a, Self>;
151
152    fn try_from(pair: pest::iterators::Pair<'a, Rule>) -> Result<Self, Self::Error> {
153        let rule = pair.as_rule();
154        let mut inner = pair.clone().into_inner();
155
156        match rule {
157            Rule::postfix => Postfix::try_from(inner.next().unwrap()),
158
159            Rule::field_px => {
160                ast_expr!(Postfix::FieldAccess(
161                    inner.next().unwrap().try_into(),
162                    inner.next().map(Generics::try_from).transpose()
163                ))
164            }
165
166            Rule::call_px => ast_expr!(Postfix::Call(collect_recovered(inner))),
167
168            Rule::struct_px => ast_expr!(Postfix::StructCall(collect_recovered_map(inner, |p| {
169                let mut pi = p.into_inner();
170                Ok((
171                    Identifier::try_from(pi.next().unwrap())?,
172                    Expression::try_from(pi.next().unwrap()).get()?,
173                ))
174            }))),
175
176            Rule::index_px => {
177                ast_expr!(Postfix::Index(Expression::try_from(inner.next().unwrap())))
178            }
179
180            Rule::macro_call_px => Ok(Postfix::MacroCall(inner.as_str().to_string())),
181
182            Rule::as_px => {
183                ast_expr!(Postfix::As(inner.next().unwrap().try_into()))
184            }
185
186            Rule::try_px => Ok(Postfix::Try),
187
188            _ => AstError::bug_unimplemented(pair),
189        }
190    }
191}
192
193impl<'a> TryFrom<pest::iterators::Pair<'a, Rule>> for ExprPath {
194    type Error = AstError<'a, Self>;
195
196    fn try_from(pair: pest::iterators::Pair<'a, Rule>) -> Result<Self, Self::Error> {
197        ast_ensure!(pair, Rule::expr_path => {
198            ast_expr!(ExprPath(collect_recovered(pair.into_inner())))
199        })
200    }
201}
202
203impl<'a> TryFrom<pest::iterators::Pair<'a, Rule>> for ExprPathSegment {
204    type Error = AstError<'a, Self>;
205
206    fn try_from(pair: pest::iterators::Pair<'a, Rule>) -> Result<Self, Self::Error> {
207        let mut inner = pair.clone().into_inner();
208
209        ast_ensure!(pair, Rule::expr_path_segment => {
210            ast_expr!(ExprPathSegment {
211                ident: Identifier::try_from(inner.next().unwrap()),
212                generics: inner.next().map(Generics::try_from).transpose(),
213            })
214        })
215    }
216}