oxidized_transformers/layers/activation.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
use candle_core::ModuleT;
use candle_nn::{Activation as CandleActivation, VarBuilder};
use serde::{Deserialize, Serialize};
use crate::error::BoxedError;
use crate::layers::build_module::BuildModule;
/// Activation functions.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(rename_all = "snake_case")]
pub enum Activation {
/// Gausian Error Linear Unit.
///
/// See [Hendrycks and Gimpel, 2016](https://arxiv.org/abs/1606.08415).
Gelu,
/// Gausian Error Linear Unit approximation.
///
/// See [Hendrycks and Gimpel, 2016](https://arxiv.org/abs/1606.08415).
GeluNew,
/// Rectified Linear Unit.
///
/// See [Fukushima, 1969](https://ieeexplore.ieee.org/document/4082265).
Relu,
/// Sigmoid Linear Unit.
///
/// See [Hendrycks and Gimpel, 2016](https://arxiv.org/abs/1606.08415).
Silu,
}
impl BuildModule for Activation {
fn build(&self, _vb: VarBuilder) -> Result<Box<dyn ModuleT>, BoxedError> {
use Activation::*;
Ok(match self {
Gelu => Box::new(CandleActivation::Gelu),
GeluNew => Box::new(CandleActivation::NewGelu),
Relu => Box::new(CandleActivation::Relu),
Silu => Box::new(CandleActivation::Silu),
})
}
}