use candle_core::ModuleT;
use candle_nn::{Activation as CandleActivation, VarBuilder};
use serde::{Deserialize, Serialize};
use crate::error::BoxedError;
use crate::layers::build_module::BuildModule;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(rename_all = "snake_case")]
pub enum Activation {
Gelu,
GeluNew,
Relu,
Silu,
}
impl BuildModule for Activation {
fn build(&self, _vb: VarBuilder) -> Result<Box<dyn ModuleT>, BoxedError> {
use Activation::*;
Ok(match self {
Gelu => Box::new(CandleActivation::Gelu),
GeluNew => Box::new(CandleActivation::NewGelu),
Relu => Box::new(CandleActivation::Relu),
Silu => Box::new(CandleActivation::Silu),
})
}
}