moe-llm-core 1.3.6

Part of the MoE-13 Ternary Intelligence Stack
Documentation
use candle_core::{Result, Tensor};
use candle_nn::VarBuilder;
use super::ternary_linear::TernaryLinear;

pub struct Mlp {
    c_fc: TernaryLinear,
    c_proj: TernaryLinear,
}

impl Mlp {
    pub fn new(hidden_size: usize, intermediate_size: usize, vb: VarBuilder, threshold: f32) -> Result<Self> {
        let c_fc = TernaryLinear::new(hidden_size, intermediate_size, true, threshold, vb.pp("c_fc"))?;
        let c_proj = TernaryLinear::new(intermediate_size, hidden_size, true, threshold, vb.pp("c_proj"))?;
        Ok(Self { c_fc, c_proj })
    }

    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
        let x = self.c_fc.forward(x)?;
        let x = x.gelu()?;
        self.c_proj.forward(&x)
    }
}