Skip to main content

rlx_flow/blocks/
llama_decoder.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16use anyhow::Result;
17use rlx_ir::op::MaskKind;
18
19use super::BlockStage;
20use crate::context::FlowCtx;
21use crate::value::FlowValue;
22#[derive(Debug, Clone)]
23pub struct LlamaDecoderStage {
24    pub layer_prefix: String,
25    pub num_heads: usize,
26    pub head_dim: usize,
27    pub num_kv_heads: usize,
28    pub eps: f32,
29    pub mask: MaskKind,
30    pub hidden_shape: rlx_ir::Shape,
31}
32
33impl LlamaDecoderStage {
34    pub fn layer(layer_idx: usize, spec: LlamaDecoderSpec) -> Self {
35        Self {
36            layer_prefix: format!("model.layers.{layer_idx}"),
37            num_heads: spec.num_heads,
38            head_dim: spec.head_dim,
39            num_kv_heads: spec.num_kv_heads,
40            eps: spec.eps,
41            mask: spec.mask,
42            hidden_shape: spec.hidden_shape,
43        }
44    }
45}
46
47#[derive(Debug, Clone)]
48pub struct LlamaDecoderSpec {
49    pub num_heads: usize,
50    pub head_dim: usize,
51    pub num_kv_heads: usize,
52    pub eps: f32,
53    pub mask: MaskKind,
54    pub hidden_shape: rlx_ir::Shape,
55}
56
57impl BlockStage for LlamaDecoderStage {
58    fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
59        let lp = &self.layer_prefix;
60        let zero_beta = ctx
61            .state
62            .zero_beta
63            .ok_or_else(|| anyhow::anyhow!("LlamaDecoder requires ZeroBeta stage"))?;
64        let cos = ctx
65            .state
66            .rope_cos
67            .ok_or_else(|| anyhow::anyhow!("LlamaDecoder requires RopeTables stage"))?;
68        let sin = ctx
69            .state
70            .rope_sin
71            .ok_or_else(|| anyhow::anyhow!("LlamaDecoder requires RopeTables stage"))?;
72
73        let in_ln_g = ctx.load_param(&format!("{lp}.input_layernorm.weight"), false)?;
74        let q_w = ctx.load_param(&format!("{lp}.self_attn.q_proj.weight"), true)?;
75        let k_w = ctx.load_param(&format!("{lp}.self_attn.k_proj.weight"), true)?;
76        let v_w = ctx.load_param(&format!("{lp}.self_attn.v_proj.weight"), true)?;
77        let o_w = ctx.load_param(&format!("{lp}.self_attn.o_proj.weight"), true)?;
78        let post_ln_g = ctx.load_param(&format!("{lp}.post_attention_layernorm.weight"), false)?;
79        let gate_w = ctx.load_param(&format!("{lp}.mlp.gate_proj.weight"), true)?;
80        let up_w = ctx.load_param(&format!("{lp}.mlp.up_proj.weight"), true)?;
81        let down_w = ctx.load_param(&format!("{lp}.mlp.down_proj.weight"), true)?;
82
83        let id = ctx.hir().llama_decoder_block(
84            input.id,
85            in_ln_g,
86            zero_beta,
87            q_w,
88            k_w,
89            v_w,
90            o_w,
91            post_ln_g,
92            zero_beta,
93            gate_w,
94            up_w,
95            down_w,
96            cos,
97            sin,
98            None,
99            self.num_heads,
100            self.head_dim,
101            self.num_kv_heads,
102            self.eps,
103            self.mask,
104            self.hidden_shape.clone(),
105        );
106
107        Ok(Some(ctx.wrap(id, self.hidden_shape.clone())))
108    }
109}