Skip to main content

rlx_flow/blocks/
qwen3_decode_layer.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4use anyhow::Result;
5use rlx_ir::HirGraphExt;
6use rlx_ir::hir::HirMut;
7use rlx_ir::op::MaskKind;
8use rlx_ir::shape;
9
10use std::sync::{Arc, Mutex};
11
12use super::BlockStage;
13use super::qwen3_decoder::per_head_rms;
14use super::self_attn::repeat_kv;
15use crate::context::FlowCtx;
16use crate::value::FlowValue;
17
18#[derive(Debug, Clone)]
19pub struct Qwen3DecodeLayerSpec {
20    pub num_heads: usize,
21    pub num_kv_heads: usize,
22    pub head_dim: usize,
23    pub kv_group_size: usize,
24    pub eps: f32,
25    pub use_custom_mask: bool,
26    pub hidden_shape: rlx_ir::Shape,
27    pub batch: usize,
28    pub qk_norm: bool,
29    pub attention_bias: bool,
30}
31
32#[derive(Debug, Clone)]
33pub struct Qwen3DecodeLayerStage {
34    pub layer_prefix: String,
35    pub spec: Qwen3DecodeLayerSpec,
36    pub layer_idx: usize,
37    pub kv_out: Arc<Mutex<Vec<rlx_ir::HirNodeId>>>,
38}
39
40impl Qwen3DecodeLayerStage {
41    pub fn layer(
42        layer_idx: usize,
43        spec: Qwen3DecodeLayerSpec,
44        kv_out: Arc<Mutex<Vec<rlx_ir::HirNodeId>>>,
45    ) -> Self {
46        Self {
47            layer_prefix: format!("model.layers.{layer_idx}"),
48            spec,
49            layer_idx,
50            kv_out,
51        }
52    }
53}
54
55impl BlockStage for Qwen3DecodeLayerStage {
56    fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
57        let decode = ctx
58            .state
59            .decode
60            .clone()
61            .ok_or_else(|| anyhow::anyhow!("Qwen3DecodeLayer requires BindDecodeInputs"))?;
62        let zero_beta_h = ctx
63            .state
64            .zero_beta
65            .ok_or_else(|| anyhow::anyhow!("Qwen3DecodeLayer requires ZeroBeta"))?;
66        let zero_beta_dh = ctx
67            .state
68            .named
69            .get("zero_beta.head")
70            .copied()
71            .ok_or_else(|| anyhow::anyhow!("Qwen3DecodeLayer requires zero_beta.head"))?;
72
73        let lp = &self.layer_prefix;
74        let spec = &self.spec;
75        let nh = spec.num_heads;
76        let nkv = spec.num_kv_heads;
77        let dh = spec.head_dim;
78        let batch = spec.batch;
79
80        let in_ln_g = ctx.load_param(&format!("{lp}.input_layernorm.weight"), false)?;
81        let q_w = ctx.load_param(&format!("{lp}.self_attn.q_proj.weight"), true)?;
82        let k_w = ctx.load_param(&format!("{lp}.self_attn.k_proj.weight"), true)?;
83        let v_w = ctx.load_param(&format!("{lp}.self_attn.v_proj.weight"), true)?;
84        let o_w = ctx.load_param(&format!("{lp}.self_attn.o_proj.weight"), true)?;
85        let post_ln_g = ctx.load_param(&format!("{lp}.post_attention_layernorm.weight"), false)?;
86        let gate_w = ctx.load_param(&format!("{lp}.mlp.gate_proj.weight"), true)?;
87        let up_w = ctx.load_param(&format!("{lp}.mlp.up_proj.weight"), true)?;
88        let down_w = ctx.load_param(&format!("{lp}.mlp.down_proj.weight"), true)?;
89        let (q_bias, k_bias, v_bias) = if spec.attention_bias {
90            (
91                Some(ctx.load_param(&format!("{lp}.self_attn.q_proj.bias"), false)?),
92                Some(ctx.load_param(&format!("{lp}.self_attn.k_proj.bias"), false)?),
93                Some(ctx.load_param(&format!("{lp}.self_attn.v_proj.bias"), false)?),
94            )
95        } else {
96            (None, None, None)
97        };
98        let (q_norm_g, k_norm_g) = if spec.qk_norm {
99            (
100                Some(ctx.load_param(&format!("{lp}.self_attn.q_norm.weight"), false)?),
101                Some(ctx.load_param(&format!("{lp}.self_attn.k_norm.weight"), false)?),
102            )
103        } else {
104            (None, None)
105        };
106
107        let past_k = decode.past_k[self.layer_idx];
108        let past_v = decode.past_v[self.layer_idx];
109
110        let mut gb = HirMut::new(ctx.hir());
111        let skip = input.id;
112        let normed_in = gb.rms_norm(skip, in_ln_g, zero_beta_h, spec.eps);
113        let mut q = gb.mm(normed_in, q_w);
114        let mut k = gb.mm(normed_in, k_w);
115        let mut v = gb.mm(normed_in, v_w);
116
117        if let (Some(qb), Some(kb), Some(vb)) = (q_bias, k_bias, v_bias) {
118            q = gb.add(q, qb);
119            k = gb.add(k, kb);
120            v = gb.add(v, vb);
121        }
122
123        let (q_rope_in, k_rope_in) = if let (Some(qng), Some(kng)) = (q_norm_g, k_norm_g) {
124            let q_normed = per_head_rms(&mut gb, q, qng, zero_beta_dh, batch, 1, nh, dh, spec.eps);
125            let k_normed = per_head_rms(&mut gb, k, kng, zero_beta_dh, batch, 1, nkv, dh, spec.eps);
126            (q_normed, k_normed)
127        } else {
128            (q, k)
129        };
130
131        let q_rope = gb.rope(q_rope_in, decode.cos, decode.sin, dh);
132        let k_rope = gb.rope(k_rope_in, decode.cos, decode.sin, dh);
133
134        let new_k = gb.concat_(vec![past_k, k_rope], 1);
135        let new_v = gb.concat_(vec![past_v, v], 1);
136        self.kv_out.lock().expect("kv out").push(new_k);
137        self.kv_out.lock().expect("kv out").push(new_v);
138
139        let k_rep = repeat_kv(&mut gb, new_k, nkv, dh, spec.kv_group_size);
140        let v_rep = repeat_kv(&mut gb, new_v, nkv, dh, spec.kv_group_size);
141
142        let attn_shape = shape::attention_shape(gb.shape(q_rope));
143        let attn = if spec.use_custom_mask {
144            let mask = decode
145                .mask
146                .ok_or_else(|| anyhow::anyhow!("custom mask requested but not bound"))?;
147            gb.attention(q_rope, k_rep, v_rep, mask, nh, dh, attn_shape)
148        } else {
149            gb.attention_kind(q_rope, k_rep, v_rep, nh, dh, MaskKind::Causal, attn_shape)
150        };
151
152        let attn_out = gb.mm(attn, o_w);
153        let post_attn = gb.add(skip, attn_out);
154        let normed_post = gb.rms_norm(post_attn, post_ln_g, zero_beta_h, spec.eps);
155        let gate = gb.mm(normed_post, gate_w);
156        let up = gb.mm(normed_post, up_w);
157        let gate_act = gb.silu(gate);
158        let swiglu = gb.mul(gate_act, up);
159        let ffn_out = gb.mm(swiglu, down_w);
160        let h_id = gb.add(post_attn, ffn_out);
161
162        Ok(Some(ctx.wrap(h_id, spec.hidden_shape.clone())))
163    }
164}