use std::collections::HashMap;
use std::fmt;
use std::ops::{Add, Mul, Neg, Sub};
use crate::erf::erf;
use crate::gamma::{digamma, gamma, loggamma, polygamma};
use crate::hypergeometric::hyp1f1;
use crate::{bessel, erf as erf_mod};
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq)]
pub enum Expr {
Const(f64),
Var(String),
Add(Box<Expr>, Box<Expr>),
Mul(Box<Expr>, Box<Expr>),
Pow(Box<Expr>, Box<Expr>),
Neg(Box<Expr>),
Recip(Box<Expr>),
Gamma(Box<Expr>),
LogGamma(Box<Expr>),
Erf(Box<Expr>),
Erfc(Box<Expr>),
BesselJ(i32, Box<Expr>),
BesselY(i32, Box<Expr>),
BesselI(i32, Box<Expr>),
Hypergeometric1F1(Box<Expr>, Box<Expr>, Box<Expr>),
Exp(Box<Expr>),
Log(Box<Expr>),
Sin(Box<Expr>),
Cos(Box<Expr>),
}
impl Expr {
#[inline]
pub fn var(name: &str) -> Self {
Expr::Var(name.to_string())
}
#[inline]
pub fn konst(value: f64) -> Self {
Expr::Const(value)
}
#[inline]
pub fn exp(self) -> Self {
Expr::Exp(Box::new(self))
}
#[inline]
pub fn ln(self) -> Self {
Expr::Log(Box::new(self))
}
#[inline]
pub fn sin(self) -> Self {
Expr::Sin(Box::new(self))
}
#[inline]
pub fn cos(self) -> Self {
Expr::Cos(Box::new(self))
}
#[inline]
pub fn recip(self) -> Self {
Expr::Recip(Box::new(self))
}
#[inline]
pub fn pow(self, exp: Expr) -> Self {
Expr::Pow(Box::new(self), Box::new(exp))
}
}
impl Expr {
pub fn eval(&self, vars: &HashMap<&str, f64>) -> f64 {
match self {
Expr::Const(c) => *c,
Expr::Var(name) => *vars.get(name.as_str()).unwrap_or(&f64::NAN),
Expr::Add(a, b) => a.eval(vars) + b.eval(vars),
Expr::Mul(a, b) => a.eval(vars) * b.eval(vars),
Expr::Pow(base, exp) => base.eval(vars).powf(exp.eval(vars)),
Expr::Neg(x) => -x.eval(vars),
Expr::Recip(x) => x.eval(vars).recip(),
Expr::Gamma(x) => gamma(x.eval(vars)),
Expr::LogGamma(x) => loggamma(x.eval(vars)),
Expr::Erf(x) => erf(x.eval(vars)),
Expr::Erfc(x) => erf_mod::erfc(x.eval(vars)),
Expr::BesselJ(n, x) => bessel::jn(*n, x.eval(vars)),
Expr::BesselY(n, x) => bessel::yn(*n, x.eval(vars)),
Expr::BesselI(n, x) => {
let xv = x.eval(vars);
bessel::iv(f64::from(*n), xv)
}
Expr::Hypergeometric1F1(a, b, z) => {
let av = a.eval(vars);
let bv = b.eval(vars);
let zv = z.eval(vars);
hyp1f1(av, bv, zv).unwrap_or(f64::NAN)
}
Expr::Exp(x) => x.eval(vars).exp(),
Expr::Log(x) => x.eval(vars).ln(),
Expr::Sin(x) => x.eval(vars).sin(),
Expr::Cos(x) => x.eval(vars).cos(),
}
}
}
impl Expr {
pub fn diff(&self, var: &str) -> Expr {
match self {
Expr::Const(_) => Expr::Const(0.0),
Expr::Var(name) => {
if name == var {
Expr::Const(1.0)
} else {
Expr::Const(0.0)
}
}
Expr::Add(u, v) => Expr::Add(Box::new(u.diff(var)), Box::new(v.diff(var))),
Expr::Mul(u, v) => Expr::Add(
Box::new(Expr::Mul(Box::new(u.diff(var)), v.clone())),
Box::new(Expr::Mul(u.clone(), Box::new(v.diff(var)))),
),
Expr::Pow(u, v) => {
let u_diff = u.diff(var);
let v_diff = v.diff(var);
let term1 = Expr::Mul(
v.clone(),
Box::new(Expr::Mul(
Box::new(u_diff),
Box::new(Expr::Recip(u.clone())),
)),
);
let term2 = Expr::Mul(Box::new(v_diff), Box::new(Expr::Log(u.clone())));
Expr::Mul(
Box::new(self.clone()),
Box::new(Expr::Add(Box::new(term1), Box::new(term2))),
)
}
Expr::Neg(u) => Expr::Neg(Box::new(u.diff(var))),
Expr::Recip(u) => Expr::Neg(Box::new(Expr::Mul(
Box::new(u.diff(var)),
Box::new(Expr::Recip(Box::new(Expr::Pow(
u.clone(),
Box::new(Expr::Const(2.0)),
)))),
))),
Expr::Exp(u) => Expr::Mul(Box::new(self.clone()), Box::new(u.diff(var))),
Expr::Log(u) => Expr::Mul(Box::new(u.diff(var)), Box::new(Expr::Recip(u.clone()))),
Expr::Sin(u) => Expr::Mul(Box::new(Expr::Cos(u.clone())), Box::new(u.diff(var))),
Expr::Cos(u) => Expr::Mul(
Box::new(Expr::Neg(Box::new(Expr::Sin(u.clone())))),
Box::new(u.diff(var)),
),
Expr::Gamma(u) => {
let digamma_u = Expr::digamma_node(*u.clone());
Expr::Mul(
Box::new(self.clone()),
Box::new(Expr::Mul(Box::new(digamma_u), Box::new(u.diff(var)))),
)
}
Expr::LogGamma(u) => Expr::Mul(
Box::new(Expr::digamma_node(u.as_ref().clone())),
Box::new(u.diff(var)),
),
Expr::Erf(u) => {
let two_over_sqrt_pi = 2.0 / std::f64::consts::PI.sqrt();
Expr::Mul(
Box::new(Expr::Mul(
Box::new(Expr::Const(two_over_sqrt_pi)),
Box::new(Expr::Exp(Box::new(Expr::Neg(Box::new(Expr::Pow(
u.clone(),
Box::new(Expr::Const(2.0)),
)))))),
)),
Box::new(u.diff(var)),
)
}
Expr::Erfc(u) => Expr::Neg(Box::new(Expr::Erf(u.clone()).diff(var))),
Expr::BesselJ(n, u) => {
let n = *n;
let jnm1 = Expr::BesselJ(n - 1, u.clone());
let jnp1 = Expr::BesselJ(n + 1, u.clone());
let half = Expr::Const(0.5);
Expr::Mul(
Box::new(Expr::Mul(
Box::new(half),
Box::new(Expr::Add(
Box::new(jnm1),
Box::new(Expr::Neg(Box::new(jnp1))),
)),
)),
Box::new(u.diff(var)),
)
}
Expr::BesselY(n, u) => {
let n = *n;
let ynm1 = Expr::BesselY(n - 1, u.clone());
let ynp1 = Expr::BesselY(n + 1, u.clone());
Expr::Mul(
Box::new(Expr::Mul(
Box::new(Expr::Const(0.5)),
Box::new(Expr::Add(
Box::new(ynm1),
Box::new(Expr::Neg(Box::new(ynp1))),
)),
)),
Box::new(u.diff(var)),
)
}
Expr::BesselI(n, u) => {
let n = *n;
let inm1 = Expr::BesselI(n - 1, u.clone());
let inp1 = Expr::BesselI(n + 1, u.clone());
Expr::Mul(
Box::new(Expr::Mul(
Box::new(Expr::Const(0.5)),
Box::new(Expr::Add(Box::new(inm1), Box::new(inp1))),
)),
Box::new(u.diff(var)),
)
}
Expr::Hypergeometric1F1(a, b, z) => {
let ratio = Expr::Mul(a.clone(), Box::new(Expr::Recip(b.clone())));
let shifted = Expr::Hypergeometric1F1(
Box::new(Expr::Add(a.clone(), Box::new(Expr::Const(1.0)))),
Box::new(Expr::Add(b.clone(), Box::new(Expr::Const(1.0)))),
z.clone(),
);
Expr::Mul(
Box::new(Expr::Mul(Box::new(ratio), Box::new(shifted))),
Box::new(z.diff(var)),
)
}
}
}
fn diff_no_chain(&self, _var: &str) -> Expr {
Expr::digamma_node(self.clone())
}
pub(crate) fn digamma_node(u: Expr) -> Expr {
Expr::Mul(
Box::new(Expr::Const(f64::NAN)),
Box::new(Expr::LogGamma(Box::new(u))),
)
}
}
impl Expr {
pub fn simplify(&self) -> Expr {
match self {
Expr::Add(a, b) => {
let a = a.simplify();
let b = b.simplify();
match (&a, &b) {
(Expr::Const(0.0), _) => b,
(_, Expr::Const(0.0)) => a,
(Expr::Const(c1), Expr::Const(c2)) => Expr::Const(c1 + c2),
_ => Expr::Add(Box::new(a), Box::new(b)),
}
}
Expr::Mul(a, b) => {
let a = a.simplify();
let b = b.simplify();
match (&a, &b) {
(Expr::Const(c), _) if *c == 0.0 => Expr::Const(0.0),
(_, Expr::Const(c)) if *c == 0.0 => Expr::Const(0.0),
(Expr::Const(c), _) if *c == 1.0 => b,
(_, Expr::Const(c)) if *c == 1.0 => a,
(Expr::Const(c1), Expr::Const(c2)) => Expr::Const(c1 * c2),
_ => Expr::Mul(Box::new(a), Box::new(b)),
}
}
Expr::Pow(base, exp) => {
let base = base.simplify();
let exp = exp.simplify();
match (&base, &exp) {
(Expr::Const(c), _) if *c == 1.0 => Expr::Const(1.0),
(_, Expr::Const(c)) if *c == 1.0 => base,
(Expr::Const(c), _) if *c == 0.0 => Expr::Const(0.0),
(_, Expr::Const(c)) if *c == 0.0 => Expr::Const(1.0),
(Expr::Const(c1), Expr::Const(c2)) => Expr::Const(c1.powf(*c2)),
_ => Expr::Pow(Box::new(base), Box::new(exp)),
}
}
Expr::Neg(inner) => {
let inner = inner.simplify();
match inner {
Expr::Neg(x) => *x,
Expr::Const(c) => Expr::Const(-c),
other => Expr::Neg(Box::new(other)),
}
}
Expr::Recip(inner) => {
let inner = inner.simplify();
match inner {
Expr::Const(c) if c != 0.0 => Expr::Const(1.0 / c),
other => Expr::Recip(Box::new(other)),
}
}
Expr::Const(_) | Expr::Var(_) => self.clone(),
Expr::Gamma(u) => Expr::Gamma(Box::new(u.simplify())),
Expr::LogGamma(u) => Expr::LogGamma(Box::new(u.simplify())),
Expr::Erf(u) => Expr::Erf(Box::new(u.simplify())),
Expr::Erfc(u) => Expr::Erfc(Box::new(u.simplify())),
Expr::BesselJ(n, u) => Expr::BesselJ(*n, Box::new(u.simplify())),
Expr::BesselY(n, u) => Expr::BesselY(*n, Box::new(u.simplify())),
Expr::BesselI(n, u) => Expr::BesselI(*n, Box::new(u.simplify())),
Expr::Hypergeometric1F1(a, b, z) => Expr::Hypergeometric1F1(
Box::new(a.simplify()),
Box::new(b.simplify()),
Box::new(z.simplify()),
),
Expr::Exp(u) => Expr::Exp(Box::new(u.simplify())),
Expr::Log(u) => Expr::Log(Box::new(u.simplify())),
Expr::Sin(u) => Expr::Sin(Box::new(u.simplify())),
Expr::Cos(u) => Expr::Cos(Box::new(u.simplify())),
}
}
}
impl Add for Expr {
type Output = Expr;
fn add(self, rhs: Expr) -> Expr {
Expr::Add(Box::new(self), Box::new(rhs))
}
}
impl Mul for Expr {
type Output = Expr;
fn mul(self, rhs: Expr) -> Expr {
Expr::Mul(Box::new(self), Box::new(rhs))
}
}
impl Neg for Expr {
type Output = Expr;
fn neg(self) -> Expr {
Expr::Neg(Box::new(self))
}
}
impl Sub for Expr {
type Output = Expr;
fn sub(self, rhs: Expr) -> Expr {
Expr::Add(Box::new(self), Box::new(Expr::Neg(Box::new(rhs))))
}
}
impl fmt::Display for Expr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Expr::Const(c) => write!(f, "{c}"),
Expr::Var(name) => write!(f, "{name}"),
Expr::Add(a, b) => write!(f, "({a} + {b})"),
Expr::Mul(a, b) => {
if let Expr::Const(c) = a.as_ref() {
if c.is_nan() {
if let Expr::LogGamma(u) = b.as_ref() {
return write!(f, "ψ({u})");
}
}
}
write!(f, "({a} * {b})")
}
Expr::Pow(base, exp) => write!(f, "({base}^{exp})"),
Expr::Neg(x) => write!(f, "(-{x})"),
Expr::Recip(x) => write!(f, "(1/{x})"),
Expr::Gamma(x) => write!(f, "Γ({x})"),
Expr::LogGamma(x) => write!(f, "lnΓ({x})"),
Expr::Erf(x) => write!(f, "erf({x})"),
Expr::Erfc(x) => write!(f, "erfc({x})"),
Expr::BesselJ(n, x) => write!(f, "J_{n}({x})"),
Expr::BesselY(n, x) => write!(f, "Y_{n}({x})"),
Expr::BesselI(n, x) => write!(f, "I_{n}({x})"),
Expr::Hypergeometric1F1(a, b, z) => write!(f, "₁F₁({a};{b};{z})"),
Expr::Exp(x) => write!(f, "exp({x})"),
Expr::Log(x) => write!(f, "ln({x})"),
Expr::Sin(x) => write!(f, "sin({x})"),
Expr::Cos(x) => write!(f, "cos({x})"),
}
}
}
pub(crate) trait EvalExt {
fn eval_full(&self, vars: &HashMap<&str, f64>) -> f64;
}
impl EvalExt for Expr {
fn eval_full(&self, vars: &HashMap<&str, f64>) -> f64 {
match self {
Expr::Mul(a, b) => {
if let Expr::Const(c) = a.as_ref() {
if c.is_nan() {
if let Expr::LogGamma(u) = b.as_ref() {
let uv = u.eval_full(vars);
return fd_digamma(uv);
}
}
}
a.eval_full(vars) * b.eval_full(vars)
}
Expr::Add(a, b) => a.eval_full(vars) + b.eval_full(vars),
Expr::Neg(x) => -x.eval_full(vars),
Expr::Recip(x) => x.eval_full(vars).recip(),
Expr::Pow(base, exp) => base.eval_full(vars).powf(exp.eval_full(vars)),
Expr::Exp(x) => x.eval_full(vars).exp(),
Expr::Log(x) => x.eval_full(vars).ln(),
Expr::Sin(x) => x.eval_full(vars).sin(),
Expr::Cos(x) => x.eval_full(vars).cos(),
other => other.eval(vars),
}
}
}
fn fd_digamma(x: f64) -> f64 {
if x <= 0.0 {
return digamma(x);
}
let mut correction = 0.0_f64;
let mut xn = x;
while xn < 20.0 {
correction -= 1.0 / xn;
xn += 1.0;
}
let x2 = xn * xn;
let asymp = xn.ln() - 0.5 / xn - 1.0 / (12.0 * x2) + 1.0 / (120.0 * x2 * x2)
- 1.0 / (252.0 * x2 * x2 * x2);
asymp + correction
}
impl Expr {
pub fn eval_ext(&self, vars: &HashMap<&str, f64>) -> f64 {
EvalExt::eval_full(self, vars)
}
}