Skip to main content

rlx_flow/blocks/
llama_decode_layer.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16use anyhow::Result;
17use rlx_ir::HirGraphExt;
18use rlx_ir::hir::HirMut;
19use rlx_ir::op::MaskKind;
20use rlx_ir::shape;
21
22use super::BlockStage;
23use crate::context::FlowCtx;
24use crate::value::FlowValue;
25#[derive(Debug, Clone)]
26pub struct LlamaDecodeLayerSpec {
27    pub num_heads: usize,
28    pub head_dim: usize,
29    pub num_kv_heads: usize,
30    pub kv_group_size: usize,
31    pub eps: f32,
32    pub use_custom_mask: bool,
33    pub hidden_shape: rlx_ir::Shape,
34}
35
36#[derive(Debug, Clone)]
37pub struct LlamaDecodeLayerStage {
38    pub layer_prefix: String,
39    pub spec: LlamaDecodeLayerSpec,
40    pub layer_idx: usize,
41    pub kv_out: std::sync::Arc<std::sync::Mutex<Vec<rlx_ir::HirNodeId>>>,
42    /// Optional EAGLE3-style tap for the pre-attention-norm layer
43    /// input. Mirrors the field on
44    /// [`crate::blocks::GemmaDecodeLayerStage`]; see that doc for
45    /// semantics and push-order guarantees.
46    pub aux_in_out: Option<std::sync::Arc<std::sync::Mutex<Vec<rlx_ir::HirNodeId>>>>,
47}
48
49impl LlamaDecodeLayerStage {
50    pub fn layer(
51        layer_idx: usize,
52        spec: LlamaDecodeLayerSpec,
53        kv_out: std::sync::Arc<std::sync::Mutex<Vec<rlx_ir::HirNodeId>>>,
54    ) -> Self {
55        Self {
56            layer_prefix: format!("model.layers.{layer_idx}"),
57            spec,
58            layer_idx,
59            kv_out,
60            aux_in_out: None,
61        }
62    }
63
64    pub fn with_aux_input_tap(
65        mut self,
66        sink: std::sync::Arc<std::sync::Mutex<Vec<rlx_ir::HirNodeId>>>,
67    ) -> Self {
68        self.aux_in_out = Some(sink);
69        self
70    }
71}
72
73impl BlockStage for LlamaDecodeLayerStage {
74    fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
75        if let Some(sink) = self.aux_in_out.as_ref() {
76            sink.lock().expect("aux in out").push(input.id);
77        }
78
79        let decode = ctx
80            .state
81            .decode
82            .clone()
83            .ok_or_else(|| anyhow::anyhow!("LlamaDecodeLayer requires BindDecodeInputs"))?;
84        let zero_beta = ctx
85            .state
86            .zero_beta
87            .ok_or_else(|| anyhow::anyhow!("LlamaDecodeLayer requires ZeroBeta"))?;
88
89        let lp = &self.layer_prefix;
90        let spec = &self.spec;
91        let in_ln_g = ctx.load_param(&format!("{lp}.input_layernorm.weight"), false)?;
92        let q_w = ctx.load_param(&format!("{lp}.self_attn.q_proj.weight"), true)?;
93        let k_w = ctx.load_param(&format!("{lp}.self_attn.k_proj.weight"), true)?;
94        let v_w = ctx.load_param(&format!("{lp}.self_attn.v_proj.weight"), true)?;
95        let o_w = ctx.load_param(&format!("{lp}.self_attn.o_proj.weight"), true)?;
96        let post_ln_g = ctx.load_param(&format!("{lp}.post_attention_layernorm.weight"), false)?;
97        let gate_w = ctx.load_param(&format!("{lp}.mlp.gate_proj.weight"), true)?;
98        let up_w = ctx.load_param(&format!("{lp}.mlp.up_proj.weight"), true)?;
99        let down_w = ctx.load_param(&format!("{lp}.mlp.down_proj.weight"), true)?;
100
101        let past_k = decode.past_k.get(self.layer_idx);
102        let past_v = decode.past_v.get(self.layer_idx);
103
104        let mut gb = HirMut::new(ctx.hir());
105        let normed_in = gb.rms_norm(input.id, in_ln_g, zero_beta, spec.eps);
106        let q = gb.mm(normed_in, q_w);
107        let k = gb.mm(normed_in, k_w);
108        let v = gb.mm(normed_in, v_w);
109
110        let q_rope = gb.rope(q, decode.cos, decode.sin, spec.head_dim);
111        let k_rope = gb.rope(k, decode.cos, decode.sin, spec.head_dim);
112
113        let (new_k, new_v) = match (past_k, past_v) {
114            (Some(past_k), Some(past_v)) => (
115                gb.concat_(vec![*past_k, k_rope], 1),
116                gb.concat_(vec![*past_v, v], 1),
117            ),
118            _ => (k_rope, v),
119        };
120        self.kv_out.lock().expect("kv out").push(new_k);
121        self.kv_out.lock().expect("kv out").push(new_v);
122
123        let k_rep = super::self_attn::repeat_kv(
124            &mut gb,
125            new_k,
126            spec.num_kv_heads,
127            spec.head_dim,
128            spec.kv_group_size,
129        );
130        let v_rep = super::self_attn::repeat_kv(
131            &mut gb,
132            new_v,
133            spec.num_kv_heads,
134            spec.head_dim,
135            spec.kv_group_size,
136        );
137
138        let attn_shape = shape::attention_shape(gb.shape(q_rope));
139        let attn = if spec.use_custom_mask {
140            let mask = decode
141                .mask
142                .ok_or_else(|| anyhow::anyhow!("custom mask requested but not bound"))?;
143            gb.attention(
144                q_rope,
145                k_rep,
146                v_rep,
147                mask,
148                spec.num_heads,
149                spec.head_dim,
150                attn_shape,
151            )
152        } else {
153            gb.attention_kind(
154                q_rope,
155                k_rep,
156                v_rep,
157                spec.num_heads,
158                spec.head_dim,
159                MaskKind::Causal,
160                attn_shape,
161            )
162        };
163
164        let attn_out = gb.mm(attn, o_w);
165        let post_attn = gb.add(input.id, attn_out);
166        let normed_post = gb.rms_norm(post_attn, post_ln_g, zero_beta, spec.eps);
167        let gate = gb.mm(normed_post, gate_w);
168        let up = gb.mm(normed_post, up_w);
169        let gate_act = gb.silu(gate);
170        let swiglu = gb.mul(gate_act, up);
171        let ffn_out = gb.mm(swiglu, down_w);
172        let h_id = gb.add(post_attn, ffn_out);
173
174        Ok(Some(ctx.wrap(h_id, spec.hidden_shape.clone())))
175    }
176}