use std::cmp::Ordering;
use num_traits::{Float, Zero};
use crate::{CmpExt, NumExt, core::Array};
impl<const D: usize, T: Zero + Clone + CmpExt> Array<D, T> {
pub fn relu(&self) -> Self {
self.map(|x| {
let zero = T::zero();
match x.cmp_ext(&zero) {
Ordering::Greater => x.clone(),
_ => zero,
}
})
}
}
impl<const D: usize> Array<D, f32> {
pub fn gelu(&self) -> Self {
self.map(|x| 0.5 * x * (1.0 + (0.7978845608 * (x + 0.044715 * x.powi(3))).tanh()))
}
}
impl<const D: usize> Array<D, f64> {
pub fn gelu(&self) -> Self {
self.map(|x| 0.5 * x * (1.0 + (0.7978845608 * (x + 0.044715 * x.powi(3))).tanh()))
}
}
impl<const D: usize, T: NumExt + Float> Array<D, T> {
pub fn sigmoid(&self) -> Self {
(-self).exp().add_scalar(T::one()).recip()
}
pub fn softmax(&self) -> Array<D, T> {
let a = self.max_axis((D - 1) as isize);
let a = (self - &a).exp();
let a_t = a.sum_axis((D - 1) as isize);
&a / &a_t
}
pub fn rms_norm(&self, eps: T) -> Self {
let a = self.map(|x| x.powi(2));
let a_t = a.sum_axis((D - 1) as isize);
&a / &a_t.sqrt().add_scalar(eps)
}
}