kdl_script/
eval.rs

1use std::{collections::HashMap, sync::Arc};
2
3use miette::NamedSource;
4
5use crate::{
6    parse::{Expr, FuncDecl, Literal, ParsedProgram, Stmt},
7    spanned::Spanned,
8    Result,
9};
10
11#[derive(Debug, Clone)]
12enum Val {
13    Struct(HashMap<String, Val>),
14    Int(i64),
15    Float(f64),
16    Bool(bool),
17}
18
19pub fn eval_kdl_script(_src: &Arc<NamedSource>, program: &ParsedProgram) -> Result<i64> {
20    let main = lookup_func(program, "main");
21
22    let val = eval_call(program, main, HashMap::default());
23
24    match val {
25        Val::Int(val) => Ok(val),
26        Val::Float(val) => Ok(val as i64),
27        Val::Bool(val) => Ok(val as i64),
28        Val::Struct(_) => {
29            unreachable!("main returned struct!")
30        }
31    }
32}
33
34fn eval_call(program: &ParsedProgram, func: &FuncDecl, mut vars: HashMap<String, Val>) -> Val {
35    for stmt in &func.body {
36        match stmt {
37            Stmt::Let(stmt) => {
38                let temp = eval_expr(program, &stmt.expr, &vars);
39                if let Some(var) = &stmt.var {
40                    vars.insert(var.to_string(), temp);
41                }
42            }
43            Stmt::Return(stmt) => {
44                let temp = eval_expr(program, &stmt.expr, &vars);
45                return temp;
46            }
47            Stmt::Print(stmt) => {
48                let temp = eval_expr(program, &stmt.expr, &vars);
49                print_val(&temp);
50            }
51        }
52    }
53    unreachable!("function didn't return!");
54}
55
56fn eval_expr(program: &ParsedProgram, expr: &Spanned<Expr>, vars: &HashMap<String, Val>) -> Val {
57    match &**expr {
58        Expr::Call(expr) => {
59            let func = lookup_func(program, &expr.func);
60            assert_eq!(
61                func.inputs.len(),
62                expr.args.len(),
63                "function {} had wrong number of args",
64                &**expr.func
65            );
66            let input = func
67                .inputs
68                .iter()
69                .zip(expr.args.iter())
70                .map(|(var, expr)| {
71                    let val = eval_expr(program, expr, vars);
72                    let var = var.name.as_ref().unwrap().to_string();
73                    (var, val)
74                })
75                .collect();
76
77            match func.name.as_str() {
78                "+" => eval_add(input),
79                _ => eval_call(program, func, input),
80            }
81        }
82        Expr::Path(expr) => {
83            let mut sub_val = vars
84                .get(&**expr.var)
85                .unwrap_or_else(|| panic!("couldn't find var {}", &**expr.var));
86            for field in &expr.path {
87                if let Val::Struct(val) = sub_val {
88                    sub_val = val
89                        .get(field.as_str())
90                        .unwrap_or_else(|| panic!("couldn't find field {}", &**field));
91                } else {
92                    panic!("tried to get .{} on primitive", &**field);
93                }
94            }
95            sub_val.clone()
96        }
97        Expr::Ctor(expr) => {
98            let fields = expr
99                .vals
100                .iter()
101                .map(|stmt| {
102                    let val = eval_expr(program, &stmt.expr, vars);
103                    let var = stmt.var.as_ref().unwrap().to_string();
104                    (var, val)
105                })
106                .collect();
107            Val::Struct(fields)
108        }
109        Expr::Literal(expr) => match expr.val {
110            Literal::Float(val) => Val::Float(val),
111            Literal::Int(val) => Val::Int(val),
112            Literal::Bool(val) => Val::Bool(val),
113        },
114    }
115}
116
117fn print_val(val: &Val) {
118    match val {
119        Val::Struct(vals) => {
120            println!("{{");
121            for (k, v) in vals {
122                print!("  {k}: ");
123                print_val(v);
124            }
125            println!("}}");
126        }
127        Val::Int(val) => println!("{val}"),
128        Val::Float(val) => println!("{val}"),
129        Val::Bool(val) => println!("{val}"),
130    }
131}
132
133fn eval_add(input: HashMap<String, Val>) -> Val {
134    let lhs = input.get("lhs").unwrap();
135    let rhs = input.get("rhs").unwrap();
136    match (lhs, rhs) {
137        (Val::Int(lhs), Val::Int(rhs)) => Val::Int(lhs + rhs),
138        (Val::Float(lhs), Val::Float(rhs)) => Val::Float(lhs + rhs),
139        _ => {
140            panic!("unsupported addition pair");
141        }
142    }
143}
144
145fn lookup_func<'a>(program: &'a ParsedProgram, func_name: &str) -> &'a FuncDecl {
146    let func = program
147        .funcs
148        .iter()
149        .find(|(name, _f)| name.as_str() == func_name);
150    if func.is_none() {
151        panic!("couldn't find {func_name} function");
152    }
153    func.unwrap().1
154}