syntaxdot-transformers 0.1.0

Transformer architectures, such as BERT
use std::f64;

use tch::nn::Module;
use tch::Tensor;

pub trait Activation: Clone + Module {}

#[derive(Clone, Copy, Debug)]
pub struct GELUNew;

impl Module for GELUNew {
    fn forward(&self, input: &Tensor) -> Tensor {
        0.5 * input
            * (1.0
                + Tensor::tanh(
                    &((2. / f64::consts::PI).sqrt() * (input + 0.044715 * input.pow(3.0))),
                ))
    }
}

impl Activation for GELUNew {}

#[derive(Clone, Copy, Debug)]
pub struct GELU;

impl Module for GELU {
    fn forward(&self, input: &Tensor) -> Tensor {
        input.gelu()
    }
}

impl Activation for GELU {}

#[cfg(test)]
mod tests {
    use std::convert::TryInto;

    use approx::assert_abs_diff_eq;
    use ndarray::{array, ArrayD};
    use tch::nn::Module;
    use tch::Tensor;

    use super::GELUNew;

    #[test]
    fn gelu_new_returns_correct_values() {
        let gelu_new = GELUNew;
        let activations: ArrayD<f32> = (&gelu_new
            .forward(&Tensor::of_slice(&[-1., -0.5, 0., 0.5, 1.])))
            .try_into()
            .unwrap();
        assert_abs_diff_eq!(
            activations,
            array![-0.1588, -0.1543, 0.0000, 0.3457, 0.8412].into_dyn(),
            epsilon = 1e-4
        );
    }
}