Skip to main content

rlx_flow/blocks/
gemma_decode_layer.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4use anyhow::Result;
5use rlx_ir::HirGraphExt;
6use rlx_ir::hir::HirMut;
7use rlx_ir::op::MaskKind;
8use rlx_ir::shape;
9use rlx_ir::{DType, Shape};
10
11use std::sync::{Arc, Mutex};
12
13use super::{BlockStage, GemmaLayerStyle};
14use crate::context::FlowCtx;
15use crate::value::FlowValue;
16
17#[derive(Debug, Clone)]
18pub struct GemmaDecodeLayerSpec {
19    pub style: GemmaLayerStyle,
20    pub num_heads: usize,
21    pub head_dim: usize,
22    pub num_kv_heads: usize,
23    pub kv_group_size: usize,
24    pub eps: f32,
25    pub use_custom_mask: bool,
26    pub hidden_shape: rlx_ir::Shape,
27    pub mask: MaskKind,
28    pub score_scale: Option<f32>,
29    pub attn_logit_softcap: Option<f32>,
30}
31
32#[derive(Debug, Clone)]
33pub struct GemmaDecodeLayerStage {
34    pub layer_prefix: String,
35    pub spec: GemmaDecodeLayerSpec,
36    pub layer_idx: usize,
37    pub kv_out: Arc<Mutex<Vec<rlx_ir::HirNodeId>>>,
38}
39
40impl GemmaDecodeLayerStage {
41    pub fn layer(
42        layer_idx: usize,
43        spec: GemmaDecodeLayerSpec,
44        kv_out: Arc<Mutex<Vec<rlx_ir::HirNodeId>>>,
45    ) -> Self {
46        Self {
47            layer_prefix: format!("model.layers.{layer_idx}"),
48            spec,
49            layer_idx,
50            kv_out,
51        }
52    }
53}
54
55impl BlockStage for GemmaDecodeLayerStage {
56    fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
57        let decode = ctx
58            .state
59            .decode
60            .clone()
61            .ok_or_else(|| anyhow::anyhow!("GemmaDecodeLayer requires BindDecodeInputs"))?;
62        let zero_beta = ctx
63            .state
64            .zero_beta
65            .ok_or_else(|| anyhow::anyhow!("GemmaDecodeLayer requires ZeroBeta"))?;
66
67        let lp = &self.layer_prefix;
68        let spec = &self.spec;
69        let style = spec.style;
70
71        let in_ln_w = ctx.load_param(&format!("{lp}.input_layernorm.weight"), false)?;
72        let in_ln_len = norm_len(ctx, in_ln_w)?;
73        let in_ln_ones = ctx.synth_param(
74            &format!("{lp}.input_layernorm.ones"),
75            vec![1.0f32; in_ln_len],
76            Shape::new(&[in_ln_len], DType::F32),
77        );
78
79        let pre_ffn_key = if matches!(
80            style,
81            GemmaLayerStyle::Gemma2 | GemmaLayerStyle::Gemma3 | GemmaLayerStyle::Gemma4
82        ) {
83            format!("{lp}.pre_feedforward_layernorm")
84        } else {
85            format!("{lp}.post_attention_layernorm")
86        };
87        let pre_ffn_w = ctx.load_param(&format!("{pre_ffn_key}.weight"), false)?;
88        let pre_ffn_len = norm_len(ctx, pre_ffn_w)?;
89        let pre_ffn_ones = ctx.synth_param(
90            &format!("{pre_ffn_key}.ones"),
91            vec![1.0f32; pre_ffn_len],
92            Shape::new(&[pre_ffn_len], DType::F32),
93        );
94
95        let post_ffn = if matches!(
96            style,
97            GemmaLayerStyle::Gemma2 | GemmaLayerStyle::Gemma3 | GemmaLayerStyle::Gemma4
98        ) {
99            let post_key = format!("{lp}.post_feedforward_layernorm");
100            let w = ctx.load_param(&format!("{post_key}.weight"), false)?;
101            let len = norm_len(ctx, w)?;
102            let ones = ctx.synth_param(
103                &format!("{post_key}.ones"),
104                vec![1.0f32; len],
105                Shape::new(&[len], DType::F32),
106            );
107            Some((w, ones))
108        } else {
109            None
110        };
111
112        let q_w = ctx.load_param(&format!("{lp}.self_attn.q_proj.weight"), true)?;
113        let k_w = ctx.load_param(&format!("{lp}.self_attn.k_proj.weight"), true)?;
114        let v_w = ctx.load_param(&format!("{lp}.self_attn.v_proj.weight"), true)?;
115        let o_w = ctx.load_param(&format!("{lp}.self_attn.o_proj.weight"), true)?;
116        let gate_w = ctx.load_param(&format!("{lp}.mlp.gate_proj.weight"), true)?;
117        let up_w = ctx.load_param(&format!("{lp}.mlp.up_proj.weight"), true)?;
118        let down_w = ctx.load_param(&format!("{lp}.mlp.down_proj.weight"), true)?;
119
120        let past_k = decode.past_k[self.layer_idx];
121        let past_v = decode.past_v[self.layer_idx];
122
123        let mut gb = HirMut::new(ctx.hir());
124        let in_gamma = gb.add(in_ln_ones, in_ln_w);
125        let normed_in = gb.rms_norm(input.id, in_gamma, zero_beta, spec.eps);
126        let q = gb.mm(normed_in, q_w);
127        let k = gb.mm(normed_in, k_w);
128        let v = gb.mm(normed_in, v_w);
129        let q_rope = gb.rope(q, decode.cos, decode.sin, spec.head_dim);
130        let k_rope = gb.rope(k, decode.cos, decode.sin, spec.head_dim);
131
132        let new_k = gb.concat_(vec![past_k, k_rope], 1);
133        let new_v = gb.concat_(vec![past_v, v], 1);
134        self.kv_out.lock().expect("kv out").push(new_k);
135        self.kv_out.lock().expect("kv out").push(new_v);
136
137        let k_rep = super::self_attn::repeat_kv(
138            &mut gb,
139            new_k,
140            spec.num_kv_heads,
141            spec.head_dim,
142            spec.kv_group_size,
143        );
144        let v_rep = super::self_attn::repeat_kv(
145            &mut gb,
146            new_v,
147            spec.num_kv_heads,
148            spec.head_dim,
149            spec.kv_group_size,
150        );
151
152        let attn_shape = shape::attention_shape(gb.shape(q_rope));
153        let attn = if spec.use_custom_mask {
154            let mask = decode
155                .mask
156                .ok_or_else(|| anyhow::anyhow!("custom mask requested but not bound"))?;
157            gb.attention_(q_rope, k_rep, v_rep, mask, spec.num_heads, spec.head_dim)
158        } else {
159            gb.attention_kind_opts(
160                q_rope,
161                k_rep,
162                v_rep,
163                spec.num_heads,
164                spec.head_dim,
165                spec.mask,
166                attn_shape,
167                spec.score_scale,
168                spec.attn_logit_softcap,
169            )
170        };
171
172        let attn_out = gb.mm(attn, o_w);
173        let post_attn = gb.add(input.id, attn_out);
174
175        let pre_gamma = gb.add(pre_ffn_ones, pre_ffn_w);
176        let mut h = gb.rms_norm(post_attn, pre_gamma, zero_beta, spec.eps);
177        let gate = gb.mm(h, gate_w);
178        let up = gb.mm(h, up_w);
179        let gate_act = gb.gelu_approx(gate);
180        h = gb.mul(gate_act, up);
181        h = gb.mm(h, down_w);
182
183        if let Some((post_w, post_ones)) = post_ffn {
184            let post_gamma = gb.add(post_ones, post_w);
185            h = gb.rms_norm(h, post_gamma, zero_beta, spec.eps);
186        }
187
188        let out_id = gb.add(post_attn, h);
189        Ok(Some(ctx.wrap(out_id, spec.hidden_shape.clone())))
190    }
191}
192
193fn norm_len(ctx: &FlowCtx<'_>, weight: rlx_ir::HirNodeId) -> Result<usize> {
194    match ctx.node_shape(weight)?.dims().last() {
195        Some(rlx_ir::shape::Dim::Static(n)) => Ok(*n),
196        _ => Ok(0),
197    }
198}