1use bumpalo::Bump;
2use chumsky::span::SimpleSpan;
3use lasso::{Rodeo, Spur};
4use std::{cell::RefCell, ops::Deref, rc::Rc};
5
6use crate::thunk::Thunk;
7
8#[derive(Debug, Clone)]
9pub enum Literal {
10 Int(i64),
11 Float(f64),
12 Bool(bool),
13}
14
15#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
16pub struct Ident(pub Spur);
17
18#[derive(Debug, Clone)]
19pub enum BinOp {
20 Eq,
21 NotEq,
22 Lt,
23 Gt,
24 Le,
25 Ge,
26 And,
27 Or,
28 Add,
29 Sub,
30 Mul,
31 Div,
32 Pow,
33}
34
35#[derive(Debug, Clone)]
36pub enum UnaryOp {
37 Neg,
38 Not,
39}
40
41#[derive(Debug, Clone)]
42pub enum Pat<'bump> {
43 Wildcard,
44 Var(Ident),
45 Lit(Literal),
46 Or(&'bump Pat<'bump>, &'bump Pat<'bump>),
47}
48
49#[derive(Debug, Clone, Copy)]
50pub struct Expr<'bump> {
51 pub kind: &'bump ExprKind<'bump>,
52 pub span: SimpleSpan,
53}
54
55impl<'bump> Deref for Expr<'bump> {
56 type Target = ExprKind<'bump>;
57
58 fn deref(&self) -> &Self::Target {
59 &self.kind
60 }
61}
62
63#[derive(Debug, Clone)]
64pub enum ExprKind<'bump> {
65 Literal(Literal),
66 Var(Ident),
67 If {
68 cond: Expr<'bump>,
69 then_expr: Expr<'bump>,
70 else_expr: Expr<'bump>,
71 },
72 BinOp {
73 op: BinOp,
74 lhs: Expr<'bump>,
75 rhs: Expr<'bump>,
76 },
77 UnaryOp {
78 op: UnaryOp,
79 rhs: Expr<'bump>,
80 },
81 Let {
82 name: Ident,
83 value: Expr<'bump>,
84 body: Expr<'bump>,
85 rec: bool,
86 },
87 Match {
88 scrutinee: Expr<'bump>,
89 arms: &'bump [(Pat<'bump>, Expr<'bump>)],
90 },
91 Abs {
92 param: Ident,
93 body: Expr<'bump>,
94 },
95 App {
96 func: Expr<'bump>,
97 arg: Expr<'bump>,
98 },
99}
100
101pub type Env<'bump> = im::HashMap<Ident, Thunk<'bump>>;
102
103#[derive(Debug, Clone)]
104pub enum Value<'bump> {
105 Int(i64),
106 Float(f64),
107 Bool(bool),
108 Closure {
109 param: Ident,
110 body: Expr<'bump>,
111 env: Env<'bump>,
112 },
113}
114
115impl std::fmt::Display for Value<'_> {
116 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117 match self {
118 Value::Int(n) => write!(f, "{n}"),
119 Value::Float(n) => write!(f, "{n}"),
120 Value::Bool(b) => write!(f, "{b}"),
121 Value::Closure { .. } => write!(f, "<closure>"),
122 }
123 }
124}
125
126#[derive(Debug, Clone)]
127pub enum EvalError {
128 UnboundVariable(String),
129 TypeMismatch {
130 expected: &'static str,
131 got: &'static str,
132 },
133 DivisionByZero,
134 NotAFunction,
135 NonExhaustiveMatch,
136}
137
138impl std::fmt::Display for EvalError {
139 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140 match self {
141 EvalError::NonExhaustiveMatch => write!(f, "non-exhaustive match"),
142 EvalError::UnboundVariable(name) => write!(f, "unbound variable `{name}`"),
143 EvalError::TypeMismatch { expected, got } => {
144 write!(f, "type mismatch: expected {expected}, got {got}")
145 }
146 EvalError::DivisionByZero => write!(f, "division by zero"),
147 EvalError::NotAFunction => write!(f, "applied a non-function value"),
148 }
149 }
150}
151
152fn type_name(v: &Value) -> &'static str {
153 match v {
154 Value::Int(_) => "int",
155 Value::Float(_) => "float",
156 Value::Bool(_) => "bool",
157 Value::Closure { .. } => "closure",
158 }
159}
160
161impl<'bump> ExprKind<'bump> {
162 fn node(expr: &'bump ExprKind<'bump>, span: SimpleSpan) -> Expr<'bump> {
163 Expr { kind: expr, span }
164 }
165
166 pub fn literal(bump: &'bump Bump, span: SimpleSpan, lit: Literal) -> Expr<'bump> {
167 Self::node(bump.alloc(ExprKind::Literal(lit)), span)
168 }
169
170 pub fn ident(bump: &'bump Bump, span: SimpleSpan, name: Ident) -> Expr<'bump> {
171 Self::node(bump.alloc(ExprKind::Var(name)), span)
172 }
173
174 pub fn if_expr(
175 bump: &'bump Bump,
176 span: SimpleSpan,
177 cond: Expr<'bump>,
178 then_expr: Expr<'bump>,
179 else_expr: Expr<'bump>,
180 ) -> Expr<'bump> {
181 Self::node(
182 bump.alloc(ExprKind::If {
183 cond,
184 then_expr,
185 else_expr,
186 }),
187 span,
188 )
189 }
190
191 pub fn binop(
192 bump: &'bump Bump,
193 span: SimpleSpan,
194 op: BinOp,
195 lhs: Expr<'bump>,
196 rhs: Expr<'bump>,
197 ) -> Expr<'bump> {
198 Self::node(bump.alloc(ExprKind::BinOp { op, lhs, rhs }), span)
199 }
200
201 pub fn unaryop(
202 bump: &'bump Bump,
203 span: SimpleSpan,
204 op: UnaryOp,
205 rhs: Expr<'bump>,
206 ) -> Expr<'bump> {
207 Self::node(bump.alloc(ExprKind::UnaryOp { op, rhs }), span)
208 }
209
210 pub fn let_expr(
211 bump: &'bump Bump,
212 span: SimpleSpan,
213 name: Ident,
214 value: Expr<'bump>,
215 body: Expr<'bump>,
216 rec: bool,
217 ) -> Expr<'bump> {
218 Self::node(
219 bump.alloc(ExprKind::Let {
220 name,
221 value,
222 body,
223 rec,
224 }),
225 span,
226 )
227 }
228
229 pub fn match_expr(
230 bump: &'bump Bump,
231 span: SimpleSpan,
232 scrutinee: Expr<'bump>,
233 arms: &'bump [(Pat<'bump>, Expr<'bump>)],
234 ) -> Expr<'bump> {
235 Self::node(bump.alloc(ExprKind::Match { scrutinee, arms }), span)
236 }
237
238 pub fn lambda(
239 bump: &'bump Bump,
240 span: SimpleSpan,
241 param: Ident,
242 body: Expr<'bump>,
243 ) -> Expr<'bump> {
244 Self::node(bump.alloc(ExprKind::Abs { param, body }), span)
245 }
246
247 pub fn app(
248 bump: &'bump Bump,
249 span: SimpleSpan,
250 func: Expr<'bump>,
251 arg: Expr<'bump>,
252 ) -> Expr<'bump> {
253 Self::node(bump.alloc(ExprKind::App { func, arg }), span)
254 }
255}
256
257impl<'bump> ExprKind<'bump> {
258 fn force(thunk: &Thunk<'bump>, rodeo: &Rodeo) -> Result<Value<'bump>, EvalError> {
259 thunk.force(rodeo)
260 }
261
262 pub fn thunk(expr: &'bump ExprKind<'bump>, env: &Env<'bump>) -> Thunk<'bump> {
263 Thunk::new(expr, env.clone())
264 }
265
266 pub fn eval_lazy(
267 &'bump self,
268 env: &Env<'bump>,
269 rodeo: &Rodeo,
270 ) -> Result<Value<'bump>, EvalError> {
271 match self {
272 ExprKind::Literal(Literal::Bool(b)) => Ok(Value::Bool(*b)),
273 ExprKind::Literal(Literal::Int(v)) => Ok(Value::Int(*v)),
274 ExprKind::Literal(Literal::Float(v)) => Ok(Value::Float(*v)),
275
276 ExprKind::Var(ident) => {
277 let thunk = env.get(ident).ok_or_else(|| {
278 EvalError::UnboundVariable(rodeo.resolve(&ident.0).to_owned())
279 })?;
280 Self::force(thunk, rodeo)
281 }
282
283 ExprKind::UnaryOp { op, rhs } => {
284 let rhs = rhs.kind.eval_lazy(env, rodeo)?;
285 match (op, rhs) {
286 (UnaryOp::Neg, Value::Int(n)) => Ok(Value::Int(-n)),
287 (UnaryOp::Neg, Value::Float(f)) => Ok(Value::Float(-f)),
288 (UnaryOp::Not, Value::Bool(b)) => Ok(Value::Bool(!b)),
289 (UnaryOp::Neg, v) => Err(EvalError::TypeMismatch {
290 expected: "number",
291 got: type_name(&v),
292 }),
293 (UnaryOp::Not, v) => Err(EvalError::TypeMismatch {
294 expected: "bool",
295 got: type_name(&v),
296 }),
297 }
298 }
299
300 ExprKind::BinOp { op, lhs, rhs } => {
301 let lhs = lhs.kind.eval_lazy(env, rodeo)?;
302 let rhs = rhs.kind.eval_lazy(env, rodeo)?;
303 Self::eval_binop(op, lhs, rhs)
304 }
305
306 ExprKind::Let {
307 name,
308 value,
309 body,
310 rec: true,
311 } => {
312 let rec_env = Rc::new(RefCell::new(env.clone()));
313 let thunk = Thunk::new_shared(value.kind, Rc::clone(&rec_env));
314 rec_env.borrow_mut().insert(*name, thunk.clone());
315 let mut body_env = env.clone();
316 body_env.insert(*name, thunk);
317 body.kind.eval_lazy(&body_env, rodeo)
318 }
319
320 ExprKind::Let {
321 name,
322 value,
323 body,
324 rec: false,
325 } => {
326 let mut env = env.clone();
327 env.insert(*name, Self::thunk(value.kind, &env));
328 body.kind.eval_lazy(&env, rodeo)
329 }
330
331 ExprKind::Match { scrutinee, arms } => {
332 let scrutinee_thunk = Thunk::new(scrutinee.kind, env.clone());
333 for (pat, body) in arms.iter() {
334 let mut arm_env = env.clone();
335 if Self::match_pat(pat, &scrutinee_thunk, &mut arm_env, rodeo)? {
336 return body.kind.eval_lazy(&arm_env, rodeo);
337 }
338 }
339 Err(EvalError::NonExhaustiveMatch)
340 }
341
342 ExprKind::If {
343 cond,
344 then_expr,
345 else_expr,
346 } => match cond.kind.eval_lazy(env, rodeo)? {
347 Value::Bool(true) => then_expr.kind.eval_lazy(env, rodeo),
348 Value::Bool(false) => else_expr.kind.eval_lazy(env, rodeo),
349 v => Err(EvalError::TypeMismatch {
350 expected: "bool",
351 got: type_name(&v),
352 }),
353 },
354
355 ExprKind::Abs { param, body } => Ok(Value::Closure {
356 param: *param,
357 body: *body,
358 env: env.clone(),
359 }),
360
361 ExprKind::App { func, arg } => {
362 let func = func.kind.eval_lazy(env, rodeo)?;
363 match func {
364 Value::Closure {
365 param,
366 body,
367 env: mut closure_env,
368 } => {
369 closure_env.insert(param, Self::thunk(arg.kind, env));
370 body.kind.eval_lazy(&closure_env, rodeo)
371 }
372 _ => Err(EvalError::NotAFunction),
373 }
374 }
375 }
376 }
377
378 fn match_pat(
379 pat: &Pat<'bump>,
380 thunk: &Thunk<'bump>,
381 env: &mut Env<'bump>,
382 rodeo: &Rodeo,
383 ) -> Result<bool, EvalError> {
384 match pat {
385 Pat::Wildcard => Ok(true),
386 Pat::Var(name) => {
387 env.insert(*name, thunk.clone());
388 Ok(true)
389 }
390 Pat::Lit(lit) => {
391 let val = thunk.force(rodeo)?;
392 Ok(match (lit, &val) {
393 (Literal::Int(a), Value::Int(b)) => a == b,
394 (Literal::Float(a), Value::Float(b)) => a == b,
395 (Literal::Bool(a), Value::Bool(b)) => a == b,
396 _ => false,
397 })
398 }
399 Pat::Or(a, b) => {
400 let mut env_a = env.clone();
401 if Self::match_pat(a, thunk, &mut env_a, rodeo)? {
402 *env = env_a;
403 Ok(true)
404 } else {
405 Self::match_pat(b, thunk, env, rodeo)
406 }
407 }
408 }
409 }
410
411 fn eval_binop(
412 op: &BinOp,
413 lhs: Value<'bump>,
414 rhs: Value<'bump>,
415 ) -> Result<Value<'bump>, EvalError> {
416 match (op, &lhs, &rhs) {
417 (BinOp::And, Value::Bool(l), Value::Bool(r)) => return Ok(Value::Bool(*l && *r)),
418 (BinOp::Or, Value::Bool(l), Value::Bool(r)) => return Ok(Value::Bool(*l || *r)),
419 (BinOp::And, _, _) | (BinOp::Or, _, _) => {
420 return Err(EvalError::TypeMismatch {
421 expected: "bool",
422 got: type_name(&lhs),
423 });
424 }
425 _ => {}
426 }
427
428 match op {
429 BinOp::Eq => return Ok(Value::Bool(Self::values_equal(&lhs, &rhs)?)),
430 BinOp::NotEq => return Ok(Value::Bool(!Self::values_equal(&lhs, &rhs)?)),
431 _ => {}
432 }
433
434 match (lhs, rhs) {
435 (Value::Int(l), Value::Int(r)) => Self::eval_int_binop(op, l, r),
436 (Value::Float(l), Value::Float(r)) => Self::eval_float_binop(op, l, r),
437 (Value::Int(l), Value::Float(r)) => Self::eval_float_binop(op, l as f64, r),
438 (Value::Float(l), Value::Int(r)) => Self::eval_float_binop(op, l, r as f64),
439 (lhs, _) => Err(EvalError::TypeMismatch {
440 expected: "number",
441 got: type_name(&lhs),
442 }),
443 }
444 }
445
446 fn values_equal(lhs: &Value, rhs: &Value) -> Result<bool, EvalError> {
447 match (lhs, rhs) {
448 (Value::Int(l), Value::Int(r)) => Ok(l == r),
449 (Value::Float(l), Value::Float(r)) => Ok(l == r),
450 (Value::Bool(l), Value::Bool(r)) => Ok(l == r),
451 (Value::Int(l), Value::Float(r)) => Ok((*l as f64) == *r),
452 (Value::Float(l), Value::Int(r)) => Ok(*l == (*r as f64)),
453 (l, r) => Err(EvalError::TypeMismatch {
454 expected: type_name(l),
455 got: type_name(r),
456 }),
457 }
458 }
459
460 fn eval_int_binop(op: &BinOp, lhs: i64, rhs: i64) -> Result<Value<'bump>, EvalError> {
461 match op {
462 BinOp::Lt => return Ok(Value::Bool(lhs < rhs)),
463 BinOp::Gt => return Ok(Value::Bool(lhs > rhs)),
464 BinOp::Le => return Ok(Value::Bool(lhs <= rhs)),
465 BinOp::Ge => return Ok(Value::Bool(lhs >= rhs)),
466 _ => {}
467 }
468
469 Ok(Value::Int(match op {
470 BinOp::Add => lhs + rhs,
471 BinOp::Sub => lhs - rhs,
472 BinOp::Mul => lhs * rhs,
473 BinOp::Div => {
474 if rhs == 0 {
475 return Err(EvalError::DivisionByZero);
476 }
477 lhs / rhs
478 }
479 BinOp::Pow => lhs.pow(rhs as u32),
480 _ => unreachable!(),
481 }))
482 }
483
484 fn eval_float_binop(op: &BinOp, lhs: f64, rhs: f64) -> Result<Value<'bump>, EvalError> {
485 match op {
486 BinOp::Lt => return Ok(Value::Bool(lhs < rhs)),
487 BinOp::Gt => return Ok(Value::Bool(lhs > rhs)),
488 BinOp::Le => return Ok(Value::Bool(lhs <= rhs)),
489 BinOp::Ge => return Ok(Value::Bool(lhs >= rhs)),
490 _ => {}
491 }
492
493 Ok(Value::Float(match op {
494 BinOp::Add => lhs + rhs,
495 BinOp::Sub => lhs - rhs,
496 BinOp::Mul => lhs * rhs,
497 BinOp::Div => {
498 if rhs == 0.0 {
499 return Err(EvalError::DivisionByZero);
500 }
501 lhs / rhs
502 }
503 BinOp::Pow => lhs.powf(rhs),
504 _ => unreachable!(),
505 }))
506 }
507}