Skip to main content

rlx_flow/blocks/
qwen3_decoder.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4use anyhow::Result;
5use rlx_ir::HirGraphExt;
6use rlx_ir::hir::HirMut;
7use rlx_ir::op::MaskKind;
8use rlx_ir::shape;
9
10use std::sync::{Arc, Mutex};
11
12use super::BlockStage;
13use super::self_attn::repeat_kv;
14use crate::context::FlowCtx;
15use crate::value::FlowValue;
16
17#[derive(Debug, Clone)]
18pub struct Qwen3DecoderSpec {
19    pub num_heads: usize,
20    pub num_kv_heads: usize,
21    pub head_dim: usize,
22    pub eps: f32,
23    pub hidden_shape: rlx_ir::Shape,
24    pub batch: usize,
25    pub seq: usize,
26}
27
28#[derive(Debug, Clone)]
29pub struct Qwen3DecoderStage {
30    pub layer_prefix: String,
31    pub spec: Qwen3DecoderSpec,
32    pub kv_sink: Option<Arc<Mutex<Vec<rlx_ir::HirNodeId>>>>,
33}
34
35impl Qwen3DecoderStage {
36    pub fn layer(layer_idx: usize, spec: Qwen3DecoderSpec) -> Self {
37        Self {
38            layer_prefix: format!("model.layers.{layer_idx}"),
39            spec,
40            kv_sink: None,
41        }
42    }
43
44    pub fn layer_with_kv(
45        layer_idx: usize,
46        spec: Qwen3DecoderSpec,
47        kv_sink: Arc<Mutex<Vec<rlx_ir::HirNodeId>>>,
48    ) -> Self {
49        Self {
50            layer_prefix: format!("model.layers.{layer_idx}"),
51            spec,
52            kv_sink: Some(kv_sink),
53        }
54    }
55}
56
57impl BlockStage for Qwen3DecoderStage {
58    fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
59        let lp = &self.layer_prefix;
60        let spec = &self.spec;
61        let nh = spec.num_heads;
62        let nkv = spec.num_kv_heads;
63        let dh = spec.head_dim;
64        let group = nh / nkv;
65
66        let zero_beta_h = ctx
67            .state
68            .zero_beta
69            .ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires ZeroBeta"))?;
70        let zero_beta_dh = ctx
71            .state
72            .named
73            .get("zero_beta.head")
74            .copied()
75            .ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires zero_beta.head"))?;
76        let cos = ctx
77            .state
78            .rope_cos
79            .ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires RopeTables"))?;
80        let sin = ctx
81            .state
82            .rope_sin
83            .ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires RopeTables"))?;
84
85        let in_ln_g = ctx.load_param(&format!("{lp}.input_layernorm.weight"), false)?;
86        let q_w = ctx.load_param(&format!("{lp}.self_attn.q_proj.weight"), true)?;
87        let k_w = ctx.load_param(&format!("{lp}.self_attn.k_proj.weight"), true)?;
88        let v_w = ctx.load_param(&format!("{lp}.self_attn.v_proj.weight"), true)?;
89        let q_norm_g = ctx.load_param(&format!("{lp}.self_attn.q_norm.weight"), false)?;
90        let k_norm_g = ctx.load_param(&format!("{lp}.self_attn.k_norm.weight"), false)?;
91        let o_w = ctx.load_param(&format!("{lp}.self_attn.o_proj.weight"), true)?;
92        let post_ln_g = ctx.load_param(&format!("{lp}.post_attention_layernorm.weight"), false)?;
93        let gate_w = ctx.load_param(&format!("{lp}.mlp.gate_proj.weight"), true)?;
94        let up_w = ctx.load_param(&format!("{lp}.mlp.up_proj.weight"), true)?;
95        let down_w = ctx.load_param(&format!("{lp}.mlp.down_proj.weight"), true)?;
96
97        let mut gb = HirMut::new(ctx.hir());
98        let skip = input.id;
99
100        let normed_in = gb.rms_norm(skip, in_ln_g, zero_beta_h, spec.eps);
101        let q = gb.mm(normed_in, q_w);
102        let k = gb.mm(normed_in, k_w);
103        let v = gb.mm(normed_in, v_w);
104
105        let q_normed = per_head_rms(
106            &mut gb,
107            q,
108            q_norm_g,
109            zero_beta_dh,
110            spec.batch,
111            spec.seq,
112            nh,
113            dh,
114            spec.eps,
115        );
116        let k_normed = per_head_rms(
117            &mut gb,
118            k,
119            k_norm_g,
120            zero_beta_dh,
121            spec.batch,
122            spec.seq,
123            nkv,
124            dh,
125            spec.eps,
126        );
127
128        let q_rope = gb.rope(q_normed, cos, sin, dh);
129        let k_rope = gb.rope(k_normed, cos, sin, dh);
130        if let Some(ref sink) = self.kv_sink {
131            sink.lock().expect("qwen3 kv sink").push(k_rope);
132            sink.lock().expect("qwen3 kv sink").push(v);
133        }
134        let k_rep = repeat_kv(&mut gb, k_rope, nkv, dh, group);
135        let v_rep = repeat_kv(&mut gb, v, nkv, dh, group);
136
137        let attn_shape = shape::attention_shape(gb.shape(q_rope));
138        let attn = gb.attention_kind(q_rope, k_rep, v_rep, nh, dh, MaskKind::Causal, attn_shape);
139        let attn_out = gb.mm(attn, o_w);
140        let post_attn = gb.add(skip, attn_out);
141        let normed_post = gb.rms_norm(post_attn, post_ln_g, zero_beta_h, spec.eps);
142
143        let gate = gb.mm(normed_post, gate_w);
144        let up = gb.mm(normed_post, up_w);
145        let gate_act = gb.silu(gate);
146        let swiglu = gb.mul(gate_act, up);
147        let ffn_out = gb.mm(swiglu, down_w);
148        let out = gb.add(post_attn, ffn_out);
149
150        Ok(Some(ctx.wrap(out, spec.hidden_shape.clone())))
151    }
152}
153
154pub(crate) fn per_head_rms(
155    gb: &mut HirMut,
156    x: rlx_ir::HirNodeId,
157    gamma: rlx_ir::HirNodeId,
158    beta: rlx_ir::HirNodeId,
159    batch: usize,
160    seq: usize,
161    heads: usize,
162    head_dim: usize,
163    eps: f32,
164) -> rlx_ir::HirNodeId {
165    let flat = (batch * seq * heads) as i64;
166    let dh = head_dim as i64;
167    let r = gb.reshape_(x, vec![flat, dh]);
168    let n = gb.rms_norm(r, gamma, beta, eps);
169    gb.reshape_(n, vec![batch as i64, seq as i64, (heads * head_dim) as i64])
170}