Skip to main content

rlx_flow/blocks/
qwen3_decoder.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4use anyhow::Result;
5use rlx_ir::HirGraphExt;
6use rlx_ir::hir::HirMut;
7use rlx_ir::op::MaskKind;
8use rlx_ir::shape;
9
10use std::sync::{Arc, Mutex};
11
12use super::BlockStage;
13use super::self_attn::repeat_kv;
14use crate::context::FlowCtx;
15use crate::value::FlowValue;
16
17#[derive(Debug, Clone)]
18pub struct Qwen3DecoderSpec {
19    pub num_heads: usize,
20    pub num_kv_heads: usize,
21    pub head_dim: usize,
22    pub eps: f32,
23    pub hidden_shape: rlx_ir::Shape,
24    pub batch: usize,
25    pub seq: usize,
26    /// Per-head Q/K RMSNorm before RoPE (Qwen3); Qwen2 skips.
27    pub qk_norm: bool,
28    /// Explicit Q/K/V bias vectors (Qwen2); Qwen3 typically false.
29    pub attention_bias: bool,
30}
31
32#[derive(Debug, Clone)]
33pub struct Qwen3DecoderStage {
34    pub layer_prefix: String,
35    pub spec: Qwen3DecoderSpec,
36    pub kv_sink: Option<Arc<Mutex<Vec<rlx_ir::HirNodeId>>>>,
37}
38
39impl Qwen3DecoderStage {
40    pub fn layer(layer_idx: usize, spec: Qwen3DecoderSpec) -> Self {
41        Self {
42            layer_prefix: format!("model.layers.{layer_idx}"),
43            spec,
44            kv_sink: None,
45        }
46    }
47
48    pub fn layer_with_kv(
49        layer_idx: usize,
50        spec: Qwen3DecoderSpec,
51        kv_sink: Arc<Mutex<Vec<rlx_ir::HirNodeId>>>,
52    ) -> Self {
53        Self {
54            layer_prefix: format!("model.layers.{layer_idx}"),
55            spec,
56            kv_sink: Some(kv_sink),
57        }
58    }
59}
60
61impl BlockStage for Qwen3DecoderStage {
62    fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
63        let lp = &self.layer_prefix;
64        let spec = &self.spec;
65        let nh = spec.num_heads;
66        let nkv = spec.num_kv_heads;
67        let dh = spec.head_dim;
68        let group = nh / nkv;
69
70        let zero_beta_h = ctx
71            .state
72            .zero_beta
73            .ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires ZeroBeta"))?;
74        let zero_beta_dh = ctx
75            .state
76            .named
77            .get("zero_beta.head")
78            .copied()
79            .ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires zero_beta.head"))?;
80        let cos = ctx
81            .state
82            .rope_cos
83            .ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires RopeTables"))?;
84        let sin = ctx
85            .state
86            .rope_sin
87            .ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires RopeTables"))?;
88
89        let in_ln_g = ctx.load_param(&format!("{lp}.input_layernorm.weight"), false)?;
90        let q_w = ctx.load_param(&format!("{lp}.self_attn.q_proj.weight"), true)?;
91        let k_w = ctx.load_param(&format!("{lp}.self_attn.k_proj.weight"), true)?;
92        let v_w = ctx.load_param(&format!("{lp}.self_attn.v_proj.weight"), true)?;
93        let o_w = ctx.load_param(&format!("{lp}.self_attn.o_proj.weight"), true)?;
94        let post_ln_g = ctx.load_param(&format!("{lp}.post_attention_layernorm.weight"), false)?;
95        let gate_w = ctx.load_param(&format!("{lp}.mlp.gate_proj.weight"), true)?;
96        let up_w = ctx.load_param(&format!("{lp}.mlp.up_proj.weight"), true)?;
97        let down_w = ctx.load_param(&format!("{lp}.mlp.down_proj.weight"), true)?;
98        let (q_bias, k_bias, v_bias) = if spec.attention_bias {
99            (
100                Some(ctx.load_param(&format!("{lp}.self_attn.q_proj.bias"), false)?),
101                Some(ctx.load_param(&format!("{lp}.self_attn.k_proj.bias"), false)?),
102                Some(ctx.load_param(&format!("{lp}.self_attn.v_proj.bias"), false)?),
103            )
104        } else {
105            (None, None, None)
106        };
107        let (q_norm_g, k_norm_g) = if spec.qk_norm {
108            (
109                Some(ctx.load_param(&format!("{lp}.self_attn.q_norm.weight"), false)?),
110                Some(ctx.load_param(&format!("{lp}.self_attn.k_norm.weight"), false)?),
111            )
112        } else {
113            (None, None)
114        };
115
116        let mut gb = HirMut::new(ctx.hir());
117        let skip = input.id;
118
119        let normed_in = gb.rms_norm(skip, in_ln_g, zero_beta_h, spec.eps);
120        let mut q = gb.mm(normed_in, q_w);
121        let mut k = gb.mm(normed_in, k_w);
122        let mut v = gb.mm(normed_in, v_w);
123
124        if let (Some(qb), Some(kb), Some(vb)) = (q_bias, k_bias, v_bias) {
125            q = gb.add(q, qb);
126            k = gb.add(k, kb);
127            v = gb.add(v, vb);
128        }
129
130        let (q_rope_in, k_rope_in) = if let (Some(qng), Some(kng)) = (q_norm_g, k_norm_g) {
131            let q_normed = per_head_rms(
132                &mut gb,
133                q,
134                qng,
135                zero_beta_dh,
136                spec.batch,
137                spec.seq,
138                nh,
139                dh,
140                spec.eps,
141            );
142            let k_normed = per_head_rms(
143                &mut gb,
144                k,
145                kng,
146                zero_beta_dh,
147                spec.batch,
148                spec.seq,
149                nkv,
150                dh,
151                spec.eps,
152            );
153            (q_normed, k_normed)
154        } else {
155            (q, k)
156        };
157
158        let q_rope = gb.rope(q_rope_in, cos, sin, dh);
159        let k_rope = gb.rope(k_rope_in, cos, sin, dh);
160        if let Some(ref sink) = self.kv_sink {
161            sink.lock().expect("qwen3 kv sink").push(k_rope);
162            sink.lock().expect("qwen3 kv sink").push(v);
163        }
164        let k_rep = repeat_kv(&mut gb, k_rope, nkv, dh, group);
165        let v_rep = repeat_kv(&mut gb, v, nkv, dh, group);
166
167        let attn_shape = shape::attention_shape(gb.shape(q_rope));
168        let attn = gb.attention_kind(q_rope, k_rep, v_rep, nh, dh, MaskKind::Causal, attn_shape);
169        let attn_out = gb.mm(attn, o_w);
170        let post_attn = gb.add(skip, attn_out);
171        let normed_post = gb.rms_norm(post_attn, post_ln_g, zero_beta_h, spec.eps);
172
173        let gate = gb.mm(normed_post, gate_w);
174        let up = gb.mm(normed_post, up_w);
175        let gate_act = gb.silu(gate);
176        let swiglu = gb.mul(gate_act, up);
177        let ffn_out = gb.mm(swiglu, down_w);
178        let out = gb.add(post_attn, ffn_out);
179
180        Ok(Some(ctx.wrap(out, spec.hidden_shape.clone())))
181    }
182}
183
184pub(crate) fn per_head_rms(
185    gb: &mut HirMut,
186    x: rlx_ir::HirNodeId,
187    gamma: rlx_ir::HirNodeId,
188    beta: rlx_ir::HirNodeId,
189    batch: usize,
190    seq: usize,
191    heads: usize,
192    head_dim: usize,
193    eps: f32,
194) -> rlx_ir::HirNodeId {
195    let flat = (batch * seq * heads) as i64;
196    let dh = head_dim as i64;
197    let r = gb.reshape_(x, vec![flat, dh]);
198    let n = gb.rms_norm(r, gamma, beta, eps);
199    gb.reshape_(n, vec![batch as i64, seq as i64, (heads * head_dim) as i64])
200}