#[derive(Debug, Clone, Copy)]
pub enum BinaryOp {
Add, Sub, Mul, Div, Mod, Pow, Atan2, Less, IntDiv, }
#[derive(Debug, Clone, Copy)]
pub enum UnaryOp {
Abs, Neg, Floor, Ceil, Tanh, Tan, Sqrt, Sinh, Sin, Log10, Log, Exp, Cosh, Cos, Atanh, Atan, Asinh, Asin, Acosh, Acos, }
#[derive(Debug, Clone, Copy)]
pub enum NaryOp {
Sum, Min, Max, }
#[derive(Debug, Clone)]
pub enum ExprNode {
Const(f64),
Var(usize),
Binary(BinaryOp, Box<ExprNode>, Box<ExprNode>),
Unary(UnaryOp, Box<ExprNode>),
Nary(NaryOp, Vec<ExprNode>),
If(Box<ExprNode>, Box<ExprNode>, Box<ExprNode>),
StringLiteral(String),
}
pub fn parse_expr<'a>(lines: &mut impl Iterator<Item = &'a str>) -> Result<ExprNode, String> {
let line = lines
.next()
.ok_or("Unexpected end of expression")?
.trim();
let token = line.split('#').next().unwrap_or("").trim();
if token.starts_with('n') {
let val: f64 = token[1..]
.parse()
.map_err(|e| format!("Bad number '{}': {}", &token[1..], e))?;
Ok(ExprNode::Const(val))
} else if token.starts_with('v') {
let idx: usize = token[1..]
.parse()
.map_err(|e| format!("Bad var index '{}': {}", &token[1..], e))?;
Ok(ExprNode::Var(idx))
} else if token.starts_with('h') {
let rest = &token[1..];
if let Some(colon) = rest.find(':') {
Ok(ExprNode::StringLiteral(rest[colon + 1..].to_string()))
} else {
Ok(ExprNode::StringLiteral(String::new()))
}
} else if token.starts_with('o') {
let opcode: usize = token[1..]
.parse()
.map_err(|e| format!("Bad opcode '{}': {}", &token[1..], e))?;
parse_op(opcode, lines)
} else {
Err(format!("Unknown expression token: '{}'", token))
}
}
fn parse_op<'a>(
opcode: usize,
lines: &mut impl Iterator<Item = &'a str>,
) -> Result<ExprNode, String> {
match opcode {
0 => binary(BinaryOp::Add, lines),
1 => binary(BinaryOp::Sub, lines),
2 => binary(BinaryOp::Mul, lines),
3 => binary(BinaryOp::Div, lines),
4 => binary(BinaryOp::Mod, lines),
5 => binary(BinaryOp::Pow, lines),
22 => binary(BinaryOp::Less, lines),
48 => binary(BinaryOp::Atan2, lines),
55 => binary(BinaryOp::IntDiv, lines),
13 => unary(UnaryOp::Floor, lines),
14 => unary(UnaryOp::Ceil, lines),
15 => unary(UnaryOp::Abs, lines),
16 => unary(UnaryOp::Neg, lines),
37 => unary(UnaryOp::Tanh, lines),
38 => unary(UnaryOp::Tan, lines),
39 => unary(UnaryOp::Sqrt, lines),
40 => unary(UnaryOp::Sinh, lines),
41 => unary(UnaryOp::Sin, lines),
42 => unary(UnaryOp::Log10, lines),
43 => unary(UnaryOp::Log, lines),
44 => unary(UnaryOp::Exp, lines),
45 => unary(UnaryOp::Cosh, lines),
46 => unary(UnaryOp::Cos, lines),
47 => unary(UnaryOp::Atanh, lines),
49 => unary(UnaryOp::Atan, lines),
50 => unary(UnaryOp::Asinh, lines),
51 => unary(UnaryOp::Asin, lines),
52 => unary(UnaryOp::Acosh, lines),
53 => unary(UnaryOp::Acos, lines),
11 => nary(NaryOp::Min, lines),
12 => nary(NaryOp::Max, lines),
54 => nary(NaryOp::Sum, lines),
35 => {
let cond = parse_expr(lines)?;
let then_expr = parse_expr(lines)?;
let else_expr = parse_expr(lines)?;
Ok(ExprNode::If(
Box::new(cond),
Box::new(then_expr),
Box::new(else_expr),
))
}
65 => {
let cond = parse_expr(lines)?;
let then_expr = parse_expr(lines)?;
Ok(ExprNode::If(
Box::new(cond),
Box::new(then_expr),
Box::new(ExprNode::Const(0.0)),
))
}
_ => Err(format!("Unsupported opcode: o{}", opcode)),
}
}
fn binary<'a>(
op: BinaryOp,
lines: &mut impl Iterator<Item = &'a str>,
) -> Result<ExprNode, String> {
let left = parse_expr(lines)?;
let right = parse_expr(lines)?;
Ok(ExprNode::Binary(op, Box::new(left), Box::new(right)))
}
fn unary<'a>(op: UnaryOp, lines: &mut impl Iterator<Item = &'a str>) -> Result<ExprNode, String> {
let arg = parse_expr(lines)?;
Ok(ExprNode::Unary(op, Box::new(arg)))
}
fn nary<'a>(op: NaryOp, lines: &mut impl Iterator<Item = &'a str>) -> Result<ExprNode, String> {
let count_line = lines
.next()
.ok_or("Expected count for n-ary op")?
.trim();
let count_str = count_line.split('#').next().unwrap_or("").trim();
let count: usize = count_str
.parse()
.map_err(|e| format!("Bad n-ary count '{}': {}", count_str, e))?;
let args: Result<Vec<_>, _> = (0..count).map(|_| parse_expr(lines)).collect();
Ok(ExprNode::Nary(op, args?))
}
pub fn eval_expr(node: &ExprNode, vals: &[f64]) -> f64 {
match node {
ExprNode::Const(c) => *c,
ExprNode::Var(i) => vals[*i],
ExprNode::Binary(op, l, r) => {
let lv = eval_expr(l, vals);
let rv = eval_expr(r, vals);
match op {
BinaryOp::Add => lv + rv,
BinaryOp::Sub => lv - rv,
BinaryOp::Mul => lv * rv,
BinaryOp::Div => lv / rv,
BinaryOp::Mod => lv % rv,
BinaryOp::Pow => lv.powf(rv),
BinaryOp::Atan2 => lv.atan2(rv),
BinaryOp::Less => {
if lv < rv {
lv
} else {
rv
}
}
BinaryOp::IntDiv => (lv / rv).floor(),
}
}
ExprNode::Unary(op, a) => {
let av = eval_expr(a, vals);
match op {
UnaryOp::Abs => av.abs(),
UnaryOp::Neg => -av,
UnaryOp::Floor => av.floor(),
UnaryOp::Ceil => av.ceil(),
UnaryOp::Tanh => av.tanh(),
UnaryOp::Tan => av.tan(),
UnaryOp::Sqrt => av.sqrt(),
UnaryOp::Sinh => av.sinh(),
UnaryOp::Sin => av.sin(),
UnaryOp::Log10 => av.log10(),
UnaryOp::Log => av.ln(),
UnaryOp::Exp => av.exp(),
UnaryOp::Cosh => av.cosh(),
UnaryOp::Cos => av.cos(),
UnaryOp::Atanh => av.atanh(),
UnaryOp::Atan => av.atan(),
UnaryOp::Asinh => av.asinh(),
UnaryOp::Asin => av.asin(),
UnaryOp::Acosh => av.acosh(),
UnaryOp::Acos => av.acos(),
}
}
ExprNode::Nary(op, args) => {
let vals_iter = args.iter().map(|a| eval_expr(a, vals));
match op {
NaryOp::Sum => vals_iter.sum(),
NaryOp::Min => vals_iter.fold(f64::INFINITY, f64::min),
NaryOp::Max => vals_iter.fold(f64::NEG_INFINITY, f64::max),
}
}
ExprNode::If(cond, then_expr, else_expr) => {
if eval_expr(cond, vals) > 0.0 {
eval_expr(then_expr, vals)
} else {
eval_expr(else_expr, vals)
}
}
ExprNode::StringLiteral(_) => 0.0,
}
}