#![warn(missing_docs)]
use std::error::Error;
use std::str::FromStr;
use crate::constants::{E, PI};
use crate::{Tensor, TensorElement};
use rayon::prelude::*;
pub fn ReLU<'a, 'b, T>(x: &Tensor<'a, T>) -> Tensor<'b, T>
where
T: TensorElement,
<T as FromStr>::Err: Error,
Vec<T>: FromParallelIterator<T>,
{
let data: Vec<T> = x
.data
.par_iter()
.map(|&e| if e >= T::zero() { e } else { T::zero() })
.collect();
Tensor::new(data, x.shape.clone()).unwrap()
}
pub fn PReLU<'a, 'b, T>(x: &Tensor<'a, T>, alpha: T) -> Tensor<'b, T>
where
T: TensorElement,
<T as FromStr>::Err: Error,
Vec<T>: FromParallelIterator<T>,
{
let data: Vec<T> = x
.data
.par_iter()
.map(|&e| if e >= T::zero() { e } else { e * alpha })
.collect();
Tensor::new(data, x.shape.clone()).unwrap()
}
pub fn Sigmoid<'a, T>(x: &Tensor<'a, T>) -> Tensor<'a, T>
where
T: TensorElement,
<T as FromStr>::Err: Error,
Vec<T>: FromParallelIterator<T>,
{
let data: Vec<T> = x
.data
.par_iter()
.map(|&e| T::from(E).unwrap().powf(e) / (T::from(E).unwrap().powf(e) + T::one()))
.collect();
Tensor::new(data, x.shape.clone()).unwrap()
}
pub fn GeLU<'a, T>(x: &Tensor<'a, T>) -> Tensor<'a, T>
where
T: TensorElement,
<T as FromStr>::Err: Error,
{
let lhs = x.mul_val(T::from(0.5).unwrap());
let inner = x
.pow(3)
.mul_val(T::from(0.044715).unwrap())
.add(x)
.unwrap()
.mul_val((T::from(2).unwrap() / T::from(PI).unwrap()).sqrt());
let result = lhs.mul(&inner.tanh().add_val(T::from(1).unwrap())).unwrap();
return result;
}