kdl-script 1.1.0

execute kdl documents!
Documentation
use std::{collections::HashMap, sync::Arc};

use miette::NamedSource;

use crate::{
    parse::{Expr, FuncDecl, Literal, ParsedProgram, Stmt},
    spanned::Spanned,
    Result,
};

#[derive(Debug, Clone)]
enum Val {
    Struct(HashMap<String, Val>),
    Int(i64),
    Float(f64),
    Bool(bool),
}

pub fn eval_kdl_script(_src: &Arc<NamedSource>, program: &ParsedProgram) -> Result<i64> {
    let main = lookup_func(program, "main");

    let val = eval_call(program, main, HashMap::default());

    match val {
        Val::Int(val) => Ok(val),
        Val::Float(val) => Ok(val as i64),
        Val::Bool(val) => Ok(val as i64),
        Val::Struct(_) => {
            unreachable!("main returned struct!")
        }
    }
}

fn eval_call(program: &ParsedProgram, func: &FuncDecl, mut vars: HashMap<String, Val>) -> Val {
    for stmt in &func.body {
        match stmt {
            Stmt::Let(stmt) => {
                let temp = eval_expr(program, &stmt.expr, &vars);
                if let Some(var) = &stmt.var {
                    vars.insert(var.to_string(), temp);
                }
            }
            Stmt::Return(stmt) => {
                let temp = eval_expr(program, &stmt.expr, &vars);
                return temp;
            }
            Stmt::Print(stmt) => {
                let temp = eval_expr(program, &stmt.expr, &vars);
                print_val(&temp);
            }
        }
    }
    unreachable!("function didn't return!");
}

fn eval_expr(program: &ParsedProgram, expr: &Spanned<Expr>, vars: &HashMap<String, Val>) -> Val {
    match &**expr {
        Expr::Call(expr) => {
            let func = lookup_func(program, &expr.func);
            assert_eq!(
                func.inputs.len(),
                expr.args.len(),
                "function {} had wrong number of args",
                &**expr.func
            );
            let input = func
                .inputs
                .iter()
                .zip(expr.args.iter())
                .map(|(var, expr)| {
                    let val = eval_expr(program, expr, vars);
                    let var = var.name.as_ref().unwrap().to_string();
                    (var, val)
                })
                .collect();

            match func.name.as_str() {
                "+" => eval_add(input),
                _ => eval_call(program, func, input),
            }
        }
        Expr::Path(expr) => {
            let mut sub_val = vars
                .get(&**expr.var)
                .unwrap_or_else(|| panic!("couldn't find var {}", &**expr.var));
            for field in &expr.path {
                if let Val::Struct(val) = sub_val {
                    sub_val = val
                        .get(field.as_str())
                        .unwrap_or_else(|| panic!("couldn't find field {}", &**field));
                } else {
                    panic!("tried to get .{} on primitive", &**field);
                }
            }
            sub_val.clone()
        }
        Expr::Ctor(expr) => {
            let fields = expr
                .vals
                .iter()
                .map(|stmt| {
                    let val = eval_expr(program, &stmt.expr, vars);
                    let var = stmt.var.as_ref().unwrap().to_string();
                    (var, val)
                })
                .collect();
            Val::Struct(fields)
        }
        Expr::Literal(expr) => match expr.val {
            Literal::Float(val) => Val::Float(val),
            Literal::Int(val) => Val::Int(val),
            Literal::Bool(val) => Val::Bool(val),
        },
    }
}

fn print_val(val: &Val) {
    match val {
        Val::Struct(vals) => {
            println!("{{");
            for (k, v) in vals {
                print!("  {k}: ");
                print_val(v);
            }
            println!("}}");
        }
        Val::Int(val) => println!("{val}"),
        Val::Float(val) => println!("{val}"),
        Val::Bool(val) => println!("{val}"),
    }
}

fn eval_add(input: HashMap<String, Val>) -> Val {
    let lhs = input.get("lhs").unwrap();
    let rhs = input.get("rhs").unwrap();
    match (lhs, rhs) {
        (Val::Int(lhs), Val::Int(rhs)) => Val::Int(lhs + rhs),
        (Val::Float(lhs), Val::Float(rhs)) => Val::Float(lhs + rhs),
        _ => {
            panic!("unsupported addition pair");
        }
    }
}

fn lookup_func<'a>(program: &'a ParsedProgram, func_name: &str) -> &'a FuncDecl {
    let func = program
        .funcs
        .iter()
        .find(|(name, _f)| name.as_str() == func_name);
    if func.is_none() {
        panic!("couldn't find {func_name} function");
    }
    func.unwrap().1
}