Skip to main content

rlx_flow/blocks/
qwen3_decode_layer.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16use anyhow::Result;
17use rlx_ir::HirGraphExt;
18use rlx_ir::hir::HirMut;
19use rlx_ir::op::MaskKind;
20use rlx_ir::shape;
21
22use std::sync::{Arc, Mutex};
23
24use super::BlockStage;
25use super::qwen3_decoder::per_head_rms;
26use super::self_attn::repeat_kv;
27use crate::context::FlowCtx;
28use crate::value::FlowValue;
29
30#[derive(Debug, Clone)]
31pub struct Qwen3DecodeLayerSpec {
32    pub num_heads: usize,
33    pub num_kv_heads: usize,
34    pub head_dim: usize,
35    pub kv_group_size: usize,
36    pub eps: f32,
37    pub use_custom_mask: bool,
38    pub hidden_shape: rlx_ir::Shape,
39    pub batch: usize,
40    pub qk_norm: bool,
41    pub attention_bias: bool,
42}
43
44#[derive(Debug, Clone)]
45pub struct Qwen3DecodeLayerStage {
46    pub layer_prefix: String,
47    pub spec: Qwen3DecodeLayerSpec,
48    pub layer_idx: usize,
49    pub kv_out: Arc<Mutex<Vec<rlx_ir::HirNodeId>>>,
50}
51
52impl Qwen3DecodeLayerStage {
53    pub fn layer(
54        layer_idx: usize,
55        spec: Qwen3DecodeLayerSpec,
56        kv_out: Arc<Mutex<Vec<rlx_ir::HirNodeId>>>,
57    ) -> Self {
58        Self {
59            layer_prefix: format!("model.layers.{layer_idx}"),
60            spec,
61            layer_idx,
62            kv_out,
63        }
64    }
65}
66
67impl BlockStage for Qwen3DecodeLayerStage {
68    fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
69        let decode = ctx
70            .state
71            .decode
72            .clone()
73            .ok_or_else(|| anyhow::anyhow!("Qwen3DecodeLayer requires BindDecodeInputs"))?;
74        let zero_beta_h = ctx
75            .state
76            .zero_beta
77            .ok_or_else(|| anyhow::anyhow!("Qwen3DecodeLayer requires ZeroBeta"))?;
78        let zero_beta_dh = ctx
79            .state
80            .named
81            .get("zero_beta.head")
82            .copied()
83            .ok_or_else(|| anyhow::anyhow!("Qwen3DecodeLayer requires zero_beta.head"))?;
84
85        let lp = &self.layer_prefix;
86        let spec = &self.spec;
87        let nh = spec.num_heads;
88        let nkv = spec.num_kv_heads;
89        let dh = spec.head_dim;
90        let batch = spec.batch;
91
92        let in_ln_g = ctx.load_param(&format!("{lp}.input_layernorm.weight"), false)?;
93        let q_w = ctx.load_param(&format!("{lp}.self_attn.q_proj.weight"), true)?;
94        let k_w = ctx.load_param(&format!("{lp}.self_attn.k_proj.weight"), true)?;
95        let v_w = ctx.load_param(&format!("{lp}.self_attn.v_proj.weight"), true)?;
96        let o_w = ctx.load_param(&format!("{lp}.self_attn.o_proj.weight"), true)?;
97        let post_ln_g = ctx.load_param(&format!("{lp}.post_attention_layernorm.weight"), false)?;
98        let gate_w = ctx.load_param(&format!("{lp}.mlp.gate_proj.weight"), true)?;
99        let up_w = ctx.load_param(&format!("{lp}.mlp.up_proj.weight"), true)?;
100        let down_w = ctx.load_param(&format!("{lp}.mlp.down_proj.weight"), true)?;
101        let (q_bias, k_bias, v_bias) = if spec.attention_bias {
102            (
103                Some(ctx.load_param(&format!("{lp}.self_attn.q_proj.bias"), false)?),
104                Some(ctx.load_param(&format!("{lp}.self_attn.k_proj.bias"), false)?),
105                Some(ctx.load_param(&format!("{lp}.self_attn.v_proj.bias"), false)?),
106            )
107        } else {
108            (None, None, None)
109        };
110        let (q_norm_g, k_norm_g) = if spec.qk_norm {
111            (
112                Some(ctx.load_param(&format!("{lp}.self_attn.q_norm.weight"), false)?),
113                Some(ctx.load_param(&format!("{lp}.self_attn.k_norm.weight"), false)?),
114            )
115        } else {
116            (None, None)
117        };
118
119        let past_k = decode.past_k[self.layer_idx];
120        let past_v = decode.past_v[self.layer_idx];
121
122        let mut gb = HirMut::new(ctx.hir());
123        let skip = input.id;
124        let normed_in = gb.rms_norm(skip, in_ln_g, zero_beta_h, spec.eps);
125        let mut q = gb.mm(normed_in, q_w);
126        let mut k = gb.mm(normed_in, k_w);
127        let mut v = gb.mm(normed_in, v_w);
128
129        if let (Some(qb), Some(kb), Some(vb)) = (q_bias, k_bias, v_bias) {
130            q = gb.add(q, qb);
131            k = gb.add(k, kb);
132            v = gb.add(v, vb);
133        }
134
135        let (q_rope_in, k_rope_in) = if let (Some(qng), Some(kng)) = (q_norm_g, k_norm_g) {
136            let q_normed = per_head_rms(&mut gb, q, qng, zero_beta_dh, batch, 1, nh, dh, spec.eps);
137            let k_normed = per_head_rms(&mut gb, k, kng, zero_beta_dh, batch, 1, nkv, dh, spec.eps);
138            (q_normed, k_normed)
139        } else {
140            (q, k)
141        };
142
143        let q_rope = gb.rope(q_rope_in, decode.cos, decode.sin, dh);
144        let k_rope = gb.rope(k_rope_in, decode.cos, decode.sin, dh);
145
146        let new_k = gb.concat_(vec![past_k, k_rope], 1);
147        let new_v = gb.concat_(vec![past_v, v], 1);
148        self.kv_out.lock().expect("kv out").push(new_k);
149        self.kv_out.lock().expect("kv out").push(new_v);
150
151        let k_rep = repeat_kv(&mut gb, new_k, nkv, dh, spec.kv_group_size);
152        let v_rep = repeat_kv(&mut gb, new_v, nkv, dh, spec.kv_group_size);
153
154        let attn_shape = shape::attention_shape(gb.shape(q_rope));
155        let attn = if spec.use_custom_mask {
156            let mask = decode
157                .mask
158                .ok_or_else(|| anyhow::anyhow!("custom mask requested but not bound"))?;
159            gb.attention(q_rope, k_rep, v_rep, mask, nh, dh, attn_shape)
160        } else {
161            gb.attention_kind(q_rope, k_rep, v_rep, nh, dh, MaskKind::Causal, attn_shape)
162        };
163
164        let attn_out = gb.mm(attn, o_w);
165        let post_attn = gb.add(skip, attn_out);
166        let normed_post = gb.rms_norm(post_attn, post_ln_g, zero_beta_h, spec.eps);
167        let gate = gb.mm(normed_post, gate_w);
168        let up = gb.mm(normed_post, up_w);
169        let gate_act = gb.silu(gate);
170        let swiglu = gb.mul(gate_act, up);
171        let ffn_out = gb.mm(swiglu, down_w);
172        let h_id = gb.add(post_attn, ffn_out);
173
174        Ok(Some(ctx.wrap(h_id, spec.hidden_shape.clone())))
175    }
176}