#[cfg(not(feature = "std"))]
use alloc::{boxed::Box, string::String, vec::Vec};
use crate::error::{CompileError, EvalError};
#[derive(Debug, Clone, PartialEq)]
pub enum Expr {
Number(f64),
Variable(String),
CurrentValue,
BinaryOp {
op: BinOp,
left: Box<Expr>,
right: Box<Expr>,
},
UnaryMinus(Box<Expr>),
FunctionCall {
name: String,
args: Vec<Expr>,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BinOp {
Add,
Sub,
Mul,
Div,
Mod,
Pow,
}
impl BinOp {
#[inline]
pub fn eval(self, left: f64, right: f64) -> Result<f64, EvalError> {
match self {
BinOp::Add => Ok(left + right),
BinOp::Sub => Ok(left - right),
BinOp::Mul => Ok(left * right),
BinOp::Div => {
if right == 0.0 {
Err(EvalError::DivisionByZero)
} else {
Ok(left / right)
}
}
BinOp::Mod => {
if right == 0.0 {
Err(EvalError::DivisionByZero)
} else {
Ok(left % right)
}
}
BinOp::Pow => Ok(libm::pow(left, right)),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BuiltinFn {
Abs,
Sqrt,
Log,
Log10,
Exp,
Min,
Max,
Pow,
Mod,
Sin,
Cos,
Tan,
Asin,
Acos,
Atan,
Sinh,
Cosh,
Tanh,
Floor,
Ceil,
Round,
Trunc,
Signum,
Cbrt,
Log2,
Clamp,
Pi,
E,
}
impl BuiltinFn {
pub fn from_name(name: &str) -> Option<(Self, usize)> {
match name {
"abs" => Some((BuiltinFn::Abs, 1)),
"sqrt" => Some((BuiltinFn::Sqrt, 1)),
"log" | "ln" => Some((BuiltinFn::Log, 1)),
"log10" => Some((BuiltinFn::Log10, 1)),
"exp" => Some((BuiltinFn::Exp, 1)),
"min" => Some((BuiltinFn::Min, 2)),
"max" => Some((BuiltinFn::Max, 2)),
"pow" => Some((BuiltinFn::Pow, 2)),
"mod" => Some((BuiltinFn::Mod, 2)),
"sin" => Some((BuiltinFn::Sin, 1)),
"cos" => Some((BuiltinFn::Cos, 1)),
"tan" => Some((BuiltinFn::Tan, 1)),
"asin" => Some((BuiltinFn::Asin, 1)),
"acos" => Some((BuiltinFn::Acos, 1)),
"atan" => Some((BuiltinFn::Atan, 1)),
"sinh" => Some((BuiltinFn::Sinh, 1)),
"cosh" => Some((BuiltinFn::Cosh, 1)),
"tanh" => Some((BuiltinFn::Tanh, 1)),
"floor" => Some((BuiltinFn::Floor, 1)),
"ceil" => Some((BuiltinFn::Ceil, 1)),
"round" => Some((BuiltinFn::Round, 1)),
"trunc" => Some((BuiltinFn::Trunc, 1)),
"signum" => Some((BuiltinFn::Signum, 1)),
"cbrt" => Some((BuiltinFn::Cbrt, 1)),
"log2" => Some((BuiltinFn::Log2, 1)),
"clamp" => Some((BuiltinFn::Clamp, 3)),
"pi" => Some((BuiltinFn::Pi, 0)),
"e" => Some((BuiltinFn::E, 0)),
_ => None,
}
}
#[inline]
pub fn eval(self, args: &[f64]) -> Result<f64, EvalError> {
match self {
BuiltinFn::Abs => Ok(libm::fabs(args[0])),
BuiltinFn::Sqrt => {
if args[0] < 0.0 {
Ok(f64::NAN)
} else {
Ok(libm::sqrt(args[0]))
}
}
BuiltinFn::Log => {
if args[0] <= 0.0 {
Ok(f64::NAN)
} else {
Ok(libm::log(args[0]))
}
}
BuiltinFn::Log10 => {
if args[0] <= 0.0 {
Ok(f64::NAN)
} else {
Ok(libm::log10(args[0]))
}
}
BuiltinFn::Exp => Ok(libm::exp(args[0])),
BuiltinFn::Min => Ok(libm::fmin(args[0], args[1])),
BuiltinFn::Max => Ok(libm::fmax(args[0], args[1])),
BuiltinFn::Pow => Ok(libm::pow(args[0], args[1])),
BuiltinFn::Mod => {
if args[1] == 0.0 {
Err(EvalError::DivisionByZero)
} else {
Ok(libm::fmod(args[0], args[1]))
}
}
BuiltinFn::Sin => Ok(libm::sin(args[0])),
BuiltinFn::Cos => Ok(libm::cos(args[0])),
BuiltinFn::Tan => Ok(libm::tan(args[0])),
BuiltinFn::Asin => Ok(libm::asin(args[0])), BuiltinFn::Acos => Ok(libm::acos(args[0])), BuiltinFn::Atan => Ok(libm::atan(args[0])),
BuiltinFn::Sinh => Ok(libm::sinh(args[0])),
BuiltinFn::Cosh => Ok(libm::cosh(args[0])),
BuiltinFn::Tanh => Ok(libm::tanh(args[0])),
BuiltinFn::Floor => Ok(libm::floor(args[0])),
BuiltinFn::Ceil => Ok(libm::ceil(args[0])),
BuiltinFn::Round => Ok(libm::round(args[0])),
BuiltinFn::Trunc => Ok(libm::trunc(args[0])),
BuiltinFn::Signum => {
if args[0].is_nan() {
Ok(f64::NAN)
} else if args[0] == 0.0 {
Ok(args[0]) } else if args[0] > 0.0 {
Ok(1.0)
} else {
Ok(-1.0)
}
}
BuiltinFn::Cbrt => Ok(libm::cbrt(args[0])),
BuiltinFn::Log2 => {
if args[0] <= 0.0 {
Ok(f64::NAN)
} else {
Ok(libm::log2(args[0]))
}
}
BuiltinFn::Clamp => {
let (x, min, max) = (args[0], args[1], args[2]);
Ok(libm::fmin(libm::fmax(x, min), max))
}
BuiltinFn::Pi => Ok(core::f64::consts::PI),
BuiltinFn::E => Ok(core::f64::consts::E),
}
}
pub fn arity(self) -> usize {
match self {
BuiltinFn::Pi | BuiltinFn::E => 0,
BuiltinFn::Abs
| BuiltinFn::Sqrt
| BuiltinFn::Log
| BuiltinFn::Log10
| BuiltinFn::Log2
| BuiltinFn::Exp
| BuiltinFn::Sin
| BuiltinFn::Cos
| BuiltinFn::Tan
| BuiltinFn::Asin
| BuiltinFn::Acos
| BuiltinFn::Atan
| BuiltinFn::Sinh
| BuiltinFn::Cosh
| BuiltinFn::Tanh
| BuiltinFn::Floor
| BuiltinFn::Ceil
| BuiltinFn::Round
| BuiltinFn::Trunc
| BuiltinFn::Signum
| BuiltinFn::Cbrt => 1,
BuiltinFn::Min | BuiltinFn::Max | BuiltinFn::Pow | BuiltinFn::Mod => 2,
BuiltinFn::Clamp => 3,
}
}
pub fn check_arity(self, got: usize, name: &str) -> Result<(), CompileError> {
let expected = self.arity();
if got != expected {
Err(CompileError::WrongArity {
name: String::from(name),
expected,
got,
})
} else {
Ok(())
}
}
}