Skip to main content

cake_core/models/common/
mlp.rs

1use candle_core::{Result, Tensor, D};
2use candle_nn::{Linear, Module, VarBuilder};
3
4/// Multi-perceptron implementation with fused gate+up projection.
5#[allow(clippy::upper_case_acronyms)]
6#[derive(Debug, Clone)]
7pub struct MLP {
8    gate_up_proj: Linear,
9    down_proj: Linear,
10    intermediate_size: usize,
11}
12
13impl MLP {
14    /// Execute MLP(x).
15    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
16        let fused = self.gate_up_proj.forward(x)?;
17        let gate = fused.narrow(D::Minus1, 0, self.intermediate_size)?;
18        let up = fused.narrow(D::Minus1, self.intermediate_size, self.intermediate_size)?;
19        let x = (candle_nn::ops::silu(&gate)? * up)?;
20        self.down_proj.forward(&x)
21    }
22
23    /// Load this block from the VarBuilder given the specific configuration.
24    pub fn load(vb: VarBuilder, cfg: &super::Config) -> Result<Self> {
25        let h_size = cfg.hidden_size;
26        let i_size = cfg.intermediate_size;
27
28        // Fuse gate_proj and up_proj into a single matmul
29        let gate_w = vb.pp("gate_proj").get((i_size, h_size), "weight")?;
30        let up_w = vb.pp("up_proj").get((i_size, h_size), "weight")?;
31        let fused_w = Tensor::cat(&[&gate_w, &up_w], 0)?;
32        let gate_up_proj = Linear::new(fused_w, None);
33
34        let down_w = vb.pp("down_proj").get((h_size, i_size), "weight")?;
35        let down_proj = Linear::new(down_w, None);
36
37        Ok(Self {
38            gate_up_proj,
39            down_proj,
40            intermediate_size: i_size,
41        })
42    }
43}