Skip to main content

rlx_flow/blocks/
llama_decoder.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4use anyhow::Result;
5use rlx_ir::op::MaskKind;
6
7use super::BlockStage;
8use crate::context::FlowCtx;
9use crate::value::FlowValue;
10#[derive(Debug, Clone)]
11pub struct LlamaDecoderStage {
12    pub layer_prefix: String,
13    pub num_heads: usize,
14    pub head_dim: usize,
15    pub num_kv_heads: usize,
16    pub eps: f32,
17    pub mask: MaskKind,
18    pub hidden_shape: rlx_ir::Shape,
19}
20
21impl LlamaDecoderStage {
22    pub fn layer(layer_idx: usize, spec: LlamaDecoderSpec) -> Self {
23        Self {
24            layer_prefix: format!("model.layers.{layer_idx}"),
25            num_heads: spec.num_heads,
26            head_dim: spec.head_dim,
27            num_kv_heads: spec.num_kv_heads,
28            eps: spec.eps,
29            mask: spec.mask,
30            hidden_shape: spec.hidden_shape,
31        }
32    }
33}
34
35#[derive(Debug, Clone)]
36pub struct LlamaDecoderSpec {
37    pub num_heads: usize,
38    pub head_dim: usize,
39    pub num_kv_heads: usize,
40    pub eps: f32,
41    pub mask: MaskKind,
42    pub hidden_shape: rlx_ir::Shape,
43}
44
45impl BlockStage for LlamaDecoderStage {
46    fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
47        let lp = &self.layer_prefix;
48        let zero_beta = ctx
49            .state
50            .zero_beta
51            .ok_or_else(|| anyhow::anyhow!("LlamaDecoder requires ZeroBeta stage"))?;
52        let cos = ctx
53            .state
54            .rope_cos
55            .ok_or_else(|| anyhow::anyhow!("LlamaDecoder requires RopeTables stage"))?;
56        let sin = ctx
57            .state
58            .rope_sin
59            .ok_or_else(|| anyhow::anyhow!("LlamaDecoder requires RopeTables stage"))?;
60
61        let in_ln_g = ctx.load_param(&format!("{lp}.input_layernorm.weight"), false)?;
62        let q_w = ctx.load_param(&format!("{lp}.self_attn.q_proj.weight"), true)?;
63        let k_w = ctx.load_param(&format!("{lp}.self_attn.k_proj.weight"), true)?;
64        let v_w = ctx.load_param(&format!("{lp}.self_attn.v_proj.weight"), true)?;
65        let o_w = ctx.load_param(&format!("{lp}.self_attn.o_proj.weight"), true)?;
66        let post_ln_g = ctx.load_param(&format!("{lp}.post_attention_layernorm.weight"), false)?;
67        let gate_w = ctx.load_param(&format!("{lp}.mlp.gate_proj.weight"), true)?;
68        let up_w = ctx.load_param(&format!("{lp}.mlp.up_proj.weight"), true)?;
69        let down_w = ctx.load_param(&format!("{lp}.mlp.down_proj.weight"), true)?;
70
71        let id = ctx.hir().llama_decoder_block(
72            input.id,
73            in_ln_g,
74            zero_beta,
75            q_w,
76            k_w,
77            v_w,
78            o_w,
79            post_ln_g,
80            zero_beta,
81            gate_w,
82            up_w,
83            down_w,
84            cos,
85            sin,
86            None,
87            self.num_heads,
88            self.head_dim,
89            self.num_kv_heads,
90            self.eps,
91            self.mask,
92            self.hidden_shape.clone(),
93        );
94
95        Ok(Some(ctx.wrap(id, self.hidden_shape.clone())))
96    }
97}