1use std::sync::Arc;
8
9use crate::builtin::apply_builtin;
10use crate::env::Env;
11use crate::error::ExprError;
12use crate::expr::{BuiltinOp, Expr, Pattern};
13use crate::literal::Literal;
14
15#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
17pub struct EvalConfig {
18 pub max_steps: u64,
20 pub max_depth: u32,
22 pub max_list_len: usize,
24}
25
26impl Default for EvalConfig {
27 fn default() -> Self {
28 Self {
29 max_steps: 100_000,
30 max_depth: 256,
31 max_list_len: 10_000,
32 }
33 }
34}
35
36struct EvalState {
38 steps_remaining: u64,
39 max_steps: u64,
40 max_depth: u32,
41 max_list_len: usize,
42}
43
44impl EvalState {
45 const fn new(config: &EvalConfig) -> Self {
46 Self {
47 steps_remaining: config.max_steps,
48 max_steps: config.max_steps,
49 max_depth: config.max_depth,
50 max_list_len: config.max_list_len,
51 }
52 }
53
54 const fn tick(&mut self) -> Result<(), ExprError> {
55 if self.steps_remaining == 0 {
56 return Err(ExprError::StepLimitExceeded(self.max_steps));
57 }
58 self.steps_remaining -= 1;
59 Ok(())
60 }
61}
62
63pub fn eval(expr: &Expr, env: &Env, config: &EvalConfig) -> Result<Literal, ExprError> {
70 let mut state = EvalState::new(config);
71 eval_inner(expr, env, 0, &mut state)
72}
73
74fn eval_inner(
75 expr: &Expr,
76 env: &Env,
77 depth: u32,
78 state: &mut EvalState,
79) -> Result<Literal, ExprError> {
80 if depth > state.max_depth {
81 return Err(ExprError::DepthExceeded(state.max_depth));
82 }
83 state.tick()?;
84
85 match expr {
86 Expr::Var(name) => env
87 .get(name)
88 .cloned()
89 .ok_or_else(|| ExprError::UnboundVariable(name.to_string())),
90
91 Expr::Lit(lit) => Ok(lit.clone()),
92
93 Expr::Lam(param, body) => {
94 let captured: Vec<(Arc<str>, Literal)> = env
97 .iter()
98 .map(|(k, v)| (Arc::clone(k), v.clone()))
99 .collect();
100 Ok(Literal::Closure {
101 param: Arc::clone(param),
102 body: body.clone(),
103 env: captured,
104 })
105 }
106
107 Expr::App(func, arg) => eval_app(func, arg, env, depth, state),
108
109 Expr::Record(fields) => {
110 let mut result = Vec::with_capacity(fields.len());
111 for (name, expr) in fields {
112 let val = eval_inner(expr, env, depth + 1, state)?;
113 result.push((Arc::clone(name), val));
114 }
115 Ok(Literal::Record(result))
116 }
117
118 Expr::List(items) => {
119 let mut result = Vec::with_capacity(items.len());
120 for item in items {
121 let val = eval_inner(item, env, depth + 1, state)?;
122 result.push(val);
123 }
124 if result.len() > state.max_list_len {
125 return Err(ExprError::ListLengthExceeded(result.len()));
126 }
127 Ok(Literal::List(result))
128 }
129
130 Expr::Field(expr, field) => {
131 let val = eval_inner(expr, env, depth + 1, state)?;
132 match &val {
133 Literal::Record(fields) => fields
134 .iter()
135 .find(|(k, _)| k == field)
136 .map(|(_, v)| v.clone())
137 .ok_or_else(|| ExprError::FieldNotFound(field.to_string())),
138 _ => Err(ExprError::TypeError {
139 expected: "record".into(),
140 got: val.type_name().into(),
141 }),
142 }
143 }
144
145 Expr::Index(expr, idx_expr) => eval_index(expr, idx_expr, env, depth, state),
146
147 Expr::Match { scrutinee, arms } => eval_match(scrutinee, arms, env, depth, state),
148
149 Expr::Let { name, value, body } => {
150 let val = eval_inner(value, env, depth + 1, state)?;
151 let new_env = env.extend(Arc::clone(name), val);
152 eval_inner(body, &new_env, depth + 1, state)
153 }
154
155 Expr::Builtin(op, args) => {
156 match op {
158 BuiltinOp::Map => eval_map(args, env, depth, state),
159 BuiltinOp::Filter => eval_filter(args, env, depth, state),
160 BuiltinOp::Fold => eval_fold(args, env, depth, state),
161 BuiltinOp::FlatMap => eval_flat_map(args, env, depth, state),
162 _ => {
163 let evaluated: Result<Vec<_>, _> = args
164 .iter()
165 .map(|a| eval_inner(a, env, depth + 1, state))
166 .collect();
167 apply_builtin(*op, &evaluated?)
168 }
169 }
170 }
171 }
172}
173
174fn eval_app(
180 func: &Expr,
181 arg: &Expr,
182 env: &Env,
183 depth: u32,
184 state: &mut EvalState,
185) -> Result<Literal, ExprError> {
186 let func_val = eval_inner(func, env, depth + 1, state)?;
188 let arg_val = eval_inner(arg, env, depth + 1, state)?;
190 apply_closure(&func_val, &arg_val, depth, state)
192}
193
194fn apply_closure(
200 func: &Literal,
201 arg: &Literal,
202 depth: u32,
203 state: &mut EvalState,
204) -> Result<Literal, ExprError> {
205 match func {
206 Literal::Closure { param, body, env } => {
207 let mut closure_env: Env = env
209 .iter()
210 .map(|(k, v)| (Arc::clone(k), v.clone()))
211 .collect();
212 closure_env = closure_env.extend(Arc::clone(param), arg.clone());
214 eval_inner(body, &closure_env, depth + 1, state)
216 }
217 _ => Err(ExprError::NotAFunction),
218 }
219}
220
221#[allow(
223 clippy::cast_possible_wrap,
224 clippy::cast_possible_truncation,
225 clippy::cast_sign_loss
226)]
227fn eval_index(
228 expr: &Expr,
229 idx_expr: &Expr,
230 env: &Env,
231 depth: u32,
232 state: &mut EvalState,
233) -> Result<Literal, ExprError> {
234 let val = eval_inner(expr, env, depth + 1, state)?;
235 let idx = eval_inner(idx_expr, env, depth + 1, state)?;
236 match (&val, &idx) {
237 (Literal::List(items), Literal::Int(i)) => {
238 let index = if *i < 0 {
239 (items.len() as i64 + i) as usize
240 } else {
241 *i as usize
242 };
243 items
244 .get(index)
245 .cloned()
246 .ok_or(ExprError::IndexOutOfBounds {
247 index: *i,
248 len: items.len(),
249 })
250 }
251 _ => Err(ExprError::TypeError {
252 expected: "(list, int)".into(),
253 got: format!("({}, {})", val.type_name(), idx.type_name()),
254 }),
255 }
256}
257
258fn eval_match(
260 scrutinee: &Expr,
261 arms: &[(Pattern, Expr)],
262 env: &Env,
263 depth: u32,
264 state: &mut EvalState,
265) -> Result<Literal, ExprError> {
266 let val = eval_inner(scrutinee, env, depth + 1, state)?;
267 for (pattern, body) in arms {
268 if let Some(bindings) = match_pattern(pattern, &val) {
269 let mut new_env = env.clone();
270 for (name, bound_val) in bindings {
271 new_env = new_env.extend(name, bound_val);
272 }
273 return eval_inner(body, &new_env, depth + 1, state);
274 }
275 }
276 Err(ExprError::NonExhaustiveMatch)
277}
278
279fn eval_map(
281 args: &[Expr],
282 env: &Env,
283 depth: u32,
284 state: &mut EvalState,
285) -> Result<Literal, ExprError> {
286 if args.len() != 2 {
287 return Err(ExprError::ArityMismatch {
288 op: "Map".into(),
289 expected: 2,
290 got: args.len(),
291 });
292 }
293 let list_val = eval_inner(&args[0], env, depth + 1, state)?;
294 let items = match list_val {
295 Literal::List(items) => items,
296 other => {
297 return Err(ExprError::TypeError {
298 expected: "list".into(),
299 got: other.type_name().into(),
300 });
301 }
302 };
303
304 let func = &args[1];
305 let mut result = Vec::with_capacity(items.len());
306 for item in &items {
307 let val = apply_lambda(func, item, env, depth + 1, state)?;
308 result.push(val);
309 }
310 if result.len() > state.max_list_len {
311 return Err(ExprError::ListLengthExceeded(result.len()));
312 }
313 Ok(Literal::List(result))
314}
315
316fn eval_filter(
318 args: &[Expr],
319 env: &Env,
320 depth: u32,
321 state: &mut EvalState,
322) -> Result<Literal, ExprError> {
323 if args.len() != 2 {
324 return Err(ExprError::ArityMismatch {
325 op: "Filter".into(),
326 expected: 2,
327 got: args.len(),
328 });
329 }
330 let list_val = eval_inner(&args[0], env, depth + 1, state)?;
331 let items = match list_val {
332 Literal::List(items) => items,
333 other => {
334 return Err(ExprError::TypeError {
335 expected: "list".into(),
336 got: other.type_name().into(),
337 });
338 }
339 };
340
341 let pred = &args[1];
342 let mut result = Vec::new();
343 for item in &items {
344 let keep = apply_lambda(pred, item, env, depth + 1, state)?;
345 match keep {
346 Literal::Bool(true) => result.push(item.clone()),
347 Literal::Bool(false) => {}
348 other => {
349 return Err(ExprError::TypeError {
350 expected: "bool".into(),
351 got: other.type_name().into(),
352 });
353 }
354 }
355 }
356 Ok(Literal::List(result))
357}
358
359fn eval_fold(
361 args: &[Expr],
362 env: &Env,
363 depth: u32,
364 state: &mut EvalState,
365) -> Result<Literal, ExprError> {
366 if args.len() != 3 {
367 return Err(ExprError::ArityMismatch {
368 op: "Fold".into(),
369 expected: 3,
370 got: args.len(),
371 });
372 }
373 let list_val = eval_inner(&args[0], env, depth + 1, state)?;
374 let items = match list_val {
375 Literal::List(items) => items,
376 other => {
377 return Err(ExprError::TypeError {
378 expected: "list".into(),
379 got: other.type_name().into(),
380 });
381 }
382 };
383
384 let mut acc = eval_inner(&args[1], env, depth + 1, state)?;
385 let func = &args[2];
386
387 for item in &items {
388 acc = apply_lambda_2(func, &acc, item, env, depth + 1, state)?;
391 }
392 Ok(acc)
393}
394
395fn eval_flat_map(
397 args: &[Expr],
398 env: &Env,
399 depth: u32,
400 state: &mut EvalState,
401) -> Result<Literal, ExprError> {
402 if args.len() != 2 {
403 return Err(ExprError::ArityMismatch {
404 op: "FlatMap".into(),
405 expected: 2,
406 got: args.len(),
407 });
408 }
409 let list_val = eval_inner(&args[0], env, depth + 1, state)?;
410 let items = match list_val {
411 Literal::List(items) => items,
412 other => {
413 return Err(ExprError::TypeError {
414 expected: "list".into(),
415 got: other.type_name().into(),
416 });
417 }
418 };
419
420 let func = &args[1];
421 let mut result = Vec::new();
422 for item in &items {
423 let sub_list = apply_lambda(func, item, env, depth + 1, state)?;
424 match sub_list {
425 Literal::List(sub_items) => result.extend(sub_items),
426 other => {
427 return Err(ExprError::TypeError {
428 expected: "list".into(),
429 got: other.type_name().into(),
430 });
431 }
432 }
433 if result.len() > state.max_list_len {
434 return Err(ExprError::ListLengthExceeded(result.len()));
435 }
436 }
437 Ok(Literal::List(result))
438}
439
440fn apply_lambda(
445 func_expr: &Expr,
446 arg: &Literal,
447 env: &Env,
448 depth: u32,
449 state: &mut EvalState,
450) -> Result<Literal, ExprError> {
451 let func_val = eval_inner(func_expr, env, depth + 1, state)?;
452 apply_closure(&func_val, arg, depth, state)
453}
454
455fn apply_lambda_2(
460 func_expr: &Expr,
461 arg1: &Literal,
462 arg2: &Literal,
463 env: &Env,
464 depth: u32,
465 state: &mut EvalState,
466) -> Result<Literal, ExprError> {
467 let func_val = eval_inner(func_expr, env, depth + 1, state)?;
468 let partial = apply_closure(&func_val, arg1, depth, state)?;
469 apply_closure(&partial, arg2, depth, state)
470}
471
472fn match_pattern(pattern: &Pattern, value: &Literal) -> Option<Vec<(Arc<str>, Literal)>> {
474 let mut bindings = Vec::new();
475 if match_inner(pattern, value, &mut bindings) {
476 Some(bindings)
477 } else {
478 None
479 }
480}
481
482fn match_inner(
483 pattern: &Pattern,
484 value: &Literal,
485 bindings: &mut Vec<(Arc<str>, Literal)>,
486) -> bool {
487 match pattern {
488 Pattern::Wildcard => true,
489 Pattern::Var(name) => {
490 bindings.push((Arc::clone(name), value.clone()));
491 true
492 }
493 Pattern::Lit(lit) => lit == value,
494 Pattern::Record(field_pats) => {
495 if let Literal::Record(fields) = value {
496 for (pat_name, pat) in field_pats {
497 let field_val = fields.iter().find(|(k, _)| k == pat_name);
498 match field_val {
499 Some((_, v)) => {
500 if !match_inner(pat, v, bindings) {
501 return false;
502 }
503 }
504 None => return false,
505 }
506 }
507 true
508 } else {
509 false
510 }
511 }
512 Pattern::List(item_pats) => {
513 if let Literal::List(items) = value {
514 if items.len() != item_pats.len() {
515 return false;
516 }
517 for (pat, val) in item_pats.iter().zip(items.iter()) {
518 if !match_inner(pat, val, bindings) {
519 return false;
520 }
521 }
522 true
523 } else {
524 false
525 }
526 }
527 Pattern::Constructor(tag, arg_pats) => {
528 if let Literal::Record(fields) = value {
530 let tag_field = fields.iter().find(|(k, _)| &**k == "$tag");
531 if let Some((_, Literal::Str(t))) = tag_field {
532 if t.as_str() != &**tag {
533 return false;
534 }
535 for (i, pat) in arg_pats.iter().enumerate() {
537 let key = format!("${i}");
538 let field_val = fields.iter().find(|(k, _)| k.as_ref() == key.as_str());
539 match field_val {
540 Some((_, v)) => {
541 if !match_inner(pat, v, bindings) {
542 return false;
543 }
544 }
545 None => return false,
546 }
547 }
548 true
549 } else {
550 false
551 }
552 } else {
553 false
554 }
555 }
556 }
557}
558
559#[cfg(test)]
560#[allow(clippy::unwrap_used)]
561mod tests {
562 use super::*;
563
564 fn default_config() -> EvalConfig {
565 EvalConfig::default()
566 }
567
568 #[test]
569 fn eval_literal() {
570 let result = eval(&Expr::Lit(Literal::Int(42)), &Env::new(), &default_config());
571 assert_eq!(result.unwrap(), Literal::Int(42));
572 }
573
574 #[test]
575 fn eval_variable() {
576 let env = Env::new().extend(Arc::from("x"), Literal::Int(10));
577 let result = eval(&Expr::var("x"), &env, &default_config());
578 assert_eq!(result.unwrap(), Literal::Int(10));
579 }
580
581 #[test]
582 fn eval_unbound_variable() {
583 let result = eval(&Expr::var("x"), &Env::new(), &default_config());
584 assert!(matches!(result, Err(ExprError::UnboundVariable(_))));
585 }
586
587 #[test]
588 fn eval_lambda_application() {
589 let expr = Expr::App(
591 Box::new(Expr::lam(
592 "x",
593 Expr::builtin(
594 BuiltinOp::Add,
595 vec![Expr::var("x"), Expr::Lit(Literal::Int(1))],
596 ),
597 )),
598 Box::new(Expr::Lit(Literal::Int(41))),
599 );
600 let result = eval(&expr, &Env::new(), &default_config());
601 assert_eq!(result.unwrap(), Literal::Int(42));
602 }
603
604 #[test]
605 fn eval_let_binding() {
606 let expr = Expr::let_in(
608 "x",
609 Expr::Lit(Literal::Int(10)),
610 Expr::builtin(
611 BuiltinOp::Add,
612 vec![Expr::var("x"), Expr::Lit(Literal::Int(5))],
613 ),
614 );
615 let result = eval(&expr, &Env::new(), &default_config());
616 assert_eq!(result.unwrap(), Literal::Int(15));
617 }
618
619 #[test]
620 fn eval_record_and_field() {
621 let expr = Expr::field(
622 Expr::Record(vec![
623 (Arc::from("name"), Expr::Lit(Literal::Str("alice".into()))),
624 (Arc::from("age"), Expr::Lit(Literal::Int(30))),
625 ]),
626 "age",
627 );
628 let result = eval(&expr, &Env::new(), &default_config());
629 assert_eq!(result.unwrap(), Literal::Int(30));
630 }
631
632 #[test]
633 fn eval_list_index() {
634 let expr = Expr::Index(
635 Box::new(Expr::List(vec![
636 Expr::Lit(Literal::Int(10)),
637 Expr::Lit(Literal::Int(20)),
638 Expr::Lit(Literal::Int(30)),
639 ])),
640 Box::new(Expr::Lit(Literal::Int(1))),
641 );
642 let result = eval(&expr, &Env::new(), &default_config());
643 assert_eq!(result.unwrap(), Literal::Int(20));
644 }
645
646 #[test]
647 fn eval_pattern_match() {
648 let expr = Expr::Match {
650 scrutinee: Box::new(Expr::Lit(Literal::Int(42))),
651 arms: vec![
652 (
653 Pattern::Lit(Literal::Int(0)),
654 Expr::Lit(Literal::Str("zero".into())),
655 ),
656 (
657 Pattern::Var(Arc::from("x")),
658 Expr::builtin(
659 BuiltinOp::Concat,
660 vec![
661 Expr::Lit(Literal::Str("num:".into())),
662 Expr::builtin(BuiltinOp::IntToStr, vec![Expr::var("x")]),
663 ],
664 ),
665 ),
666 ],
667 };
668 let result = eval(&expr, &Env::new(), &default_config());
669 assert_eq!(result.unwrap(), Literal::Str("num:42".into()));
670 }
671
672 #[test]
673 fn eval_map() {
674 let expr = Expr::builtin(
676 BuiltinOp::Map,
677 vec![
678 Expr::List(vec![
679 Expr::Lit(Literal::Int(1)),
680 Expr::Lit(Literal::Int(2)),
681 Expr::Lit(Literal::Int(3)),
682 ]),
683 Expr::lam(
684 "x",
685 Expr::builtin(
686 BuiltinOp::Mul,
687 vec![Expr::var("x"), Expr::Lit(Literal::Int(2))],
688 ),
689 ),
690 ],
691 );
692 let result = eval(&expr, &Env::new(), &default_config());
693 assert_eq!(
694 result.unwrap(),
695 Literal::List(vec![Literal::Int(2), Literal::Int(4), Literal::Int(6)])
696 );
697 }
698
699 #[test]
700 fn eval_filter() {
701 let expr = Expr::builtin(
703 BuiltinOp::Filter,
704 vec![
705 Expr::List(vec![
706 Expr::Lit(Literal::Int(1)),
707 Expr::Lit(Literal::Int(2)),
708 Expr::Lit(Literal::Int(3)),
709 Expr::Lit(Literal::Int(4)),
710 ]),
711 Expr::lam(
712 "x",
713 Expr::builtin(
714 BuiltinOp::Gt,
715 vec![Expr::var("x"), Expr::Lit(Literal::Int(2))],
716 ),
717 ),
718 ],
719 );
720 let result = eval(&expr, &Env::new(), &default_config());
721 assert_eq!(
722 result.unwrap(),
723 Literal::List(vec![Literal::Int(3), Literal::Int(4)])
724 );
725 }
726
727 #[test]
728 fn eval_fold() {
729 let expr = Expr::builtin(
731 BuiltinOp::Fold,
732 vec![
733 Expr::List(vec![
734 Expr::Lit(Literal::Int(1)),
735 Expr::Lit(Literal::Int(2)),
736 Expr::Lit(Literal::Int(3)),
737 ]),
738 Expr::Lit(Literal::Int(0)),
739 Expr::lam(
740 "acc",
741 Expr::lam(
742 "x",
743 Expr::builtin(BuiltinOp::Add, vec![Expr::var("acc"), Expr::var("x")]),
744 ),
745 ),
746 ],
747 );
748 let result = eval(&expr, &Env::new(), &default_config());
749 assert_eq!(result.unwrap(), Literal::Int(6));
750 }
751
752 #[test]
753 fn eval_step_limit() {
754 let config = EvalConfig {
756 max_steps: 5,
757 ..EvalConfig::default()
758 };
759 let items: Vec<_> = (1..=10).map(|i| Expr::Lit(Literal::Int(i))).collect();
761 let expr = Expr::builtin(
762 BuiltinOp::Map,
763 vec![
764 Expr::List(items),
765 Expr::lam(
766 "x",
767 Expr::builtin(
768 BuiltinOp::Add,
769 vec![Expr::var("x"), Expr::Lit(Literal::Int(1))],
770 ),
771 ),
772 ],
773 );
774 let result = eval(&expr, &Env::new(), &config);
775 assert!(matches!(result, Err(ExprError::StepLimitExceeded(_))));
776 }
777
778 #[test]
779 fn eval_merge_example() {
780 let merge_fn = Expr::lam(
783 "first",
784 Expr::lam(
785 "last",
786 Expr::builtin(
787 BuiltinOp::Concat,
788 vec![
789 Expr::var("first"),
790 Expr::builtin(
791 BuiltinOp::Concat,
792 vec![Expr::Lit(Literal::Str(" ".into())), Expr::var("last")],
793 ),
794 ],
795 ),
796 ),
797 );
798 let expr = Expr::App(
800 Box::new(Expr::App(
801 Box::new(merge_fn),
802 Box::new(Expr::Lit(Literal::Str("Alice".into()))),
803 )),
804 Box::new(Expr::Lit(Literal::Str("Smith".into()))),
805 );
806 let result = eval(&expr, &Env::new(), &default_config());
807 assert_eq!(result.unwrap(), Literal::Str("Alice Smith".into()));
808 }
809
810 #[test]
811 fn eval_split_example() {
812 let split_fn = Expr::lam(
816 "full",
817 Expr::let_in(
818 "parts",
819 Expr::builtin(
820 BuiltinOp::Split,
821 vec![Expr::var("full"), Expr::Lit(Literal::Str(" ".into()))],
822 ),
823 Expr::Record(vec![
824 (
825 Arc::from("firstName"),
826 Expr::builtin(BuiltinOp::Head, vec![Expr::var("parts")]),
827 ),
828 (
829 Arc::from("lastName"),
830 Expr::builtin(
831 BuiltinOp::Join,
832 vec![
833 Expr::builtin(BuiltinOp::Tail, vec![Expr::var("parts")]),
834 Expr::Lit(Literal::Str(" ".into())),
835 ],
836 ),
837 ),
838 ]),
839 ),
840 );
841 let expr = Expr::App(
842 Box::new(split_fn),
843 Box::new(Expr::Lit(Literal::Str("Alice B Smith".into()))),
844 );
845 let result = eval(&expr, &Env::new(), &default_config());
846 let expected = Literal::Record(vec![
847 (Arc::from("firstName"), Literal::Str("Alice".into())),
848 (Arc::from("lastName"), Literal::Str("B Smith".into())),
849 ]);
850 assert_eq!(result.unwrap(), expected);
851 }
852
853 #[test]
854 fn eval_coercion_example() {
855 let coerce = Expr::lam(
857 "v",
858 Expr::builtin(BuiltinOp::StrToInt, vec![Expr::var("v")]),
859 );
860 let expr = Expr::App(
861 Box::new(coerce),
862 Box::new(Expr::Lit(Literal::Str("42".into()))),
863 );
864 let result = eval(&expr, &Env::new(), &default_config());
865 assert_eq!(result.unwrap(), Literal::Int(42));
866 }
867}