use serde::Deserialize;
use std::ops::{Add, Div, Mul, Rem, Sub};
use std::ops::{BitAnd, BitOr, BitXor};
use std::ops::{Neg, Not};
use anyhow::{anyhow, Result};
use crate::model::{Equation, Variable};
use crate::utils::bool_to_f64;
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type")]
pub enum Expr {
Tree {
op: String,
args: Vec<Expr>,
},
Const {
val: f64,
},
Var {
name: String,
},
Special,
Label {
id: usize,
},
Branch {
id: usize,
},
BranchIf {
cond: Box<Expr>,
id: usize,
is_else: bool,
},
}
impl From<f64> for Expr {
fn from(val: f64) -> Expr {
Expr::Const { val }
}
}
impl From<f32> for Expr {
fn from(val: f32) -> Expr {
Expr::Const { val: val.into() }
}
}
impl From<i32> for Expr {
fn from(val: i32) -> Expr {
Expr::Const { val: val.into() }
}
}
impl Expr {
pub fn var(name: &str) -> Expr {
Expr::Var {
name: name.to_string(),
}
}
pub fn unary(op: &str, arg: &Expr) -> Expr {
Expr::Tree {
op: op.to_string(),
args: vec![arg.clone()],
}
}
pub fn binary(op: &str, l: &Expr, r: &Expr) -> Expr {
Expr::Tree {
op: op.to_string(),
args: vec![l.clone(), r.clone()],
}
}
pub fn ternary(op: &str, l: &Expr, c: &Expr, r: &Expr) -> Expr {
Expr::Tree {
op: op.to_string(),
args: vec![l.clone(), c.clone(), r.clone()],
}
}
pub fn nary(op: &str, args: &[&Expr]) -> Expr {
if args.len() == 1 {
args[0].clone()
} else {
Expr::Tree {
op: op.to_string(),
args: args.iter().map(|x| (*x).clone()).collect::<Vec<Expr>>(),
}
}
}
pub fn equation(lhs: &Expr, rhs: &Expr) -> Equation {
Equation {
lhs: lhs.clone(),
rhs: rhs.clone(),
}
}
pub fn special(rhs: &Expr) -> Equation {
Equation {
lhs: Expr::Special,
rhs: rhs.clone(),
}
}
pub fn to_variable(&self) -> Result<Variable> {
if let Expr::Var { name } = self {
Ok(Variable {
name: name.to_string(),
})
} else {
Err(anyhow!("cannot convert {:?} to a Variable", self))
}
}
pub fn add(&self, other: &Expr) -> Expr {
Expr::binary("plus", self, other)
}
pub fn sub(&self, other: &Self) -> Expr {
Expr::binary("minus", self, other)
}
pub fn mul(&self, other: &Self) -> Expr {
Expr::binary("times", self, other)
}
pub fn div(&self, other: &Self) -> Expr {
Expr::binary("divide", self, other)
}
pub fn rem(&self, other: &Self) -> Expr {
Expr::binary("rem", self, other)
}
pub fn bitand(&self, other: &Self) -> Expr {
Expr::binary("and", self, other)
}
pub fn bitor(&self, other: &Self) -> Expr {
Expr::binary("or", self, other)
}
pub fn bitxor(&self, other: &Self) -> Expr {
Expr::binary("xor", self, other)
}
pub fn eq(&self, other: &Expr) -> Expr {
Self::binary("eq", self, other)
}
pub fn ne(&self, other: &Expr) -> Expr {
Self::binary("neq", self, other)
}
pub fn lt(&self, other: &Expr) -> Expr {
Self::binary("lt", self, other)
}
pub fn le(&self, other: &Expr) -> Expr {
Self::binary("leq", self, other)
}
pub fn gt(&self, other: &Expr) -> Expr {
Self::binary("gt", self, other)
}
pub fn ge(&self, other: &Expr) -> Expr {
Self::binary("geq", self, other)
}
pub fn min(&self, other: &Expr) -> Expr {
Self::binary("min", self, other)
}
pub fn max(&self, other: &Expr) -> Expr {
Self::binary("max", self, other)
}
pub fn pow(&self, other: &Expr) -> Expr {
Self::binary("power", self, other)
}
pub fn ifelse(&self, true_val: &Expr, false_val: &Expr) -> Expr {
Self::ternary("ifelse", self, true_val, false_val)
}
pub fn sum(&self, var: &Expr, start: &Expr, end: &Expr) -> Expr {
Self::nary("Sum", &[self, var, start, end])
}
pub fn prod(&self, var: &Expr, start: &Expr, end: &Expr) -> Expr {
Self::nary("Product", &[self, var, start, end])
}
pub fn not(&self) -> Expr {
Expr::unary("not", self)
}
pub fn neg(&self) -> Expr {
Expr::unary("neg", self)
}
pub fn square(&self) -> Expr {
Expr::unary("square", self)
}
pub fn cube(&self) -> Expr {
Expr::unary("cube", self)
}
pub fn recip(&self) -> Expr {
Expr::unary("recip", self)
}
pub fn round(&self) -> Expr {
Expr::unary("round", self)
}
pub fn trunc(&self) -> Expr {
Expr::unary("trunc", self)
}
pub fn floor(&self) -> Expr {
Expr::unary("floor", self)
}
pub fn ceil(&self) -> Expr {
Expr::unary("ceiling", self)
}
pub fn fract(&self) -> Expr {
Expr::unary("frac", self)
}
pub fn heaviside(&self) -> Expr {
Self::unary("heaviside", self)
}
pub fn sqrt(&self) -> Expr {
Expr::unary("root", self)
}
pub fn abs(&self) -> Expr {
Expr::unary("abs", self)
}
pub fn sin(&self) -> Expr {
Expr::unary("sin", self)
}
pub fn cos(&self) -> Expr {
Expr::unary("cos", self)
}
pub fn tan(&self) -> Expr {
Expr::unary("tan", self)
}
pub fn csc(&self) -> Expr {
Expr::unary("csc", self)
}
pub fn sec(&self) -> Expr {
Expr::unary("sec", self)
}
pub fn cot(&self) -> Expr {
Expr::unary("cot", self)
}
pub fn sinh(&self) -> Expr {
Expr::unary("sinh", self)
}
pub fn cosh(&self) -> Expr {
Expr::unary("cosh", self)
}
pub fn tanh(&self) -> Expr {
Expr::unary("tanh", self)
}
pub fn csch(&self) -> Expr {
Expr::unary("csch", self)
}
pub fn sech(&self) -> Expr {
Expr::unary("sech", self)
}
pub fn coth(&self) -> Expr {
Expr::unary("coth", self)
}
pub fn asin(&self) -> Expr {
Expr::unary("arcsin", self)
}
pub fn acos(&self) -> Expr {
Expr::unary("arccos", self)
}
pub fn atan(&self) -> Expr {
Expr::unary("arctan", self)
}
pub fn asinh(&self) -> Expr {
Expr::unary("arcsinh", self)
}
pub fn acosh(&self) -> Expr {
Expr::unary("arccosh", self)
}
pub fn atanh(&self) -> Expr {
Expr::unary("arctanh", self)
}
pub fn sinc(&self) -> Expr {
Expr::unary("sinc", self)
}
pub fn cbrt(&self) -> Expr {
Expr::unary("cbrt", self)
}
pub fn exp(&self) -> Expr {
Expr::unary("exp", self)
}
pub fn ln(&self) -> Expr {
Expr::unary("ln", self)
}
pub fn log10(&self) -> Expr {
Expr::unary("log", self)
}
pub fn exp_m1(&self) -> Expr {
Expr::unary("expm1", self)
}
pub fn ln_1p(&self) -> Expr {
Expr::unary("log1p", self)
}
pub fn log2(&self) -> Expr {
Expr::unary("log2", self)
}
pub fn exp2(&self) -> Expr {
Expr::unary("exp2", self)
}
pub fn erf(&self) -> Expr {
Expr::unary("erf", self)
}
pub fn erfc(&self) -> Expr {
Expr::unary("erfc", self)
}
pub fn gamma(&self) -> Expr {
Expr::unary("gamma", self)
}
pub fn lgam(&self) -> Expr {
Expr::unary("loggamma", self)
}
pub fn si(&self) -> Expr {
Expr::unary("si", self)
}
pub fn ci(&self) -> Expr {
Expr::unary("ci", self)
}
pub fn shi(&self) -> Expr {
Expr::unary("shi", self)
}
pub fn chi(&self) -> Expr {
Expr::unary("chi", self)
}
}
impl Add for &Expr {
type Output = Expr;
fn add(self, other: Self) -> Expr {
Expr::add(self, other)
}
}
impl Sub for &Expr {
type Output = Expr;
fn sub(self, other: Self) -> Expr {
Expr::sub(self, other)
}
}
impl Mul for &Expr {
type Output = Expr;
fn mul(self, other: Self) -> Expr {
Expr::mul(self, other)
}
}
impl Div for &Expr {
type Output = Expr;
fn div(self, other: Self) -> Expr {
Expr::div(self, other)
}
}
impl Rem for &Expr {
type Output = Expr;
fn rem(self, other: Self) -> Expr {
Expr::rem(self, other)
}
}
impl BitAnd for &Expr {
type Output = Expr;
fn bitand(self, other: Self) -> Expr {
Expr::bitand(self, other)
}
}
impl BitOr for &Expr {
type Output = Expr;
fn bitor(self, other: Self) -> Expr {
Expr::bitor(self, other)
}
}
impl BitXor for &Expr {
type Output = Expr;
fn bitxor(self, other: Self) -> Expr {
Expr::bitxor(self, other)
}
}
impl Neg for &Expr {
type Output = Expr;
fn neg(self) -> Expr {
Expr::neg(self)
}
}
impl Not for &Expr {
type Output = Expr;
fn not(self) -> Expr {
Expr::neg(self)
}
}
pub fn var(name: &str) -> Expr {
Expr::var(name)
}
pub fn double(val: f64) -> Expr {
Expr::from(val)
}
pub fn int(val: i32) -> Expr {
Expr::from(val)
}
pub fn boolean(val: bool) -> Expr {
Expr::from(bool_to_f64(val))
}