use crate::{Node, Differentiable, Identifier, ops::{OneSub, Double, Mul}};
use num_traits::real::Real;
impl_unary!(
Rabbit<F: crate::buffers::Scalar>, |x| { x * (F::one() - x) }, |self| {
use crate::fmt::{PreWrap, Expr::*};
match self.0.to_expr() {
Zero | One => Zero,
Text(pw) => Text(PreWrap {
text: format!("{0} \u{2218} (1 - {0})", pw.to_safe_string('(', ')')),
needs_wrap: true,
})
}
}
);
impl<N, I> Differentiable<I> for Rabbit<N>
where
N: Clone + Differentiable<I>,
I: Identifier,
{
type Adjoint = Mul<OneSub<Double<N>>, N::Adjoint>;
fn adjoint(&self, target: I) -> Self::Adjoint {
OneSub(Double(self.0.clone())).mul(self.0.adjoint(target))
}
}
pub fn sigmoid<F: Real>(x: F) -> F {
if x >= num_traits::zero() {
let l: F = num_traits::one();
l / (l + (-x).exp())
} else {
let l: F = num_traits::one();
let z = x.exp();
return z / (l + z);
}
}
impl_unary!(
Sigmoid<F: Real>, |x| { sigmoid(x) }, |self| {
use crate::fmt::{PreWrap, Expr::*};
match self.0.to_expr() {
Zero => Text(PreWrap {
text: "\u{00BD}".to_string(),
needs_wrap: false,
}),
One => Text(PreWrap {
text: "\u{03C3}(1)".to_string(),
needs_wrap: false,
}),
Text(pw) => Text(PreWrap {
text: format!("\u{03C3}({})", pw),
needs_wrap: false,
})
}
}
);
impl<T, N> Differentiable<T> for Sigmoid<N>
where
T: Identifier,
N: Differentiable<T> + Clone,
{
type Adjoint = crate::ops::Mul<N::Adjoint, Rabbit<Self>>;
fn adjoint(&self, target: T) -> Self::Adjoint {
crate::ops::Mul(self.0.adjoint(target), Rabbit(self.clone()))
}
}