use chumsky::prelude::*;
#[derive(Debug)]
enum Expr<'src> {
Num(f64),
Var(&'src str),
Neg(Box<Expr<'src>>),
Add(Box<Expr<'src>>, Box<Expr<'src>>),
Sub(Box<Expr<'src>>, Box<Expr<'src>>),
Mul(Box<Expr<'src>>, Box<Expr<'src>>),
Div(Box<Expr<'src>>, Box<Expr<'src>>),
Call(&'src str, Vec<Expr<'src>>),
Let {
name: &'src str,
rhs: Box<Expr<'src>>,
then: Box<Expr<'src>>,
},
Fn {
name: &'src str,
args: Vec<&'src str>,
body: Box<Expr<'src>>,
then: Box<Expr<'src>>,
},
}
#[allow(clippy::let_and_return)]
fn parser<'src>() -> impl Parser<'src, &'src str, Expr<'src>> {
let ident = text::ascii::ident().padded();
let expr = recursive(|expr| {
let int = text::int(10).map(|s: &str| Expr::Num(s.parse().unwrap()));
let call = ident
.then(
expr.clone()
.separated_by(just(','))
.allow_trailing()
.collect::<Vec<_>>()
.delimited_by(just('('), just(')')),
)
.map(|(f, args)| Expr::Call(f, args));
let atom = int
.or(expr.delimited_by(just('('), just(')')))
.or(call)
.or(ident.map(Expr::Var))
.padded();
let op = |c| just(c).padded();
let unary = op('-')
.repeated()
.foldr(atom, |_op, rhs| Expr::Neg(Box::new(rhs)));
let product = unary.clone().foldl(
choice((
op('*').to(Expr::Mul as fn(_, _) -> _),
op('/').to(Expr::Div as fn(_, _) -> _),
))
.then(unary)
.repeated(),
|lhs, (op, rhs)| op(Box::new(lhs), Box::new(rhs)),
);
let sum = product.clone().foldl(
choice((
op('+').to(Expr::Add as fn(_, _) -> _),
op('-').to(Expr::Sub as fn(_, _) -> _),
))
.then(product)
.repeated(),
|lhs, (op, rhs)| op(Box::new(lhs), Box::new(rhs)),
);
sum
});
let decl = recursive(|decl| {
let r#let = text::ascii::keyword("let")
.ignore_then(ident)
.then_ignore(just('='))
.then(expr.clone())
.then_ignore(just(';'))
.then(decl.clone())
.map(|((name, rhs), then)| Expr::Let {
name,
rhs: Box::new(rhs),
then: Box::new(then),
});
let r#fn = text::ascii::keyword("fn")
.ignore_then(ident)
.then(ident.repeated().collect::<Vec<_>>())
.then_ignore(just('='))
.then(expr.clone())
.then_ignore(just(';'))
.then(decl)
.map(|(((name, args), body), then)| Expr::Fn {
name,
args,
body: Box::new(body),
then: Box::new(then),
});
r#let.or(r#fn).or(expr).padded()
});
decl
}
fn eval<'src>(
expr: &'src Expr<'src>,
vars: &mut Vec<(&'src str, f64)>,
funcs: &mut Vec<(&'src str, &'src [&'src str], &'src Expr<'src>)>,
) -> Result<f64, String> {
match expr {
Expr::Num(x) => Ok(*x),
Expr::Neg(a) => Ok(-eval(a, vars, funcs)?),
Expr::Add(a, b) => Ok(eval(a, vars, funcs)? + eval(b, vars, funcs)?),
Expr::Sub(a, b) => Ok(eval(a, vars, funcs)? - eval(b, vars, funcs)?),
Expr::Mul(a, b) => Ok(eval(a, vars, funcs)? * eval(b, vars, funcs)?),
Expr::Div(a, b) => Ok(eval(a, vars, funcs)? / eval(b, vars, funcs)?),
Expr::Var(name) => {
if let Some((_, val)) = vars.iter().rev().find(|(var, _)| var == name) {
Ok(*val)
} else {
Err(format!("Cannot find variable `{name}` in scope"))
}
}
Expr::Let { name, rhs, then } => {
let rhs = eval(rhs, vars, funcs)?;
vars.push((*name, rhs));
let output = eval(then, vars, funcs);
vars.pop();
output
}
Expr::Call(name, args) => {
if let Some((_, arg_names, body)) =
funcs.iter().rev().find(|(var, _, _)| var == name).copied()
{
if arg_names.len() == args.len() {
let mut args = args
.iter()
.map(|arg| eval(arg, vars, funcs))
.zip(arg_names.iter())
.map(|(val, name)| Ok((*name, val?)))
.collect::<Result<_, String>>()?;
let old_vars = vars.len();
vars.append(&mut args);
let output = eval(body, vars, funcs);
vars.truncate(old_vars);
output
} else {
Err(format!(
"Wrong number of arguments for function `{name}`: expected {}, found {}",
arg_names.len(),
args.len(),
))
}
} else {
Err(format!("Cannot find function `{name}` in scope"))
}
}
Expr::Fn {
name,
args,
body,
then,
} => {
funcs.push((name, args, body));
let output = eval(then, vars, funcs);
funcs.pop();
output
}
}
}
fn main() {
let usage = "Run `cargo run --example foo -- examples/sample.foo`";
let src = std::fs::read_to_string(std::env::args().nth(1).expect(usage)).expect(usage);
match parser().parse(&src).into_result() {
Ok(ast) => match eval(&ast, &mut Vec::new(), &mut Vec::new()) {
Ok(output) => println!("{output}"),
Err(eval_err) => println!("Evaluation error: {eval_err}"),
},
Err(parse_errs) => parse_errs
.into_iter()
.for_each(|err| println!("Parse error: {err}")),
};
}