gllm 0.10.6

Pure Rust library for local embeddings, reranking, and text generation with MoE-optimized inference and aggressive performance tuning
Documentation
use crate::causal_attention::CausalAttention;
use crate::rope::RotaryPositionEmbedding;
use crate::kv_cache::KVCache;
use crate::model_config::ModelConfig;
use crate::rms_norm::RmsNorm;
use crate::types::Result;
use burn::nn::{Linear, LinearConfig};
use burn::tensor::activation::silu;
use burn::tensor::backend::Backend;
use burn::tensor::Tensor;
use std::sync::Arc;

#[derive(Clone)]
pub struct DecoderLayer<B: Backend> {
    pub(crate) attention_norm: RmsNorm<B>,
    pub(crate) attention: CausalAttention<B>,
    pub(crate) ffn_norm: RmsNorm<B>,
    pub(crate) gate_proj: Linear<B>,
    pub(crate) up_proj: Linear<B>,
    pub(crate) down_proj: Linear<B>,
}

impl<B: Backend> DecoderLayer<B> {
    pub fn new(
        device: &B::Device,
        config: &ModelConfig,
        rope: Option<Arc<RotaryPositionEmbedding<B>>>,
    ) -> Result<Self> {
        let hidden_size = config.hidden_size;
        let intermediate = config.intermediate_size.unwrap_or(hidden_size.saturating_mul(4));

        let attention_norm = RmsNorm::new(device, config);
        let attention = CausalAttention::new(device, config, rope)?;
        let ffn_norm = RmsNorm::new(device, config);
        let gate_proj = LinearConfig::new(hidden_size, intermediate).init(device);
        let up_proj = LinearConfig::new(hidden_size, intermediate).init(device);
        let down_proj = LinearConfig::new(intermediate, hidden_size).init(device);

        Ok(Self {
            attention_norm,
            attention,
            ffn_norm,
            gate_proj,
            up_proj,
            down_proj,
        })
    }

    pub fn forward(&self, hidden_states: Tensor<B, 3>, position_offset: usize) -> Tensor<B, 3> {
        let attn_input = self.attention_norm.forward(hidden_states.clone());
        let attn_output = self.attention.forward(attn_input, position_offset);
        let hidden_states = hidden_states + attn_output;

        let ffn_input = self.ffn_norm.forward(hidden_states.clone());
        let gate = silu(self.gate_proj.forward(ffn_input.clone()));
        let up = self.up_proj.forward(ffn_input);
        let ffn_output = self.down_proj.forward(gate * up);

        hidden_states + ffn_output
    }

    pub fn forward_with_cache(
        &self,
        hidden_states: Tensor<B, 3>,
        position_offset: usize,
        cache: &mut KVCache<B>,
        layer: usize,
    ) -> Tensor<B, 3> {
        let attn_input = self.attention_norm.forward(hidden_states.clone());
        let attn_output = self
            .attention
            .forward_with_cache(attn_input, position_offset, cache, layer);
        let hidden_states = hidden_states + attn_output;

        let ffn_input = self.ffn_norm.forward(hidden_states.clone());
        let gate = silu(self.gate_proj.forward(ffn_input.clone()));
        let up = self.up_proj.forward(ffn_input);
        let ffn_output = self.down_proj.forward(gate * up);

        hidden_states + ffn_output
    }
}