use candle_core::{Module, Tensor};
use candle_nn::{Linear, VarBuilder};
use crate::config::{Activation, MlpLayout, TransformerConfig};
use crate::error::Result;
pub struct Mlp {
variant: MlpVariant,
activation: Activation,
}
enum MlpVariant {
GatedSeparate {
gate_proj: Linear,
up_proj: Linear,
down_proj: Linear,
},
GatedFused {
gate_up_proj: Linear,
down_proj: Linear,
intermediate_size: usize,
},
Plain {
fc: Linear,
proj: Linear,
},
}
impl Mlp {
#[allow(clippy::needless_pass_by_value)] pub fn load(config: &TransformerConfig, vb: VarBuilder<'_>) -> Result<Self> {
let hidden = config.hidden_size;
let inter = config.intermediate_size;
let bias = config.mlp_bias;
let variant = match config.mlp_layout {
MlpLayout::GatedSeparate => {
let gate_proj = load_linear(hidden, inter, bias, vb.pp("gate_proj"))?;
let up_proj = load_linear(hidden, inter, bias, vb.pp("up_proj"))?;
let down_proj = load_linear(inter, hidden, bias, vb.pp("down_proj"))?;
MlpVariant::GatedSeparate {
gate_proj,
up_proj,
down_proj,
}
}
MlpLayout::GatedFused => {
let gate_up_proj = load_linear(hidden, 2 * inter, bias, vb.pp("gate_up_proj"))?;
let down_proj = load_linear(inter, hidden, bias, vb.pp("down_proj"))?;
MlpVariant::GatedFused {
gate_up_proj,
down_proj,
intermediate_size: inter,
}
}
MlpLayout::Plain => {
let fc = load_linear(hidden, inter, bias, vb.pp("c_fc"))?;
let proj = load_linear(inter, hidden, bias, vb.pp("c_proj"))?;
MlpVariant::Plain { fc, proj }
}
};
Ok(Self {
variant,
activation: config.activation,
})
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
match &self.variant {
MlpVariant::GatedSeparate {
gate_proj,
up_proj,
down_proj,
} => {
let gate = apply_activation(&gate_proj.forward(x)?, self.activation)?;
let up = up_proj.forward(x)?;
Ok(down_proj.forward(&(gate * up)?)?)
}
MlpVariant::GatedFused {
gate_up_proj,
down_proj,
intermediate_size,
} => {
let gate_up = gate_up_proj.forward(x)?;
let gate = gate_up.narrow(candle_core::D::Minus1, 0, *intermediate_size)?;
let up = gate_up.narrow(
candle_core::D::Minus1,
*intermediate_size,
*intermediate_size,
)?;
let gate = apply_activation(&gate, self.activation)?;
Ok(down_proj.forward(&(gate * up)?)?)
}
MlpVariant::Plain { fc, proj } => {
let hidden = apply_activation(&fc.forward(x)?, self.activation)?;
Ok(proj.forward(&hidden)?)
}
}
}
}
fn apply_activation(x: &Tensor, activation: Activation) -> Result<Tensor> {
match activation {
Activation::Silu => Ok(candle_nn::ops::silu(x)?),
Activation::Gelu => Ok(x.gelu_erf()?),
Activation::GeluApprox => Ok(x.gelu()?),
}
}
#[allow(clippy::needless_pass_by_value)] fn load_linear(in_dim: usize, out_dim: usize, bias: bool, vb: VarBuilder<'_>) -> Result<Linear> {
if bias {
Ok(candle_nn::linear(in_dim, out_dim, vb)?)
} else {
Ok(candle_nn::linear_no_bias(in_dim, out_dim, vb)?)
}
}