use crate::common::matrix::DMat;
use crate::{activation::ActivationFunction, error::NetworkError};
use serde::{Deserialize, Serialize};
use typetag;
use super::{he_initialization, ActivationFunctionClone};
#[derive(Serialize, Deserialize, Clone)]
struct GELUActivation;
pub struct GELU;
impl GELU {
fn new() -> Self {
Self {}
}
pub fn build() -> Result<Box<dyn ActivationFunction>, NetworkError> {
Ok(Box::new(GELUActivation {}))
}
}
impl Default for GELU {
fn default() -> Self {
Self::new()
}
}
#[typetag::serde]
impl ActivationFunction for GELUActivation {
fn forward(&self, input: &mut DMat) {
input.apply(|x| 0.5 * x * (1.0 + special::Primitive::erf(x / (2.0_f32.sqrt()))));
}
fn backward(&self, d_output: &DMat, input: &mut DMat, _output: &DMat) {
input.apply(|x| {
let cdf = 0.5 * (1.0 + special::Primitive::erf(x / (2.0_f32.sqrt())));
let pdf = (-(x * x) / 2.0).exp() / (2.0 * std::f32::consts::PI).sqrt();
cdf + x * pdf
});
input.mul_elem(d_output);
}
fn weight_initialization_factor(&self) -> fn(usize, usize) -> f32 {
he_initialization
}
}
impl ActivationFunctionClone for GELUActivation {
fn clone_box(&self) -> Box<dyn ActivationFunction> {
Box::new(self.clone())
}
}
#[cfg(test)]
mod gelu_tests {
use super::*;
use crate::{common::matrix::DMat, util::equal_approx};
#[test]
fn test_gelu_forward() {
let mut input = DMat::new(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let gelu = GELU::build().unwrap();
gelu.forward(&mut input);
let expected = DMat::new(2, 3, &[0.841_3, 1.9545, 2.9964, 3.9999, 4.9999, 5.9999]);
assert!(equal_approx(&input, &expected, 1e-3), "GELU forward pass failed");
}
#[test]
fn test_gelu_backward() {
let mut input = DMat::new(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let d_output = DMat::new(2, 3, &[0.5, 1.0, 0.7, 0.2, 0.3, 0.1]);
let output: DMat = DMat::new(2, 3, &[0.0; 6]);
let gelu = GELU::build().unwrap();
gelu.backward(&d_output, &mut input, &output);
let expected = DMat::new(2, 3, &[0.541658, 1.085232, 0.708362, 0.200101, 0.300002, 0.100000]);
assert!(equal_approx(&input, &expected, 1e-3), "GELU backward pass failed");
}
#[test]
fn test_gelu_weight_initialization() {
let gelu = GELU::build().unwrap();
let factor = gelu.weight_initialization_factor()(2, 3);
assert_eq!(factor, 0.8164966, "GELU weight initialization factor should be 0.8164966");
}
#[test]
fn test_gelu_clone() {
let gelu = GELU::build().unwrap();
let _cloned_gelu = gelu.clone();
}
}