Skip to main content

rlx_flow/blocks/
llama_decode_layer.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4use anyhow::Result;
5use rlx_ir::HirGraphExt;
6use rlx_ir::hir::HirMut;
7use rlx_ir::op::MaskKind;
8use rlx_ir::shape;
9
10use super::BlockStage;
11use crate::context::FlowCtx;
12use crate::value::FlowValue;
13#[derive(Debug, Clone)]
14pub struct LlamaDecodeLayerSpec {
15    pub num_heads: usize,
16    pub head_dim: usize,
17    pub num_kv_heads: usize,
18    pub kv_group_size: usize,
19    pub eps: f32,
20    pub use_custom_mask: bool,
21    pub hidden_shape: rlx_ir::Shape,
22}
23
24#[derive(Debug, Clone)]
25pub struct LlamaDecodeLayerStage {
26    pub layer_prefix: String,
27    pub spec: LlamaDecodeLayerSpec,
28    pub layer_idx: usize,
29    pub kv_out: std::sync::Arc<std::sync::Mutex<Vec<rlx_ir::HirNodeId>>>,
30}
31
32impl LlamaDecodeLayerStage {
33    pub fn layer(
34        layer_idx: usize,
35        spec: LlamaDecodeLayerSpec,
36        kv_out: std::sync::Arc<std::sync::Mutex<Vec<rlx_ir::HirNodeId>>>,
37    ) -> Self {
38        Self {
39            layer_prefix: format!("model.layers.{layer_idx}"),
40            spec,
41            layer_idx,
42            kv_out,
43        }
44    }
45}
46
47impl BlockStage for LlamaDecodeLayerStage {
48    fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
49        let decode = ctx
50            .state
51            .decode
52            .clone()
53            .ok_or_else(|| anyhow::anyhow!("LlamaDecodeLayer requires BindDecodeInputs"))?;
54        let zero_beta = ctx
55            .state
56            .zero_beta
57            .ok_or_else(|| anyhow::anyhow!("LlamaDecodeLayer requires ZeroBeta"))?;
58
59        let lp = &self.layer_prefix;
60        let spec = &self.spec;
61        let in_ln_g = ctx.load_param(&format!("{lp}.input_layernorm.weight"), false)?;
62        let q_w = ctx.load_param(&format!("{lp}.self_attn.q_proj.weight"), true)?;
63        let k_w = ctx.load_param(&format!("{lp}.self_attn.k_proj.weight"), true)?;
64        let v_w = ctx.load_param(&format!("{lp}.self_attn.v_proj.weight"), true)?;
65        let o_w = ctx.load_param(&format!("{lp}.self_attn.o_proj.weight"), true)?;
66        let post_ln_g = ctx.load_param(&format!("{lp}.post_attention_layernorm.weight"), false)?;
67        let gate_w = ctx.load_param(&format!("{lp}.mlp.gate_proj.weight"), true)?;
68        let up_w = ctx.load_param(&format!("{lp}.mlp.up_proj.weight"), true)?;
69        let down_w = ctx.load_param(&format!("{lp}.mlp.down_proj.weight"), true)?;
70
71        let past_k = decode.past_k[self.layer_idx];
72        let past_v = decode.past_v[self.layer_idx];
73
74        let mut gb = HirMut::new(ctx.hir());
75        let normed_in = gb.rms_norm(input.id, in_ln_g, zero_beta, spec.eps);
76        let q = gb.mm(normed_in, q_w);
77        let k = gb.mm(normed_in, k_w);
78        let v = gb.mm(normed_in, v_w);
79
80        let q_rope = gb.rope(q, decode.cos, decode.sin, spec.head_dim);
81        let k_rope = gb.rope(k, decode.cos, decode.sin, spec.head_dim);
82
83        let new_k = gb.concat_(vec![past_k, k_rope], 1);
84        let new_v = gb.concat_(vec![past_v, v], 1);
85        self.kv_out.lock().expect("kv out").push(new_k);
86        self.kv_out.lock().expect("kv out").push(new_v);
87
88        let k_rep = super::self_attn::repeat_kv(
89            &mut gb,
90            new_k,
91            spec.num_kv_heads,
92            spec.head_dim,
93            spec.kv_group_size,
94        );
95        let v_rep = super::self_attn::repeat_kv(
96            &mut gb,
97            new_v,
98            spec.num_kv_heads,
99            spec.head_dim,
100            spec.kv_group_size,
101        );
102
103        let attn_shape = shape::attention_shape(gb.shape(q_rope));
104        let attn = if spec.use_custom_mask {
105            let mask = decode
106                .mask
107                .ok_or_else(|| anyhow::anyhow!("custom mask requested but not bound"))?;
108            gb.attention(
109                q_rope,
110                k_rep,
111                v_rep,
112                mask,
113                spec.num_heads,
114                spec.head_dim,
115                attn_shape,
116            )
117        } else {
118            gb.attention_kind(
119                q_rope,
120                k_rep,
121                v_rep,
122                spec.num_heads,
123                spec.head_dim,
124                MaskKind::Causal,
125                attn_shape,
126            )
127        };
128
129        let attn_out = gb.mm(attn, o_w);
130        let post_attn = gb.add(input.id, attn_out);
131        let normed_post = gb.rms_norm(post_attn, post_ln_g, zero_beta, spec.eps);
132        let gate = gb.mm(normed_post, gate_w);
133        let up = gb.mm(normed_post, up_w);
134        let gate_act = gb.silu(gate);
135        let swiglu = gb.mul(gate_act, up);
136        let ffn_out = gb.mm(swiglu, down_w);
137        let h_id = gb.add(post_attn, ffn_out);
138
139        Ok(Some(ctx.wrap(h_id, spec.hidden_shape.clone())))
140    }
141}