use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use crate::Scalar;
use super::node::ExprId;
use super::with_graph;
impl Add for ExprId {
type Output = Self;
#[inline]
fn add(self, rhs: Self) -> Self {
with_graph(|g| g.add(self, rhs))
}
}
impl Sub for ExprId {
type Output = Self;
#[inline]
fn sub(self, rhs: Self) -> Self {
let nb = with_graph(|g| g.neg(rhs));
with_graph(|g| g.add(self, nb))
}
}
impl Mul for ExprId {
type Output = Self;
#[inline]
fn mul(self, rhs: Self) -> Self {
with_graph(|g| g.mul(self, rhs))
}
}
impl Div for ExprId {
type Output = Self;
#[inline]
fn div(self, rhs: Self) -> Self {
let rb = with_graph(|g| g.recip(rhs));
with_graph(|g| g.mul(self, rb))
}
}
impl Neg for ExprId {
type Output = Self;
#[inline]
fn neg(self) -> Self {
with_graph(|g| g.neg(self))
}
}
impl AddAssign for ExprId {
#[inline]
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs;
}
}
impl SubAssign for ExprId {
#[inline]
fn sub_assign(&mut self, rhs: Self) {
*self = *self - rhs;
}
}
impl MulAssign for ExprId {
#[inline]
fn mul_assign(&mut self, rhs: Self) {
*self = *self * rhs;
}
}
impl DivAssign for ExprId {
#[inline]
fn div_assign(&mut self, rhs: Self) {
*self = *self / rhs;
}
}
impl Scalar for ExprId {
const ZERO: Self = ExprId::ZERO;
const ONE: Self = ExprId::ONE;
const TWO: Self = ExprId::TWO;
const HALF: Self = ExprId(u32::MAX - 1);
const PI: Self = ExprId(u32::MAX - 2);
const TAU: Self = ExprId(u32::MAX - 3);
const FRAC_PI_2: Self = ExprId(u32::MAX - 4);
const EPSILON: Self = ExprId(u32::MAX - 5);
const INFINITY: Self = ExprId(u32::MAX - 6);
const NEG_INFINITY: Self = ExprId(u32::MAX - 7);
#[inline]
fn sqrt(self) -> Self {
with_graph(|g| g.sqrt(self))
}
#[inline]
fn abs(self) -> Self {
let xx = with_graph(|g| g.mul(self, self));
with_graph(|g| g.sqrt(xx))
}
#[inline]
fn sin(self) -> Self {
with_graph(|g| g.sin(self))
}
#[inline]
fn cos(self) -> Self {
let half_pi = Self::from_f64(std::f64::consts::FRAC_PI_2);
let shifted = with_graph(|g| g.add(self, half_pi));
with_graph(|g| g.sin(shifted))
}
#[inline]
fn tan(self) -> Self {
let s = self.sin();
let c = self.cos();
let rc = with_graph(|g| g.recip(c));
with_graph(|g| g.mul(s, rc))
}
#[inline]
fn asin(self) -> Self {
let one = ExprId::ONE;
let xx = with_graph(|g| g.mul(self, self));
let diff = with_graph(|g| {
let neg_xx = g.neg(xx);
g.add(one, neg_xx)
});
let sq = with_graph(|g| g.sqrt(diff));
with_graph(|g| g.atan2(self, sq))
}
#[inline]
fn acos(self) -> Self {
let one = ExprId::ONE;
let xx = with_graph(|g| g.mul(self, self));
let diff = with_graph(|g| {
let neg_xx = g.neg(xx);
g.add(one, neg_xx)
});
let sq = with_graph(|g| g.sqrt(diff));
with_graph(|g| g.atan2(sq, self))
}
#[inline]
fn atan2(self, other: Self) -> Self {
with_graph(|g| g.atan2(self, other))
}
#[inline]
fn sin_cos(self) -> (Self, Self) {
(self.sin(), self.cos())
}
#[inline]
fn min(self, other: Self) -> Self {
let half = Self::from_f64(0.5);
let sum = self + other;
let diff = self - other;
let diff_sq = with_graph(|g| g.mul(diff, diff));
let abs_diff = with_graph(|g| g.sqrt(diff_sq));
let neg_abs = with_graph(|g| g.neg(abs_diff));
let inner = with_graph(|g| g.add(sum, neg_abs));
with_graph(|g| g.mul(half, inner))
}
#[inline]
fn max(self, other: Self) -> Self {
let half = Self::from_f64(0.5);
let sum = self + other;
let diff = self - other;
let diff_sq = with_graph(|g| g.mul(diff, diff));
let abs_diff = with_graph(|g| g.sqrt(diff_sq));
let inner = with_graph(|g| g.add(sum, abs_diff));
with_graph(|g| g.mul(half, inner))
}
#[inline]
fn clamp(self, lo: Self, hi: Self) -> Self {
self.max(lo).min(hi)
}
#[inline]
fn recip(self) -> Self {
with_graph(|g| g.recip(self))
}
#[inline]
fn powi(self, n: i32) -> Self {
match n {
0 => ExprId::ONE,
1 => self,
2 => with_graph(|g| g.mul(self, self)),
3 => {
let sq = with_graph(|g| g.mul(self, self));
with_graph(|g| g.mul(sq, self))
}
4 => {
let sq = with_graph(|g| g.mul(self, self));
with_graph(|g| g.mul(sq, sq))
}
-1 => with_graph(|g| g.recip(self)),
-2 => {
let sq = with_graph(|g| g.mul(self, self));
with_graph(|g| g.recip(sq))
}
_ => self.powf(Self::from_f64(n as f64)),
}
}
#[inline]
fn copysign(self, sign: Self) -> Self {
let ax = self.abs();
let ss = sign.signum();
with_graph(|g| g.mul(ax, ss))
}
#[inline]
fn signum(self) -> Self {
let xx = with_graph(|g| g.mul(self, self));
let abs_x = with_graph(|g| g.sqrt(xx));
let r = with_graph(|g| g.recip(abs_x));
with_graph(|g| g.mul(self, r))
}
#[inline]
fn floor(self) -> Self {
with_graph(|g| {
if let Some(v) = g.node(self).as_f64() {
g.lit(v.floor())
} else {
panic!("floor() requires a literal expression")
}
})
}
#[inline]
fn ceil(self) -> Self {
with_graph(|g| {
if let Some(v) = g.node(self).as_f64() {
g.lit(v.ceil())
} else {
panic!("ceil() requires a literal expression")
}
})
}
#[inline]
fn round(self) -> Self {
with_graph(|g| {
if let Some(v) = g.node(self).as_f64() {
g.lit(v.round())
} else {
panic!("round() requires a literal expression")
}
})
}
#[inline]
fn exp(self) -> Self {
let log2_e = Self::from_f64(std::f64::consts::LOG2_E);
let scaled = with_graph(|g| g.mul(self, log2_e));
with_graph(|g| g.exp2(scaled))
}
#[inline]
fn ln(self) -> Self {
let ln_2 = Self::from_f64(std::f64::consts::LN_2);
let l = with_graph(|g| g.log2(self));
with_graph(|g| g.mul(l, ln_2))
}
#[inline]
fn powf(self, p: Self) -> Self {
let l = with_graph(|g| g.log2(self));
let pl = with_graph(|g| g.mul(p, l));
with_graph(|g| g.exp2(pl))
}
#[inline]
fn sinh(self) -> Self {
let half = Self::from_f64(0.5);
let ex = self.exp();
let neg_x = with_graph(|g| g.neg(self));
let enx = Scalar::exp(neg_x);
let diff = ex - enx;
with_graph(|g| g.mul(half, diff))
}
#[inline]
fn cosh(self) -> Self {
let half = Self::from_f64(0.5);
let ex = self.exp();
let neg_x = with_graph(|g| g.neg(self));
let enx = Scalar::exp(neg_x);
let sum = ex + enx;
with_graph(|g| g.mul(half, sum))
}
#[inline]
fn tanh(self) -> Self {
let s = self.sinh();
let c = self.cosh();
let rc = with_graph(|g| g.recip(c));
with_graph(|g| g.mul(s, rc))
}
#[inline]
fn acosh(self) -> Self {
let one = ExprId::ONE;
let xx = with_graph(|g| g.mul(self, self));
let diff = with_graph(|g| {
let neg_one = g.neg(one);
g.add(xx, neg_one)
});
let sq = with_graph(|g| g.sqrt(diff));
let sum = with_graph(|g| g.add(self, sq));
Scalar::ln(sum)
}
#[inline]
fn asinh(self) -> Self {
let one = ExprId::ONE;
let xx = with_graph(|g| g.mul(self, self));
let sum_inner = with_graph(|g| g.add(xx, one));
let sq = with_graph(|g| g.sqrt(sum_inner));
let sum = with_graph(|g| g.add(self, sq));
Scalar::ln(sum)
}
#[inline]
fn atanh(self) -> Self {
let half = Self::from_f64(0.5);
let one = ExprId::ONE;
let one_plus = with_graph(|g| g.add(one, self));
let neg_x = with_graph(|g| g.neg(self));
let one_minus = with_graph(|g| g.add(one, neg_x));
let ratio = one_plus / one_minus;
let l = Scalar::ln(ratio);
with_graph(|g| g.mul(half, l))
}
#[inline]
fn from_f64(v: f64) -> Self {
with_graph(|g| g.lit(v))
}
#[inline]
fn to_f64(self) -> f64 {
panic!("cannot evaluate symbolic ExprId to f64 — use ExprGraph::eval() instead")
}
#[inline]
fn from_i32(v: i32) -> Self {
with_graph(|g| g.lit(v as f64))
}
#[inline]
fn select(cond: Self, a: Self, b: Self) -> Self {
with_graph(|g| g.select(cond, a, b))
}
}
#[cfg(test)]
mod tests {
use crate::Scalar;
use super::{trace, ExprId};
#[test]
fn basic_arithmetic() {
let (g, result) = trace(|| {
let x = ExprId::from_f64(3.0);
let y = ExprId::from_f64(4.0);
x + y
});
let val = g.eval::<f64>(result, &[]);
assert!((val - 7.0).abs() < 1e-10);
}
#[test]
fn var_trace() {
let (g, result) = trace(|| {
let x: ExprId = Scalar::from_f64(0.0); x
});
assert_eq!(result, ExprId::ZERO);
assert_eq!(g.len(), 3); }
#[test]
fn constants_are_lits() {
let (_g, (half, pi)) = trace(|| {
let h = ExprId::from_f64(0.5);
let p = ExprId::from_f64(std::f64::consts::PI);
(h, p)
});
assert_ne!(half, pi);
}
}