cake_core/models/common/
mlp.rs1use candle_core::{Result, Tensor, D};
2use candle_nn::{Linear, Module, VarBuilder};
3
4#[allow(clippy::upper_case_acronyms)]
6#[derive(Debug, Clone)]
7pub struct MLP {
8 gate_up_proj: Linear,
9 down_proj: Linear,
10 intermediate_size: usize,
11}
12
13impl MLP {
14 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
16 let fused = self.gate_up_proj.forward(x)?;
17 let gate = fused.narrow(D::Minus1, 0, self.intermediate_size)?;
18 let up = fused.narrow(D::Minus1, self.intermediate_size, self.intermediate_size)?;
19 let x = (candle_nn::ops::silu(&gate)? * up)?;
20 self.down_proj.forward(&x)
21 }
22
23 pub fn load(vb: VarBuilder, cfg: &super::Config) -> Result<Self> {
25 let h_size = cfg.hidden_size;
26 let i_size = cfg.intermediate_size;
27
28 let gate_w = vb.pp("gate_proj").get((i_size, h_size), "weight")?;
30 let up_w = vb.pp("up_proj").get((i_size, h_size), "weight")?;
31 let fused_w = Tensor::cat(&[&gate_w, &up_w], 0)?;
32 let gate_up_proj = Linear::new(fused_w, None);
33
34 let down_w = vb.pp("down_proj").get((h_size, i_size), "weight")?;
35 let down_proj = Linear::new(down_w, None);
36
37 Ok(Self {
38 gate_up_proj,
39 down_proj,
40 intermediate_size: i_size,
41 })
42 }
43}