use candle_core::{Result, Tensor};
use candle_nn::{Linear, Module, VarBuilder};
use crate::tensor_utils::silu;
pub struct SiluMlp {
gate_proj: Linear,
up_proj: Linear,
down_proj: Linear,
}
impl SiluMlp {
pub fn load(hidden_size: usize, intermediate_size: usize, vb: VarBuilder) -> Result<Self> {
let gate_proj =
candle_nn::linear_no_bias(hidden_size, intermediate_size, vb.pp("gate_proj"))?;
let up_proj = candle_nn::linear_no_bias(hidden_size, intermediate_size, vb.pp("up_proj"))?;
let down_proj =
candle_nn::linear_no_bias(intermediate_size, hidden_size, vb.pp("down_proj"))?;
Ok(Self {
gate_proj,
up_proj,
down_proj,
})
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let gate = silu(&self.gate_proj.forward(x)?)?;
let up = self.up_proj.forward(x)?;
self.down_proj.forward(&(gate * up)?)
}
}
impl std::fmt::Debug for SiluMlp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SiluMlp").finish()
}
}