Skip to main content

rlx_flow/blocks/
llama_decode_layer.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16use anyhow::Result;
17use rlx_ir::HirGraphExt;
18use rlx_ir::hir::HirMut;
19use rlx_ir::op::MaskKind;
20use rlx_ir::shape;
21
22use super::BlockStage;
23use crate::context::FlowCtx;
24use crate::value::FlowValue;
25#[derive(Debug, Clone)]
26pub struct LlamaDecodeLayerSpec {
27    pub num_heads: usize,
28    pub head_dim: usize,
29    pub num_kv_heads: usize,
30    pub kv_group_size: usize,
31    pub eps: f32,
32    pub use_custom_mask: bool,
33    pub hidden_shape: rlx_ir::Shape,
34}
35
36#[derive(Debug, Clone)]
37pub struct LlamaDecodeLayerStage {
38    pub layer_prefix: String,
39    pub spec: LlamaDecodeLayerSpec,
40    pub layer_idx: usize,
41    pub kv_out: std::sync::Arc<std::sync::Mutex<Vec<rlx_ir::HirNodeId>>>,
42}
43
44impl LlamaDecodeLayerStage {
45    pub fn layer(
46        layer_idx: usize,
47        spec: LlamaDecodeLayerSpec,
48        kv_out: std::sync::Arc<std::sync::Mutex<Vec<rlx_ir::HirNodeId>>>,
49    ) -> Self {
50        Self {
51            layer_prefix: format!("model.layers.{layer_idx}"),
52            spec,
53            layer_idx,
54            kv_out,
55        }
56    }
57}
58
59impl BlockStage for LlamaDecodeLayerStage {
60    fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
61        let decode = ctx
62            .state
63            .decode
64            .clone()
65            .ok_or_else(|| anyhow::anyhow!("LlamaDecodeLayer requires BindDecodeInputs"))?;
66        let zero_beta = ctx
67            .state
68            .zero_beta
69            .ok_or_else(|| anyhow::anyhow!("LlamaDecodeLayer requires ZeroBeta"))?;
70
71        let lp = &self.layer_prefix;
72        let spec = &self.spec;
73        let in_ln_g = ctx.load_param(&format!("{lp}.input_layernorm.weight"), false)?;
74        let q_w = ctx.load_param(&format!("{lp}.self_attn.q_proj.weight"), true)?;
75        let k_w = ctx.load_param(&format!("{lp}.self_attn.k_proj.weight"), true)?;
76        let v_w = ctx.load_param(&format!("{lp}.self_attn.v_proj.weight"), true)?;
77        let o_w = ctx.load_param(&format!("{lp}.self_attn.o_proj.weight"), true)?;
78        let post_ln_g = ctx.load_param(&format!("{lp}.post_attention_layernorm.weight"), false)?;
79        let gate_w = ctx.load_param(&format!("{lp}.mlp.gate_proj.weight"), true)?;
80        let up_w = ctx.load_param(&format!("{lp}.mlp.up_proj.weight"), true)?;
81        let down_w = ctx.load_param(&format!("{lp}.mlp.down_proj.weight"), true)?;
82
83        let past_k = decode.past_k[self.layer_idx];
84        let past_v = decode.past_v[self.layer_idx];
85
86        let mut gb = HirMut::new(ctx.hir());
87        let normed_in = gb.rms_norm(input.id, in_ln_g, zero_beta, spec.eps);
88        let q = gb.mm(normed_in, q_w);
89        let k = gb.mm(normed_in, k_w);
90        let v = gb.mm(normed_in, v_w);
91
92        let q_rope = gb.rope(q, decode.cos, decode.sin, spec.head_dim);
93        let k_rope = gb.rope(k, decode.cos, decode.sin, spec.head_dim);
94
95        let new_k = gb.concat_(vec![past_k, k_rope], 1);
96        let new_v = gb.concat_(vec![past_v, v], 1);
97        self.kv_out.lock().expect("kv out").push(new_k);
98        self.kv_out.lock().expect("kv out").push(new_v);
99
100        let k_rep = super::self_attn::repeat_kv(
101            &mut gb,
102            new_k,
103            spec.num_kv_heads,
104            spec.head_dim,
105            spec.kv_group_size,
106        );
107        let v_rep = super::self_attn::repeat_kv(
108            &mut gb,
109            new_v,
110            spec.num_kv_heads,
111            spec.head_dim,
112            spec.kv_group_size,
113        );
114
115        let attn_shape = shape::attention_shape(gb.shape(q_rope));
116        let attn = if spec.use_custom_mask {
117            let mask = decode
118                .mask
119                .ok_or_else(|| anyhow::anyhow!("custom mask requested but not bound"))?;
120            gb.attention(
121                q_rope,
122                k_rep,
123                v_rep,
124                mask,
125                spec.num_heads,
126                spec.head_dim,
127                attn_shape,
128            )
129        } else {
130            gb.attention_kind(
131                q_rope,
132                k_rep,
133                v_rep,
134                spec.num_heads,
135                spec.head_dim,
136                MaskKind::Causal,
137                attn_shape,
138            )
139        };
140
141        let attn_out = gb.mm(attn, o_w);
142        let post_attn = gb.add(input.id, attn_out);
143        let normed_post = gb.rms_norm(post_attn, post_ln_g, zero_beta, spec.eps);
144        let gate = gb.mm(normed_post, gate_w);
145        let up = gb.mm(normed_post, up_w);
146        let gate_act = gb.silu(gate);
147        let swiglu = gb.mul(gate_act, up);
148        let ffn_out = gb.mm(swiglu, down_w);
149        let h_id = gb.add(post_attn, ffn_out);
150
151        Ok(Some(ctx.wrap(h_id, spec.hidden_shape.clone())))
152    }
153}