Skip to main content

rlx_flow/blocks/
qwen3_decoder.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16use anyhow::Result;
17use rlx_ir::HirGraphExt;
18use rlx_ir::hir::HirMut;
19use rlx_ir::op::MaskKind;
20use rlx_ir::shape;
21
22use std::sync::{Arc, Mutex};
23
24use super::BlockStage;
25use super::self_attn::repeat_kv;
26use crate::context::FlowCtx;
27use crate::value::FlowValue;
28
29#[derive(Debug, Clone)]
30pub struct Qwen3DecoderSpec {
31    pub num_heads: usize,
32    pub num_kv_heads: usize,
33    pub head_dim: usize,
34    pub eps: f32,
35    pub hidden_shape: rlx_ir::Shape,
36    pub batch: usize,
37    pub seq: usize,
38    /// Per-head Q/K RMSNorm before RoPE (Qwen3); Qwen2 skips.
39    pub qk_norm: bool,
40    /// Explicit Q/K/V bias vectors (Qwen2); Qwen3 typically false.
41    pub attention_bias: bool,
42}
43
44#[derive(Debug, Clone)]
45pub struct Qwen3DecoderStage {
46    pub layer_prefix: String,
47    pub spec: Qwen3DecoderSpec,
48    pub kv_sink: Option<Arc<Mutex<Vec<rlx_ir::HirNodeId>>>>,
49}
50
51impl Qwen3DecoderStage {
52    pub fn layer(layer_idx: usize, spec: Qwen3DecoderSpec) -> Self {
53        Self {
54            layer_prefix: format!("model.layers.{layer_idx}"),
55            spec,
56            kv_sink: None,
57        }
58    }
59
60    pub fn layer_with_kv(
61        layer_idx: usize,
62        spec: Qwen3DecoderSpec,
63        kv_sink: Arc<Mutex<Vec<rlx_ir::HirNodeId>>>,
64    ) -> Self {
65        Self {
66            layer_prefix: format!("model.layers.{layer_idx}"),
67            spec,
68            kv_sink: Some(kv_sink),
69        }
70    }
71}
72
73impl BlockStage for Qwen3DecoderStage {
74    fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
75        let lp = &self.layer_prefix;
76        let spec = &self.spec;
77        let nh = spec.num_heads;
78        let nkv = spec.num_kv_heads;
79        let dh = spec.head_dim;
80        let group = nh / nkv;
81
82        let zero_beta_h = ctx
83            .state
84            .zero_beta
85            .ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires ZeroBeta"))?;
86        let zero_beta_dh = ctx
87            .state
88            .named
89            .get("zero_beta.head")
90            .copied()
91            .ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires zero_beta.head"))?;
92        let cos = ctx
93            .state
94            .rope_cos
95            .ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires RopeTables"))?;
96        let sin = ctx
97            .state
98            .rope_sin
99            .ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires RopeTables"))?;
100
101        let in_ln_g = ctx.load_param(&format!("{lp}.input_layernorm.weight"), false)?;
102        let q_w = ctx.load_param(&format!("{lp}.self_attn.q_proj.weight"), true)?;
103        let k_w = ctx.load_param(&format!("{lp}.self_attn.k_proj.weight"), true)?;
104        let v_w = ctx.load_param(&format!("{lp}.self_attn.v_proj.weight"), true)?;
105        let o_w = ctx.load_param(&format!("{lp}.self_attn.o_proj.weight"), true)?;
106        let post_ln_g = ctx.load_param(&format!("{lp}.post_attention_layernorm.weight"), false)?;
107        let gate_w = ctx.load_param(&format!("{lp}.mlp.gate_proj.weight"), true)?;
108        let up_w = ctx.load_param(&format!("{lp}.mlp.up_proj.weight"), true)?;
109        let down_w = ctx.load_param(&format!("{lp}.mlp.down_proj.weight"), true)?;
110        let (q_bias, k_bias, v_bias) = if spec.attention_bias {
111            (
112                Some(ctx.load_param(&format!("{lp}.self_attn.q_proj.bias"), false)?),
113                Some(ctx.load_param(&format!("{lp}.self_attn.k_proj.bias"), false)?),
114                Some(ctx.load_param(&format!("{lp}.self_attn.v_proj.bias"), false)?),
115            )
116        } else {
117            (None, None, None)
118        };
119        let (q_norm_g, k_norm_g) = if spec.qk_norm {
120            (
121                Some(ctx.load_param(&format!("{lp}.self_attn.q_norm.weight"), false)?),
122                Some(ctx.load_param(&format!("{lp}.self_attn.k_norm.weight"), false)?),
123            )
124        } else {
125            (None, None)
126        };
127
128        let mut gb = HirMut::new(ctx.hir());
129        let skip = input.id;
130
131        let normed_in = gb.rms_norm(skip, in_ln_g, zero_beta_h, spec.eps);
132        let mut q = gb.mm(normed_in, q_w);
133        let mut k = gb.mm(normed_in, k_w);
134        let mut v = gb.mm(normed_in, v_w);
135
136        if let (Some(qb), Some(kb), Some(vb)) = (q_bias, k_bias, v_bias) {
137            q = gb.add(q, qb);
138            k = gb.add(k, kb);
139            v = gb.add(v, vb);
140        }
141
142        let (q_rope_in, k_rope_in) = if let (Some(qng), Some(kng)) = (q_norm_g, k_norm_g) {
143            let q_normed = per_head_rms(
144                &mut gb,
145                q,
146                qng,
147                zero_beta_dh,
148                spec.batch,
149                spec.seq,
150                nh,
151                dh,
152                spec.eps,
153            );
154            let k_normed = per_head_rms(
155                &mut gb,
156                k,
157                kng,
158                zero_beta_dh,
159                spec.batch,
160                spec.seq,
161                nkv,
162                dh,
163                spec.eps,
164            );
165            (q_normed, k_normed)
166        } else {
167            (q, k)
168        };
169
170        let q_rope = gb.rope(q_rope_in, cos, sin, dh);
171        let k_rope = gb.rope(k_rope_in, cos, sin, dh);
172        if let Some(ref sink) = self.kv_sink {
173            sink.lock().expect("qwen3 kv sink").push(k_rope);
174            sink.lock().expect("qwen3 kv sink").push(v);
175        }
176        let k_rep = repeat_kv(&mut gb, k_rope, nkv, dh, group);
177        let v_rep = repeat_kv(&mut gb, v, nkv, dh, group);
178
179        let attn_shape = shape::attention_shape(gb.shape(q_rope));
180        let attn = gb.attention_kind(q_rope, k_rep, v_rep, nh, dh, MaskKind::Causal, attn_shape);
181        let attn_out = gb.mm(attn, o_w);
182        let post_attn = gb.add(skip, attn_out);
183        let normed_post = gb.rms_norm(post_attn, post_ln_g, zero_beta_h, spec.eps);
184
185        let gate = gb.mm(normed_post, gate_w);
186        let up = gb.mm(normed_post, up_w);
187        let gate_act = gb.silu(gate);
188        let swiglu = gb.mul(gate_act, up);
189        let ffn_out = gb.mm(swiglu, down_w);
190        let out = gb.add(post_attn, ffn_out);
191
192        Ok(Some(ctx.wrap(out, spec.hidden_shape.clone())))
193    }
194}
195
196pub(crate) fn per_head_rms(
197    gb: &mut HirMut,
198    x: rlx_ir::HirNodeId,
199    gamma: rlx_ir::HirNodeId,
200    beta: rlx_ir::HirNodeId,
201    batch: usize,
202    seq: usize,
203    heads: usize,
204    head_dim: usize,
205    eps: f32,
206) -> rlx_ir::HirNodeId {
207    let flat = (batch * seq * heads) as i64;
208    let dh = head_dim as i64;
209    let r = gb.reshape_(x, vec![flat, dh]);
210    let n = gb.rms_norm(r, gamma, beta, eps);
211    gb.reshape_(n, vec![batch as i64, seq as i64, (heads * head_dim) as i64])
212}