use {
super::tape::Tape,
super::variable::Variable,
std::f64::consts::PI,
std::fmt::Display,
std::iter::{Product, Sum},
std::ops::{Add, Div, Mul, Neg, Sub},
};
impl<'v> Neg for Variable<'v> {
type Output = Self;
fn neg(self) -> Self::Output {
self * -1.0
}
}
impl<'v> Add<Variable<'v>> for Variable<'v> {
type Output = Variable<'v>;
#[inline]
fn add(self, other: Variable<'v>) -> Self::Output {
assert_eq!(self.tape as *const Tape, other.tape as *const Tape);
Variable {
tape: self.tape,
value: self.value + other.value,
index: self.tape.push2(self.index, 1.0, other.index, 1.0),
}
}
}
impl<'v> Add<f64> for Variable<'v> {
type Output = Variable<'v>;
#[inline]
fn add(self, other: f64) -> Self::Output {
Variable {
tape: self.tape,
value: self.value + other,
index: self.tape.push2(self.index, 1.0, self.index, 0.0),
}
}
}
impl<'v> Add<Variable<'v>> for f64 {
type Output = Variable<'v>;
#[inline]
fn add(self, other: Variable<'v>) -> Self::Output {
other + self
}
}
impl<'v> Sub<Variable<'v>> for Variable<'v> {
type Output = Variable<'v>;
#[inline]
fn sub(self, other: Variable<'v>) -> Self::Output {
assert_eq!(self.tape as *const Tape, other.tape as *const Tape);
self.add(other.neg())
}
}
impl<'v> Sub<f64> for Variable<'v> {
type Output = Variable<'v>;
#[inline]
fn sub(self, other: f64) -> Self::Output {
self.add(other.neg())
}
}
impl<'v> Sub<Variable<'v>> for f64 {
type Output = Variable<'v>;
#[inline]
fn sub(self, other: Variable<'v>) -> Self::Output {
Variable {
tape: other.tape,
value: self - other.value,
index: other.tape.push2(other.index, 0.0, other.index, -1.0),
}
}
}
impl<'v> Mul<Variable<'v>> for Variable<'v> {
type Output = Variable<'v>;
#[inline]
fn mul(self, other: Variable<'v>) -> Self::Output {
assert_eq!(self.tape as *const Tape, other.tape as *const Tape);
Variable {
tape: self.tape,
value: self.value * other.value,
index: self
.tape
.push2(self.index, other.value, other.index, self.value),
}
}
}
impl<'v> Mul<f64> for Variable<'v> {
type Output = Variable<'v>;
#[inline]
fn mul(self, other: f64) -> Self::Output {
Variable {
tape: self.tape,
value: self.value * other,
index: self.tape.push2(self.index, other, self.index, 0.0),
}
}
}
impl<'v> Mul<Variable<'v>> for f64 {
type Output = Variable<'v>;
#[inline]
fn mul(self, other: Variable<'v>) -> Self::Output {
other * self
}
}
impl<'v> Div<Variable<'v>> for Variable<'v> {
type Output = Variable<'v>;
#[inline]
fn div(self, other: Variable<'v>) -> Self::Output {
assert_eq!(self.tape as *const Tape, other.tape as *const Tape);
self * other.recip()
}
}
impl<'v> Div<f64> for Variable<'v> {
type Output = Variable<'v>;
#[inline]
fn div(self, other: f64) -> Self::Output {
self * other.recip()
}
}
impl<'v> Div<Variable<'v>> for f64 {
type Output = Variable<'v>;
#[inline]
fn div(self, other: Variable<'v>) -> Self::Output {
Variable {
tape: other.tape,
value: self / other.value,
index: other.tape.push2(
other.index,
0.0,
other.index,
-self / (other.value * other.value),
),
}
}
}
impl<'v> Sum<Variable<'v>> for Variable<'v> {
#[inline]
fn sum<I: Iterator<Item = Variable<'v>>>(iter: I) -> Self {
iter.reduce(|x, y| x + y)
.expect("Cannot call sum() since vector is empty. Exiting ...")
}
}
impl<'v> Product<Variable<'v>> for Variable<'v> {
#[inline]
fn product<I: Iterator<Item = Variable<'v>>>(iter: I) -> Self {
iter.reduce(|x, y| x * y)
.expect("Cannot call product() since vector is empty. Exiting ...")
}
}
impl<'v> Display for Variable<'v> {
#[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.value)
}
}
pub trait Powf<T> {
type Output;
fn powf(&self, other: T) -> Self::Output;
}
impl<'v> Powf<Variable<'v>> for Variable<'v> {
type Output = Variable<'v>;
#[inline]
fn powf(&self, other: Variable<'v>) -> Self::Output {
assert_eq!(self.tape as *const Tape, other.tape as *const Tape);
Self::Output {
tape: self.tape,
value: self.value.powf(other.value),
index: self.tape.push2(
self.index,
other.value * f64::powf(self.value, other.value - 1.),
other.index,
f64::powf(self.value, other.value) * f64::ln(self.value),
),
}
}
}
impl<'v> Powf<f64> for Variable<'v> {
type Output = Variable<'v>;
#[inline]
fn powf(&self, n: f64) -> Self::Output {
Self::Output {
tape: self.tape,
value: f64::powf(self.value, n),
index: self.tape.push2(
self.index,
n * f64::powf(self.value, n - 1.0),
self.index,
0.0,
),
}
}
}
impl<'v> Powf<Variable<'v>> for f64 {
type Output = Variable<'v>;
#[inline]
fn powf(&self, other: Variable<'v>) -> Self::Output {
Self::Output {
tape: other.tape,
value: f64::powf(*self, other.value),
index: other.tape.push2(
other.index,
0.,
other.index,
other.value * f64::powf(*self, other.value - 1.0),
),
}
}
}
impl<'v> Variable<'v> {
pub fn abs(self) -> Self {
Variable {
tape: self.tape,
value: self.value.abs(),
index: self.tape.push1(self.index, self.value.signum()),
}
}
pub fn acos(self) -> Self {
Variable {
tape: self.tape,
value: self.value.acos(),
index: self.tape.push1(
self.index,
((1.0 - self.value.powi(2)).sqrt()).recip().neg(),
),
}
}
pub fn acosh(self) -> Self {
Variable {
tape: self.tape,
value: self.value.acosh(),
index: self.tape.push1(
self.index,
((self.value - 1.0).sqrt() * (self.value + 1.0).sqrt()).recip(),
),
}
}
pub fn asin(self) -> Self {
Variable {
tape: self.tape,
value: self.value.asin(),
index: self.tape.push1(
self.index,
if (self.value > -1.0) && (self.value < 1.0) {
((1.0 - self.value.powi(2)).sqrt()).recip()
} else {
f64::NAN
},
),
}
}
pub fn asinh(self) -> Self {
Variable {
tape: self.tape,
value: self.value.asinh(),
index: self
.tape
.push1(self.index, ((1.0 + self.value.powi(2)).sqrt()).recip()),
}
}
pub fn atan(self) -> Self {
Variable {
tape: self.tape,
value: self.value.atan(),
index: self
.tape
.push1(self.index, (1.0 + self.value.powi(2)).recip()),
}
}
pub fn atanh(self) -> Self {
Variable {
tape: self.tape,
value: self.value.atanh(),
index: self
.tape
.push1(self.index, (1.0 - self.value.powi(2)).recip()),
}
}
pub fn cbrt(self) -> Self {
Variable {
tape: self.tape,
value: self.value.cbrt(),
index: self
.tape
.push1(self.index, (3.0 * self.value.powf(2.0 / 3.0)).recip()),
}
}
pub fn cos(self) -> Self {
Variable {
tape: self.tape,
value: self.value.cos(),
index: self.tape.push1(self.index, self.value.sin().neg()),
}
}
pub fn cosh(self) -> Self {
Variable {
tape: self.tape,
value: self.value.cosh(),
index: self.tape.push1(self.index, self.value.sinh()),
}
}
pub fn erfc(self) -> Self {
use statrs::function::erf::erfc;
Variable {
tape: self.tape,
value: erfc(self.value),
index: self.tape.push1(
self.index,
(2.0 * self.value.powi(2).neg().exp()).neg() / PI.sqrt(),
),
}
}
pub fn exp(self) -> Self {
Variable {
tape: self.tape,
value: self.value.exp(),
index: self.tape.push1(self.index, self.value.exp()),
}
}
pub fn exp2(self) -> Self {
Variable {
tape: self.tape,
value: self.value.exp2(),
index: self
.tape
.push1(self.index, 2_f64.powf(self.value) * 2_f64.ln()),
}
}
pub fn exp_m1(self) -> Self {
Variable {
tape: self.tape,
value: self.value.exp_m1(),
index: self.tape.push1(self.index, self.value.exp()),
}
}
pub fn ln(self) -> Self {
Variable {
tape: self.tape,
value: self.value.ln(),
index: self.tape.push1(self.index, self.value.recip()),
}
}
pub fn ln_1p(self) -> Self {
Variable {
tape: self.tape,
value: self.value.ln_1p(),
index: self.tape.push1(self.index, (1.0 + self.value).recip()),
}
}
pub fn log10(self) -> Self {
Variable {
tape: self.tape,
value: self.value.log10(),
index: self.tape.push1(self.index, self.value.recip()),
}
}
pub fn log2(self) -> Self {
Variable {
tape: self.tape,
value: self.value.log2(),
index: self.tape.push1(self.index, self.value.recip()),
}
}
pub fn recip(self) -> Self {
Variable {
tape: self.tape,
value: self.value.recip(),
index: self
.tape
.push1(self.index, self.value.powi(2).recip().neg()),
}
}
pub fn sin(self) -> Self {
Variable {
tape: self.tape,
value: self.value.sin(),
index: self.tape.push1(self.index, self.value.cos()),
}
}
pub fn sinh(self) -> Self {
Variable {
tape: self.tape,
value: self.value.sinh(),
index: self.tape.push1(self.index, self.value.cosh()),
}
}
pub fn sqrt(self) -> Self {
Variable {
tape: self.tape,
value: self.value.sqrt(),
index: self
.tape
.push1(self.index, (2.0 * self.value.sqrt()).recip()),
}
}
pub fn tan(self) -> Self {
Variable {
tape: self.tape,
value: self.value.tan(),
index: self
.tape
.push1(self.index, (self.value.cos().powi(2)).recip()),
}
}
pub fn tanh(self) -> Self {
Variable {
tape: self.tape,
value: self.value.tanh(),
index: self
.tape
.push1(self.index, (self.value.cosh().powi(2)).recip()),
}
}
}