use ndarray::{Array, ArrayBase, Axis, Data, Dimension, RemoveAxis, ScalarOperand};
use num_traits::{Float, One, Zero};
pub fn relu<T>(args: T) -> T
where
T: PartialOrd + Zero,
{
if args > T::zero() { args } else { T::zero() }
}
pub fn relu_derivative<T>(args: T) -> T
where
T: PartialOrd + One + Zero,
{
if args > T::zero() {
T::one()
} else {
T::zero()
}
}
pub fn sigmoid<T>(args: T) -> T
where
T: Float,
{
(T::one() + args.neg().exp()).recip()
}
pub fn sigmoid_derivative<T>(args: T) -> T
where
T: Float,
{
let s = sigmoid(args);
s * (T::one() - s)
}
pub fn softmax<A, S, D>(args: &ArrayBase<S, D, A>) -> Array<A, D>
where
A: Float + ScalarOperand,
D: Dimension,
S: Data<Elem = A>,
{
let e = args.exp();
&e / e.sum()
}
pub fn softmax_axis<A, S, D>(args: &ArrayBase<S, D, A>, axis: usize) -> Array<A, D>
where
A: Float + ScalarOperand,
D: RemoveAxis,
S: Data<Elem = A>,
{
let axis = Axis(axis);
let e = args.exp();
&e / &e.sum_axis(axis)
}
pub fn tanh<T>(args: T) -> T
where
T: Float,
{
args.tanh()
}
pub fn tanh_derivative<T>(args: T) -> T
where
T: Float,
{
let t = tanh(args);
T::one() - t * t
}
pub const fn linear<T>(x: T) -> T {
x
}
pub fn linear_derivative<T>() -> T
where
T: One,
{
<T>::one()
}
pub fn heavyside<T>(x: T) -> T
where
T: One + PartialOrd + Zero,
{
if x > T::zero() { T::one() } else { T::zero() }
}