use crate::activations::Activation;
use crate::error::Result;
use crate::layers::Layer;
use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand, Zip};
use scirs2_core::numeric::{Float, NumAssign};
use std::fmt::Debug;
#[derive(Debug, Clone, Copy)]
pub struct GELU {
fast: bool,
}
impl GELU {
pub fn new() -> Self {
Self { fast: false }
}
pub fn fast() -> Self {
Self { fast: true }
}
}
impl Default for GELU {
fn default() -> Self {
Self::new()
}
}
impl<F: Float + Debug + NumAssign> Activation<F> for GELU {
fn forward(
&self,
input: &Array<F, scirs2_core::ndarray::IxDyn>,
) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
let mut output = input.clone();
if self.fast {
let sqrt_2_over_pi =
F::from(0.7978845608028654).expect("Failed to convert constant to float"); let coeff = F::from(0.044715).expect("Failed to convert constant to float");
let half = F::from(0.5).expect("Failed to convert constant to float");
let one = F::one();
Zip::from(&mut output).for_each(|x| {
let x3 = *x * *x * *x;
let inner = sqrt_2_over_pi * (*x + coeff * x3);
*x = half * *x * (one + inner.tanh());
});
} else {
let sqrt_pi_over_2 =
F::from(1.2533141373155).expect("Failed to convert constant to float"); let coeff = F::from(0.044715).expect("Failed to convert constant to float");
let half = F::from(0.5).expect("Failed to convert constant to float");
let one = F::one();
Zip::from(&mut output).for_each(|x| {
let x2 = *x * *x;
let inner = sqrt_pi_over_2 * *x * (one + coeff * x2);
*x = half * *x * (one + inner.tanh());
});
}
Ok(output)
}
fn backward(
&self,
grad_output: &Array<F, scirs2_core::ndarray::IxDyn>,
input: &Array<F, scirs2_core::ndarray::IxDyn>,
) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
let mut grad_input = Array::zeros(grad_output.raw_dim());
if self.fast {
let sqrt_2_over_pi =
F::from(0.7978845608028654).expect("Failed to convert constant to float"); let coeff = F::from(0.044715).expect("Failed to convert constant to float");
let half = F::from(0.5).expect("Failed to convert constant to float");
let one = F::one();
let three = F::from(3.0).expect("Failed to convert constant to float");
Zip::from(&mut grad_input)
.and(grad_output)
.and(input)
.for_each(|grad_in, &grad_out, &x| {
let x2 = x * x;
let x3 = x2 * x;
let inner = sqrt_2_over_pi * (x + coeff * x3);
let tanh_inner = inner.tanh();
let sech_sq = one - tanh_inner * tanh_inner;
let d_inner_dx = sqrt_2_over_pi * (one + three * coeff * x2);
let dgelu_dx = half * (one + tanh_inner) + half * x * sech_sq * d_inner_dx;
*grad_in = grad_out * dgelu_dx;
});
} else {
let sqrt_pi_over_2 =
F::from(1.2533141373155).expect("Failed to convert constant to float"); let coeff = F::from(0.044715).expect("Failed to convert constant to float");
let half = F::from(0.5).expect("Failed to convert constant to float");
let one = F::one();
let three = F::from(3.0).expect("Failed to convert constant to float");
Zip::from(&mut grad_input)
.and(grad_output)
.and(input)
.for_each(|grad_in, &grad_out, &x| {
let x2 = x * x;
let inner = sqrt_pi_over_2 * x * (one + coeff * x2);
let tanh_inner = inner.tanh();
let sech_sq = one - tanh_inner * tanh_inner;
let d_inner_dx = sqrt_pi_over_2 * (one + three * coeff * x2);
let dgelu_dx = half * (one + tanh_inner) + half * x * sech_sq * d_inner_dx;
*grad_in = grad_out * dgelu_dx;
});
}
Ok(grad_input)
}
}
impl<F: Float + Debug + ScalarOperand + NumAssign> Layer<F> for GELU {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
<Self as Activation<F>>::forward(self, input)
}
fn backward(
&self,
input: &Array<F, IxDyn>,
grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
<Self as Activation<F>>::backward(self, grad_output, input)
}
fn update(&mut self, learningrate: F) -> Result<()> {
Ok(())
}
}