use num_traits::Float;
use crate::kernels;
pub const UNUSED: u32 = u32::MAX;
#[repr(u8)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum OpCode {
Input,
Const,
Add,
Sub,
Mul,
Div,
Rem,
Powf,
Atan2,
Hypot,
Max,
Min,
Neg,
Recip,
Sqrt,
Cbrt,
Powi,
Exp,
Exp2,
ExpM1,
Ln,
Log2,
Log10,
Ln1p,
Sin,
Cos,
Tan,
Asin,
Acos,
Atan,
Sinh,
Cosh,
Tanh,
Asinh,
Acosh,
Atanh,
Abs,
Signum,
Floor,
Ceil,
Round,
Trunc,
Fract,
Custom,
}
#[inline]
#[must_use]
pub fn is_nonsmooth(op: OpCode) -> bool {
matches!(
op,
OpCode::Abs
| OpCode::Min
| OpCode::Max
| OpCode::Signum
| OpCode::Floor
| OpCode::Ceil
| OpCode::Round
| OpCode::Trunc
| OpCode::Fract
)
}
#[inline]
#[must_use]
pub fn has_nontrivial_subdifferential(op: OpCode) -> bool {
matches!(op, OpCode::Abs | OpCode::Min | OpCode::Max)
}
#[inline]
pub fn forced_reverse_partials<T: Float>(op: OpCode, a: T, b: T, r: T, sign: i8) -> (T, T) {
let zero = T::zero();
let one = T::one();
match op {
OpCode::Abs => {
if sign >= 0 {
(one, zero)
} else {
(-one, zero)
}
}
OpCode::Max | OpCode::Min => {
if sign >= 0 {
(one, zero)
} else {
(zero, one)
}
}
OpCode::Signum | OpCode::Floor | OpCode::Ceil | OpCode::Round | OpCode::Trunc => {
(zero, zero)
}
OpCode::Fract => {
(one, zero)
}
_ => reverse_partials(op, a, b, r),
}
}
#[inline]
pub fn eval_forward<T: Float>(op: OpCode, a: T, b: T) -> T {
match op {
OpCode::Input | OpCode::Const => {
unreachable!("Input/Const should not be re-evaluated via eval_forward")
}
OpCode::Add => a + b,
OpCode::Sub => a - b,
OpCode::Mul => a * b,
OpCode::Div => a / b,
OpCode::Rem => a % b,
OpCode::Powf => a.powf(b),
OpCode::Atan2 => a.atan2(b),
OpCode::Hypot => a.hypot(b),
OpCode::Max => {
if a >= b || b.is_nan() {
a
} else {
b
}
}
OpCode::Min => {
if a <= b || b.is_nan() {
a
} else {
b
}
}
OpCode::Neg => -a,
OpCode::Recip => a.recip(),
OpCode::Sqrt => a.sqrt(),
OpCode::Cbrt => a.cbrt(),
OpCode::Powi => {
let exp = powi_exp_decode_raw(b.to_u32().unwrap_or(0));
a.powi(exp)
}
OpCode::Exp => a.exp(),
OpCode::Exp2 => a.exp2(),
OpCode::ExpM1 => a.exp_m1(),
OpCode::Ln => a.ln(),
OpCode::Log2 => a.log2(),
OpCode::Log10 => a.log10(),
OpCode::Ln1p => a.ln_1p(),
OpCode::Sin => a.sin(),
OpCode::Cos => a.cos(),
OpCode::Tan => a.tan(),
OpCode::Asin => a.asin(),
OpCode::Acos => a.acos(),
OpCode::Atan => a.atan(),
OpCode::Sinh => a.sinh(),
OpCode::Cosh => a.cosh(),
OpCode::Tanh => a.tanh(),
OpCode::Asinh => a.asinh(),
OpCode::Acosh => a.acosh(),
OpCode::Atanh => a.atanh(),
OpCode::Abs => a.abs(),
OpCode::Signum => a.signum(),
OpCode::Floor => a.floor(),
OpCode::Ceil => a.ceil(),
OpCode::Round => a.round(),
OpCode::Trunc => a.trunc(),
OpCode::Fract => a.fract(),
OpCode::Custom => unreachable!("Custom ops are dispatched separately in the tape"),
}
}
#[inline]
pub fn reverse_partials<T: Float>(op: OpCode, a: T, b: T, r: T) -> (T, T) {
let zero = T::zero();
let one = T::one();
match op {
OpCode::Input | OpCode::Const => (zero, zero),
OpCode::Add => (one, one),
OpCode::Sub => (one, -one),
OpCode::Mul => (b, a),
OpCode::Div => {
let inv = one / b;
(inv, -r * inv)
}
OpCode::Rem => (one, -(a / b).trunc()),
OpCode::Powf => {
if b == zero {
let db = if a > zero { a.ln() } else { zero };
(zero, db)
} else {
let da = if a == zero || r == zero || !a.is_finite() || !r.is_finite() {
b * a.powf(b - one)
} else {
b * r / a
};
let db = if r == zero || a <= zero {
zero
} else {
r * a.ln()
};
(da, db)
}
}
OpCode::Atan2 => kernels::atan2_partials(a, b),
OpCode::Hypot => kernels::hypot_partials(a, b, r),
OpCode::Max => {
if a >= b || b.is_nan() {
(one, zero)
} else {
(zero, one)
}
}
OpCode::Min => {
if a <= b || b.is_nan() {
(one, zero)
} else {
(zero, one)
}
}
OpCode::Neg => (-one, zero),
OpCode::Recip => {
let inv = one / a;
(-inv * inv, zero)
}
OpCode::Sqrt => {
let two = one + one;
(one / (two * r), zero)
}
OpCode::Cbrt => {
let three = T::from(3.0).unwrap();
(one / (three * r * r), zero)
}
OpCode::Powi => {
let exp = powi_exp_decode_raw(b.to_u32().unwrap_or(0));
if exp == 0 {
(zero, zero) } else if exp == i32::MIN {
let n = T::from(exp).unwrap();
(n * r / a, zero)
} else {
let n = T::from(exp).unwrap();
(n * a.powi(exp - 1), zero)
}
}
OpCode::Exp => (r, zero), OpCode::Exp2 => (r * T::ln(T::from(2.0).unwrap()), zero),
OpCode::ExpM1 => (r + one, zero), OpCode::Ln => {
if a >= zero {
(one / a, zero)
} else {
(T::nan(), zero)
}
}
OpCode::Log2 => {
if a >= zero {
(one / (a * T::ln(T::from(2.0).unwrap())), zero)
} else {
(T::nan(), zero)
}
}
OpCode::Log10 => {
if a >= zero {
(one / (a * T::ln(T::from(10.0).unwrap())), zero)
} else {
(T::nan(), zero)
}
}
OpCode::Ln1p => {
if a >= -one {
(one / (one + a), zero)
} else {
(T::nan(), zero)
}
}
OpCode::Sin => (a.cos(), zero),
OpCode::Cos => (-a.sin(), zero),
OpCode::Tan => {
let c = a.cos();
(one / (c * c), zero)
}
OpCode::Asin => (one / ((one - a) * (one + a)).sqrt(), zero),
OpCode::Acos => (-one / ((one - a) * (one + a)).sqrt(), zero),
OpCode::Atan => (kernels::atan_deriv(a), zero),
OpCode::Sinh => (a.cosh(), zero),
OpCode::Cosh => (a.sinh(), zero),
OpCode::Tanh => {
let c = a.cosh();
(one / (c * c), zero)
}
OpCode::Asinh => (kernels::asinh_deriv(a), zero),
OpCode::Acosh => (kernels::acosh_deriv(a), zero),
OpCode::Atanh => {
if a >= -one && a <= one {
(one / ((one - a) * (one + a)), zero)
} else {
(T::nan(), zero)
}
}
OpCode::Abs => {
if a == zero {
(zero, zero)
} else {
(a.signum(), zero)
}
}
OpCode::Signum | OpCode::Floor | OpCode::Ceil | OpCode::Round | OpCode::Trunc => {
(zero, zero)
}
OpCode::Fract => (one, zero),
OpCode::Custom => unreachable!("Custom ops are dispatched separately in the tape"),
}
}
#[inline]
#[must_use]
pub fn powi_exp_decode_raw(b_idx: u32) -> i32 {
b_idx as i32
}
#[inline]
#[must_use]
pub fn powi_exp_encode(exp: i32) -> u32 {
exp as u32
}