Skip to main content

mist_parser/parser/common/
expr.rs

1use crate::{
2    Rule,
3    ast::*,
4    ast_ensure, ast_expr,
5    error::{AstError, AstResult, GetLength, IntoErr, collect_recovered, collect_recovered_map},
6    parser::consume_rule,
7};
8use pest::pratt_parser::PrattParser;
9use std::sync::OnceLock;
10
11impl<'a> TryFrom<pest::iterators::Pair<'a, Rule>> for Expression {
12    type Error = AstError<'a, Self>;
13
14    fn try_from(pair: pest::iterators::Pair<'a, Rule>) -> Result<Self, Self::Error> {
15        let rule = pair.as_rule();
16        let mut inner = pair.clone().into_inner();
17
18        match rule {
19            Rule::expr => {
20                static PRATT_PARSER: OnceLock<PrattParser<Rule>> = OnceLock::new();
21                let pratt = PRATT_PARSER.get_or_init(|| {
22                    use pest::pratt_parser::{Assoc::*, Op};
23
24                    PrattParser::new().op(Op::infix(Rule::bin_op, Left))
25                });
26
27                pratt
28                    .map_primary(|primary_pair| Expression::try_from(primary_pair))
29                    .map_infix(|expr, op, rhs| {
30                        ast_expr!(Expression::Binary {
31                            lhs: expr.map(Box::new).get_map(Box::new),
32                            op: Ok(op.as_str().to_string()) as AstResult<'_, String>,
33                            rhs: rhs.map(Box::new).get_map(Box::new),
34                        })
35                    })
36                    .parse(inner)
37            }
38
39            Rule::term => {
40                let mut prefix_pairs = Vec::new();
41                let mut primary_pair = None;
42                let mut postfix_pairs = Vec::new();
43
44                for p in inner {
45                    match p.as_rule() {
46                        Rule::prefix => prefix_pairs.push(p),
47                        Rule::primary => primary_pair = Some(p),
48                        Rule::postfix => postfix_pairs.push(p),
49                        _ => {}
50                    }
51                }
52
53                let prefixes = collect_recovered::<Prefix, Prefix>(prefix_pairs.into_iter());
54                let exp = Expression::try_from(
55                    primary_pair.expect("Term must contain a primary expression"),
56                );
57                let postfixes = collect_recovered::<Postfix, Postfix>(postfix_pairs.into_iter());
58
59                if postfixes.len() > 0 || prefixes.len() > 0 {
60                    ast_expr!(Expression::Fix {
61                        initial: exp.map(Box::new),
62                        prefixes: prefixes,
63                        postfixes: postfixes,
64                    })
65                } else {
66                    ast_expr!(use exp?, prefixes, postfixes)
67                }
68            }
69
70            Rule::tuple => {
71                ast_expr!(Expression::Literal(
72                    collect_recovered(pair.into_inner())
73                        .map(Literal::Tuple)
74                        .get_map(Literal::Tuple)
75                ))
76            }
77
78            Rule::closure => {
79                ast_expr!(Expression::Closure {
80                    params: collect_recovered(inner.next().unwrap().into_inner()),
81                    return_type: consume_rule(&mut inner, Rule::type_expr)
82                        .map(TypeExpr::try_from)
83                        .transpose(),
84                    body: Expression::try_from(inner.next().unwrap()).map(Box::new),
85                })
86            }
87
88            Rule::array => {
89                ast_expr!(Expression::Array(collect_recovered(inner)))
90            }
91
92            Rule::array_repeat => {
93                ast_expr!(Expression::ArrayRepeat(
94                    Expression::try_from(inner.next().unwrap()).map(Box::new),
95                    Expression::try_from(inner.next().unwrap()).map(Box::new)
96                ))
97            }
98
99            Rule::primary => pair.into_inner().next().unwrap().try_into(),
100            Rule::static_path => ast_expr!(Expression::Path(pair.try_into())),
101            Rule::literal => ast_expr!(Expression::Literal(pair.try_into())),
102            Rule::expr_path => ast_expr!(Expression::Path(pair.try_into())),
103            Rule::statement_wrapper => {
104                let i = inner.next().unwrap();
105                match i.as_rule() {
106                    Rule::expr => i.try_into(),
107                    _ => ast_expr!(Expression::Statement(
108                        i.try_into().get_map(Box::new).map(Box::new)
109                    )),
110                }
111            }
112            Rule::statement | Rule::basic_stmt | Rule::control_flow | Rule::block => ast_expr!(
113                Expression::Statement(pair.try_into().get_map(Box::new).map(Box::new))
114            ),
115
116            _ => AstError::bug_unimplemented(pair),
117        }
118    }
119}
120
121impl<'a> TryFrom<pest::iterators::Pair<'a, Rule>> for Prefix {
122    type Error = AstError<'a, Self>;
123
124    fn try_from(pair: pest::iterators::Pair<'a, Rule>) -> Result<Self, Self::Error> {
125        Ok(match pair.as_rule() {
126            Rule::prefix => Self::try_from(pair.into_inner().next().unwrap())?,
127            Rule::deref_px => Self::Deref,
128            Rule::mut_ref_px => Self::RefMut,
129            Rule::ref_px => Self::Ref,
130            Rule::not_px => Self::Not,
131            Rule::neg_px => Self::Neg,
132
133            _ => return AstError::bug_unimplemented(pair),
134        })
135    }
136}
137
138impl<'a> TryFrom<pest::iterators::Pair<'a, Rule>> for Postfix {
139    type Error = AstError<'a, Self>;
140
141    fn try_from(pair: pest::iterators::Pair<'a, Rule>) -> Result<Self, Self::Error> {
142        let rule = pair.as_rule();
143        let mut inner = pair.clone().into_inner();
144
145        match rule {
146            Rule::postfix => Postfix::try_from(inner.next().unwrap()),
147
148            Rule::field_px => {
149                ast_expr!(Postfix::FieldAccess(
150                    inner.next().unwrap().try_into(),
151                    inner.next().map(Generics::try_from).transpose()
152                ))
153            }
154
155            Rule::tuple_field_px => {
156                ast_expr!(Postfix::TupleFieldAccess(
157                    Ok(inner.next().unwrap().as_str().parse().unwrap_or(255_u8))
158                        as AstResult<'_, u8>,
159                    inner.next().map(Generics::try_from).transpose(),
160                ))
161            }
162
163            Rule::call_px => ast_expr!(Postfix::Call(collect_recovered(inner))),
164
165            Rule::struct_px => ast_expr!(Postfix::StructCall(collect_recovered_map(inner, |p| {
166                let mut pi = p.into_inner();
167                Ok((
168                    Identifier::try_from(pi.next().unwrap())?,
169                    pi.next().map(Expression::try_from).transpose().get()?,
170                ))
171            }))),
172
173            Rule::index_px => {
174                ast_expr!(Postfix::Index(Expression::try_from(inner.next().unwrap())))
175            }
176
177            Rule::macro_call_paren => Ok(Postfix::MacroCall {
178                inner: inner.as_str().to_string(),
179                delimiter: MacroDelimiter::Paren,
180            }),
181            Rule::macro_call_bracket => Ok(Postfix::MacroCall {
182                inner: inner.as_str().to_string(),
183                delimiter: MacroDelimiter::Bracket,
184            }),
185            Rule::macro_call_brace => Ok(Postfix::MacroCall {
186                inner: inner.as_str().to_string(),
187                delimiter: MacroDelimiter::Brace,
188            }),
189
190            Rule::as_px => {
191                ast_expr!(Postfix::As(inner.next().unwrap().try_into()))
192            }
193
194            Rule::try_px => Ok(Postfix::Try),
195
196            Rule::increment => Ok(Postfix::Increment),
197            Rule::decrement => Ok(Postfix::Decrement),
198
199            _ => AstError::bug_unimplemented(pair),
200        }
201    }
202}
203
204impl<'a> TryFrom<pest::iterators::Pair<'a, Rule>> for ExprPath {
205    type Error = AstError<'a, Self>;
206
207    fn try_from(pair: pest::iterators::Pair<'a, Rule>) -> Result<Self, Self::Error> {
208        ast_ensure!(pair, Rule::expr_path => {
209            ast_expr!(ExprPath(collect_recovered(pair.into_inner())))
210        })
211    }
212}
213
214impl<'a> TryFrom<pest::iterators::Pair<'a, Rule>> for ExprPathSegment {
215    type Error = AstError<'a, Self>;
216
217    fn try_from(pair: pest::iterators::Pair<'a, Rule>) -> Result<Self, Self::Error> {
218        let mut inner = pair.clone().into_inner();
219
220        ast_ensure!(pair, Rule::expr_path_segment => {
221            ast_expr!(ExprPathSegment {
222                ident: Identifier::try_from(inner.next().unwrap()),
223                generics: inner.next().map(Generics::try_from).transpose(),
224            })
225        })
226    }
227}