1use std::{collections::HashMap, sync::Arc};
2
3use miette::NamedSource;
4
5use crate::{
6 parse::{Expr, FuncDecl, Literal, ParsedProgram, Stmt},
7 spanned::Spanned,
8 Result,
9};
10
11#[derive(Debug, Clone)]
12enum Val {
13 Struct(HashMap<String, Val>),
14 Int(i64),
15 Float(f64),
16 Bool(bool),
17}
18
19pub fn eval_kdl_script(_src: &Arc<NamedSource>, program: &ParsedProgram) -> Result<i64> {
20 let main = lookup_func(program, "main");
21
22 let val = eval_call(program, main, HashMap::default());
23
24 match val {
25 Val::Int(val) => Ok(val),
26 Val::Float(val) => Ok(val as i64),
27 Val::Bool(val) => Ok(val as i64),
28 Val::Struct(_) => {
29 unreachable!("main returned struct!")
30 }
31 }
32}
33
34fn eval_call(program: &ParsedProgram, func: &FuncDecl, mut vars: HashMap<String, Val>) -> Val {
35 for stmt in &func.body {
36 match stmt {
37 Stmt::Let(stmt) => {
38 let temp = eval_expr(program, &stmt.expr, &vars);
39 if let Some(var) = &stmt.var {
40 vars.insert(var.to_string(), temp);
41 }
42 }
43 Stmt::Return(stmt) => {
44 let temp = eval_expr(program, &stmt.expr, &vars);
45 return temp;
46 }
47 Stmt::Print(stmt) => {
48 let temp = eval_expr(program, &stmt.expr, &vars);
49 print_val(&temp);
50 }
51 }
52 }
53 unreachable!("function didn't return!");
54}
55
56fn eval_expr(program: &ParsedProgram, expr: &Spanned<Expr>, vars: &HashMap<String, Val>) -> Val {
57 match &**expr {
58 Expr::Call(expr) => {
59 let func = lookup_func(program, &expr.func);
60 assert_eq!(
61 func.inputs.len(),
62 expr.args.len(),
63 "function {} had wrong number of args",
64 &**expr.func
65 );
66 let input = func
67 .inputs
68 .iter()
69 .zip(expr.args.iter())
70 .map(|(var, expr)| {
71 let val = eval_expr(program, expr, vars);
72 let var = var.name.as_ref().unwrap().to_string();
73 (var, val)
74 })
75 .collect();
76
77 match func.name.as_str() {
78 "+" => eval_add(input),
79 _ => eval_call(program, func, input),
80 }
81 }
82 Expr::Path(expr) => {
83 let mut sub_val = vars
84 .get(&**expr.var)
85 .unwrap_or_else(|| panic!("couldn't find var {}", &**expr.var));
86 for field in &expr.path {
87 if let Val::Struct(val) = sub_val {
88 sub_val = val
89 .get(field.as_str())
90 .unwrap_or_else(|| panic!("couldn't find field {}", &**field));
91 } else {
92 panic!("tried to get .{} on primitive", &**field);
93 }
94 }
95 sub_val.clone()
96 }
97 Expr::Ctor(expr) => {
98 let fields = expr
99 .vals
100 .iter()
101 .map(|stmt| {
102 let val = eval_expr(program, &stmt.expr, vars);
103 let var = stmt.var.as_ref().unwrap().to_string();
104 (var, val)
105 })
106 .collect();
107 Val::Struct(fields)
108 }
109 Expr::Literal(expr) => match expr.val {
110 Literal::Float(val) => Val::Float(val),
111 Literal::Int(val) => Val::Int(val),
112 Literal::Bool(val) => Val::Bool(val),
113 },
114 }
115}
116
117fn print_val(val: &Val) {
118 match val {
119 Val::Struct(vals) => {
120 println!("{{");
121 for (k, v) in vals {
122 print!(" {k}: ");
123 print_val(v);
124 }
125 println!("}}");
126 }
127 Val::Int(val) => println!("{val}"),
128 Val::Float(val) => println!("{val}"),
129 Val::Bool(val) => println!("{val}"),
130 }
131}
132
133fn eval_add(input: HashMap<String, Val>) -> Val {
134 let lhs = input.get("lhs").unwrap();
135 let rhs = input.get("rhs").unwrap();
136 match (lhs, rhs) {
137 (Val::Int(lhs), Val::Int(rhs)) => Val::Int(lhs + rhs),
138 (Val::Float(lhs), Val::Float(rhs)) => Val::Float(lhs + rhs),
139 _ => {
140 panic!("unsupported addition pair");
141 }
142 }
143}
144
145fn lookup_func<'a>(program: &'a ParsedProgram, func_name: &str) -> &'a FuncDecl {
146 let func = program
147 .funcs
148 .iter()
149 .find(|(name, _f)| name.as_str() == func_name);
150 if func.is_none() {
151 panic!("couldn't find {func_name} function");
152 }
153 func.unwrap().1
154}