Skip to main content

mest_core/
ast.rs

1use bumpalo::Bump;
2use chumsky::span::SimpleSpan;
3use lasso::{Rodeo, Spur};
4use std::{cell::RefCell, ops::Deref, rc::Rc};
5
6use crate::thunk::Thunk;
7
8#[derive(Debug, Clone)]
9pub enum Literal {
10    Int(i64),
11    Float(f64),
12    Bool(bool),
13}
14
15#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
16pub struct Ident(pub Spur);
17
18#[derive(Debug, Clone)]
19pub enum BinOp {
20    Eq,
21    NotEq,
22    Lt,
23    Gt,
24    Le,
25    Ge,
26    And,
27    Or,
28    Add,
29    Sub,
30    Mul,
31    Div,
32    Pow,
33}
34
35#[derive(Debug, Clone)]
36pub enum UnaryOp {
37    Neg,
38    Not,
39}
40
41#[derive(Debug, Clone)]
42pub enum Pat<'bump> {
43    Wildcard,
44    Var(Ident),
45    Lit(Literal),
46    Or(&'bump Pat<'bump>, &'bump Pat<'bump>),
47}
48
49#[derive(Debug, Clone, Copy)]
50pub struct Expr<'bump> {
51    pub kind: &'bump ExprKind<'bump>,
52    pub span: SimpleSpan,
53}
54
55impl<'bump> Deref for Expr<'bump> {
56    type Target = ExprKind<'bump>;
57
58    fn deref(&self) -> &Self::Target {
59        &self.kind
60    }
61}
62
63#[derive(Debug, Clone)]
64pub enum ExprKind<'bump> {
65    Literal(Literal),
66    Var(Ident),
67    If {
68        cond: Expr<'bump>,
69        then_expr: Expr<'bump>,
70        else_expr: Expr<'bump>,
71    },
72    BinOp {
73        op: BinOp,
74        lhs: Expr<'bump>,
75        rhs: Expr<'bump>,
76    },
77    UnaryOp {
78        op: UnaryOp,
79        rhs: Expr<'bump>,
80    },
81    Let {
82        name: Ident,
83        value: Expr<'bump>,
84        body: Expr<'bump>,
85        rec: bool,
86    },
87    Match {
88        scrutinee: Expr<'bump>,
89        arms: &'bump [(Pat<'bump>, Expr<'bump>)],
90    },
91    Abs {
92        param: Ident,
93        body: Expr<'bump>,
94    },
95    App {
96        func: Expr<'bump>,
97        arg: Expr<'bump>,
98    },
99}
100
101pub type Env<'bump> = im::HashMap<Ident, Thunk<'bump>>;
102
103#[derive(Debug, Clone)]
104pub enum Value<'bump> {
105    Int(i64),
106    Float(f64),
107    Bool(bool),
108    Closure {
109        param: Ident,
110        body: Expr<'bump>,
111        env: Env<'bump>,
112    },
113}
114
115impl std::fmt::Display for Value<'_> {
116    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117        match self {
118            Value::Int(n) => write!(f, "{n}"),
119            Value::Float(n) => write!(f, "{n}"),
120            Value::Bool(b) => write!(f, "{b}"),
121            Value::Closure { .. } => write!(f, "<closure>"),
122        }
123    }
124}
125
126#[derive(Debug, Clone)]
127pub enum EvalError {
128    UnboundVariable(String),
129    TypeMismatch {
130        expected: &'static str,
131        got: &'static str,
132    },
133    DivisionByZero,
134    NotAFunction,
135    NonExhaustiveMatch,
136}
137
138impl std::fmt::Display for EvalError {
139    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140        match self {
141            EvalError::NonExhaustiveMatch => write!(f, "non-exhaustive match"),
142            EvalError::UnboundVariable(name) => write!(f, "unbound variable `{name}`"),
143            EvalError::TypeMismatch { expected, got } => {
144                write!(f, "type mismatch: expected {expected}, got {got}")
145            }
146            EvalError::DivisionByZero => write!(f, "division by zero"),
147            EvalError::NotAFunction => write!(f, "applied a non-function value"),
148        }
149    }
150}
151
152fn type_name(v: &Value) -> &'static str {
153    match v {
154        Value::Int(_) => "int",
155        Value::Float(_) => "float",
156        Value::Bool(_) => "bool",
157        Value::Closure { .. } => "closure",
158    }
159}
160
161impl<'bump> ExprKind<'bump> {
162    fn node(expr: &'bump ExprKind<'bump>, span: SimpleSpan) -> Expr<'bump> {
163        Expr { kind: expr, span }
164    }
165
166    pub fn literal(bump: &'bump Bump, span: SimpleSpan, lit: Literal) -> Expr<'bump> {
167        Self::node(bump.alloc(ExprKind::Literal(lit)), span)
168    }
169
170    pub fn ident(bump: &'bump Bump, span: SimpleSpan, name: Ident) -> Expr<'bump> {
171        Self::node(bump.alloc(ExprKind::Var(name)), span)
172    }
173
174    pub fn if_expr(
175        bump: &'bump Bump,
176        span: SimpleSpan,
177        cond: Expr<'bump>,
178        then_expr: Expr<'bump>,
179        else_expr: Expr<'bump>,
180    ) -> Expr<'bump> {
181        Self::node(
182            bump.alloc(ExprKind::If {
183                cond,
184                then_expr,
185                else_expr,
186            }),
187            span,
188        )
189    }
190
191    pub fn binop(
192        bump: &'bump Bump,
193        span: SimpleSpan,
194        op: BinOp,
195        lhs: Expr<'bump>,
196        rhs: Expr<'bump>,
197    ) -> Expr<'bump> {
198        Self::node(bump.alloc(ExprKind::BinOp { op, lhs, rhs }), span)
199    }
200
201    pub fn unaryop(
202        bump: &'bump Bump,
203        span: SimpleSpan,
204        op: UnaryOp,
205        rhs: Expr<'bump>,
206    ) -> Expr<'bump> {
207        Self::node(bump.alloc(ExprKind::UnaryOp { op, rhs }), span)
208    }
209
210    pub fn let_expr(
211        bump: &'bump Bump,
212        span: SimpleSpan,
213        name: Ident,
214        value: Expr<'bump>,
215        body: Expr<'bump>,
216        rec: bool,
217    ) -> Expr<'bump> {
218        Self::node(
219            bump.alloc(ExprKind::Let {
220                name,
221                value,
222                body,
223                rec,
224            }),
225            span,
226        )
227    }
228
229    pub fn match_expr(
230        bump: &'bump Bump,
231        span: SimpleSpan,
232        scrutinee: Expr<'bump>,
233        arms: &'bump [(Pat<'bump>, Expr<'bump>)],
234    ) -> Expr<'bump> {
235        Self::node(bump.alloc(ExprKind::Match { scrutinee, arms }), span)
236    }
237
238    pub fn lambda(
239        bump: &'bump Bump,
240        span: SimpleSpan,
241        param: Ident,
242        body: Expr<'bump>,
243    ) -> Expr<'bump> {
244        Self::node(bump.alloc(ExprKind::Abs { param, body }), span)
245    }
246
247    pub fn app(
248        bump: &'bump Bump,
249        span: SimpleSpan,
250        func: Expr<'bump>,
251        arg: Expr<'bump>,
252    ) -> Expr<'bump> {
253        Self::node(bump.alloc(ExprKind::App { func, arg }), span)
254    }
255}
256
257impl<'bump> ExprKind<'bump> {
258    fn force(thunk: &Thunk<'bump>, rodeo: &Rodeo) -> Result<Value<'bump>, EvalError> {
259        thunk.force(rodeo)
260    }
261
262    pub fn thunk(expr: &'bump ExprKind<'bump>, env: &Env<'bump>) -> Thunk<'bump> {
263        Thunk::new(expr, env.clone())
264    }
265
266    pub fn eval_lazy(
267        &'bump self,
268        env: &Env<'bump>,
269        rodeo: &Rodeo,
270    ) -> Result<Value<'bump>, EvalError> {
271        match self {
272            ExprKind::Literal(Literal::Bool(b)) => Ok(Value::Bool(*b)),
273            ExprKind::Literal(Literal::Int(v)) => Ok(Value::Int(*v)),
274            ExprKind::Literal(Literal::Float(v)) => Ok(Value::Float(*v)),
275
276            ExprKind::Var(ident) => {
277                let thunk = env.get(ident).ok_or_else(|| {
278                    EvalError::UnboundVariable(rodeo.resolve(&ident.0).to_owned())
279                })?;
280                Self::force(thunk, rodeo)
281            }
282
283            ExprKind::UnaryOp { op, rhs } => {
284                let rhs = rhs.kind.eval_lazy(env, rodeo)?;
285                match (op, rhs) {
286                    (UnaryOp::Neg, Value::Int(n)) => Ok(Value::Int(-n)),
287                    (UnaryOp::Neg, Value::Float(f)) => Ok(Value::Float(-f)),
288                    (UnaryOp::Not, Value::Bool(b)) => Ok(Value::Bool(!b)),
289                    (UnaryOp::Neg, v) => Err(EvalError::TypeMismatch {
290                        expected: "number",
291                        got: type_name(&v),
292                    }),
293                    (UnaryOp::Not, v) => Err(EvalError::TypeMismatch {
294                        expected: "bool",
295                        got: type_name(&v),
296                    }),
297                }
298            }
299
300            ExprKind::BinOp { op, lhs, rhs } => {
301                let lhs = lhs.kind.eval_lazy(env, rodeo)?;
302                let rhs = rhs.kind.eval_lazy(env, rodeo)?;
303                Self::eval_binop(op, lhs, rhs)
304            }
305
306            ExprKind::Let {
307                name,
308                value,
309                body,
310                rec: true,
311            } => {
312                let rec_env = Rc::new(RefCell::new(env.clone()));
313                let thunk = Thunk::new_shared(value.kind, Rc::clone(&rec_env));
314                rec_env.borrow_mut().insert(*name, thunk.clone());
315                let mut body_env = env.clone();
316                body_env.insert(*name, thunk);
317                body.kind.eval_lazy(&body_env, rodeo)
318            }
319
320            ExprKind::Let {
321                name,
322                value,
323                body,
324                rec: false,
325            } => {
326                let mut env = env.clone();
327                env.insert(*name, Self::thunk(value.kind, &env));
328                body.kind.eval_lazy(&env, rodeo)
329            }
330
331            ExprKind::Match { scrutinee, arms } => {
332                let scrutinee_thunk = Thunk::new(scrutinee.kind, env.clone());
333                for (pat, body) in arms.iter() {
334                    let mut arm_env = env.clone();
335                    if Self::match_pat(pat, &scrutinee_thunk, &mut arm_env, rodeo)? {
336                        return body.kind.eval_lazy(&arm_env, rodeo);
337                    }
338                }
339                Err(EvalError::NonExhaustiveMatch)
340            }
341
342            ExprKind::If {
343                cond,
344                then_expr,
345                else_expr,
346            } => match cond.kind.eval_lazy(env, rodeo)? {
347                Value::Bool(true) => then_expr.kind.eval_lazy(env, rodeo),
348                Value::Bool(false) => else_expr.kind.eval_lazy(env, rodeo),
349                v => Err(EvalError::TypeMismatch {
350                    expected: "bool",
351                    got: type_name(&v),
352                }),
353            },
354
355            ExprKind::Abs { param, body } => Ok(Value::Closure {
356                param: *param,
357                body: *body,
358                env: env.clone(),
359            }),
360
361            ExprKind::App { func, arg } => {
362                let func = func.kind.eval_lazy(env, rodeo)?;
363                match func {
364                    Value::Closure {
365                        param,
366                        body,
367                        env: mut closure_env,
368                    } => {
369                        closure_env.insert(param, Self::thunk(arg.kind, env));
370                        body.kind.eval_lazy(&closure_env, rodeo)
371                    }
372                    _ => Err(EvalError::NotAFunction),
373                }
374            }
375        }
376    }
377
378    fn match_pat(
379        pat: &Pat<'bump>,
380        thunk: &Thunk<'bump>,
381        env: &mut Env<'bump>,
382        rodeo: &Rodeo,
383    ) -> Result<bool, EvalError> {
384        match pat {
385            Pat::Wildcard => Ok(true),
386            Pat::Var(name) => {
387                env.insert(*name, thunk.clone());
388                Ok(true)
389            }
390            Pat::Lit(lit) => {
391                let val = thunk.force(rodeo)?;
392                Ok(match (lit, &val) {
393                    (Literal::Int(a), Value::Int(b)) => a == b,
394                    (Literal::Float(a), Value::Float(b)) => a == b,
395                    (Literal::Bool(a), Value::Bool(b)) => a == b,
396                    _ => false,
397                })
398            }
399            Pat::Or(a, b) => {
400                let mut env_a = env.clone();
401                if Self::match_pat(a, thunk, &mut env_a, rodeo)? {
402                    *env = env_a;
403                    Ok(true)
404                } else {
405                    Self::match_pat(b, thunk, env, rodeo)
406                }
407            }
408        }
409    }
410
411    fn eval_binop(
412        op: &BinOp,
413        lhs: Value<'bump>,
414        rhs: Value<'bump>,
415    ) -> Result<Value<'bump>, EvalError> {
416        match (op, &lhs, &rhs) {
417            (BinOp::And, Value::Bool(l), Value::Bool(r)) => return Ok(Value::Bool(*l && *r)),
418            (BinOp::Or, Value::Bool(l), Value::Bool(r)) => return Ok(Value::Bool(*l || *r)),
419            (BinOp::And, _, _) | (BinOp::Or, _, _) => {
420                return Err(EvalError::TypeMismatch {
421                    expected: "bool",
422                    got: type_name(&lhs),
423                });
424            }
425            _ => {}
426        }
427
428        match op {
429            BinOp::Eq => return Ok(Value::Bool(Self::values_equal(&lhs, &rhs)?)),
430            BinOp::NotEq => return Ok(Value::Bool(!Self::values_equal(&lhs, &rhs)?)),
431            _ => {}
432        }
433
434        match (lhs, rhs) {
435            (Value::Int(l), Value::Int(r)) => Self::eval_int_binop(op, l, r),
436            (Value::Float(l), Value::Float(r)) => Self::eval_float_binop(op, l, r),
437            (Value::Int(l), Value::Float(r)) => Self::eval_float_binop(op, l as f64, r),
438            (Value::Float(l), Value::Int(r)) => Self::eval_float_binop(op, l, r as f64),
439            (lhs, _) => Err(EvalError::TypeMismatch {
440                expected: "number",
441                got: type_name(&lhs),
442            }),
443        }
444    }
445
446    fn values_equal(lhs: &Value, rhs: &Value) -> Result<bool, EvalError> {
447        match (lhs, rhs) {
448            (Value::Int(l), Value::Int(r)) => Ok(l == r),
449            (Value::Float(l), Value::Float(r)) => Ok(l == r),
450            (Value::Bool(l), Value::Bool(r)) => Ok(l == r),
451            (Value::Int(l), Value::Float(r)) => Ok((*l as f64) == *r),
452            (Value::Float(l), Value::Int(r)) => Ok(*l == (*r as f64)),
453            (l, r) => Err(EvalError::TypeMismatch {
454                expected: type_name(l),
455                got: type_name(r),
456            }),
457        }
458    }
459
460    fn eval_int_binop(op: &BinOp, lhs: i64, rhs: i64) -> Result<Value<'bump>, EvalError> {
461        match op {
462            BinOp::Lt => return Ok(Value::Bool(lhs < rhs)),
463            BinOp::Gt => return Ok(Value::Bool(lhs > rhs)),
464            BinOp::Le => return Ok(Value::Bool(lhs <= rhs)),
465            BinOp::Ge => return Ok(Value::Bool(lhs >= rhs)),
466            _ => {}
467        }
468
469        Ok(Value::Int(match op {
470            BinOp::Add => lhs + rhs,
471            BinOp::Sub => lhs - rhs,
472            BinOp::Mul => lhs * rhs,
473            BinOp::Div => {
474                if rhs == 0 {
475                    return Err(EvalError::DivisionByZero);
476                }
477                lhs / rhs
478            }
479            BinOp::Pow => lhs.pow(rhs as u32),
480            _ => unreachable!(),
481        }))
482    }
483
484    fn eval_float_binop(op: &BinOp, lhs: f64, rhs: f64) -> Result<Value<'bump>, EvalError> {
485        match op {
486            BinOp::Lt => return Ok(Value::Bool(lhs < rhs)),
487            BinOp::Gt => return Ok(Value::Bool(lhs > rhs)),
488            BinOp::Le => return Ok(Value::Bool(lhs <= rhs)),
489            BinOp::Ge => return Ok(Value::Bool(lhs >= rhs)),
490            _ => {}
491        }
492
493        Ok(Value::Float(match op {
494            BinOp::Add => lhs + rhs,
495            BinOp::Sub => lhs - rhs,
496            BinOp::Mul => lhs * rhs,
497            BinOp::Div => {
498                if rhs == 0.0 {
499                    return Err(EvalError::DivisionByZero);
500                }
501                lhs / rhs
502            }
503            BinOp::Pow => lhs.powf(rhs),
504            _ => unreachable!(),
505        }))
506    }
507}