1use anyhow::Result;
5use rlx_ir::HirGraphExt;
6use rlx_ir::hir::HirMut;
7use rlx_ir::op::MaskKind;
8use rlx_ir::shape;
9
10use std::sync::{Arc, Mutex};
11
12use super::BlockStage;
13use super::self_attn::repeat_kv;
14use crate::context::FlowCtx;
15use crate::value::FlowValue;
16
17#[derive(Debug, Clone)]
18pub struct Qwen3DecoderSpec {
19 pub num_heads: usize,
20 pub num_kv_heads: usize,
21 pub head_dim: usize,
22 pub eps: f32,
23 pub hidden_shape: rlx_ir::Shape,
24 pub batch: usize,
25 pub seq: usize,
26 pub qk_norm: bool,
28 pub attention_bias: bool,
30}
31
32#[derive(Debug, Clone)]
33pub struct Qwen3DecoderStage {
34 pub layer_prefix: String,
35 pub spec: Qwen3DecoderSpec,
36 pub kv_sink: Option<Arc<Mutex<Vec<rlx_ir::HirNodeId>>>>,
37}
38
39impl Qwen3DecoderStage {
40 pub fn layer(layer_idx: usize, spec: Qwen3DecoderSpec) -> Self {
41 Self {
42 layer_prefix: format!("model.layers.{layer_idx}"),
43 spec,
44 kv_sink: None,
45 }
46 }
47
48 pub fn layer_with_kv(
49 layer_idx: usize,
50 spec: Qwen3DecoderSpec,
51 kv_sink: Arc<Mutex<Vec<rlx_ir::HirNodeId>>>,
52 ) -> Self {
53 Self {
54 layer_prefix: format!("model.layers.{layer_idx}"),
55 spec,
56 kv_sink: Some(kv_sink),
57 }
58 }
59}
60
61impl BlockStage for Qwen3DecoderStage {
62 fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
63 let lp = &self.layer_prefix;
64 let spec = &self.spec;
65 let nh = spec.num_heads;
66 let nkv = spec.num_kv_heads;
67 let dh = spec.head_dim;
68 let group = nh / nkv;
69
70 let zero_beta_h = ctx
71 .state
72 .zero_beta
73 .ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires ZeroBeta"))?;
74 let zero_beta_dh = ctx
75 .state
76 .named
77 .get("zero_beta.head")
78 .copied()
79 .ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires zero_beta.head"))?;
80 let cos = ctx
81 .state
82 .rope_cos
83 .ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires RopeTables"))?;
84 let sin = ctx
85 .state
86 .rope_sin
87 .ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires RopeTables"))?;
88
89 let in_ln_g = ctx.load_param(&format!("{lp}.input_layernorm.weight"), false)?;
90 let q_w = ctx.load_param(&format!("{lp}.self_attn.q_proj.weight"), true)?;
91 let k_w = ctx.load_param(&format!("{lp}.self_attn.k_proj.weight"), true)?;
92 let v_w = ctx.load_param(&format!("{lp}.self_attn.v_proj.weight"), true)?;
93 let o_w = ctx.load_param(&format!("{lp}.self_attn.o_proj.weight"), true)?;
94 let post_ln_g = ctx.load_param(&format!("{lp}.post_attention_layernorm.weight"), false)?;
95 let gate_w = ctx.load_param(&format!("{lp}.mlp.gate_proj.weight"), true)?;
96 let up_w = ctx.load_param(&format!("{lp}.mlp.up_proj.weight"), true)?;
97 let down_w = ctx.load_param(&format!("{lp}.mlp.down_proj.weight"), true)?;
98 let (q_bias, k_bias, v_bias) = if spec.attention_bias {
99 (
100 Some(ctx.load_param(&format!("{lp}.self_attn.q_proj.bias"), false)?),
101 Some(ctx.load_param(&format!("{lp}.self_attn.k_proj.bias"), false)?),
102 Some(ctx.load_param(&format!("{lp}.self_attn.v_proj.bias"), false)?),
103 )
104 } else {
105 (None, None, None)
106 };
107 let (q_norm_g, k_norm_g) = if spec.qk_norm {
108 (
109 Some(ctx.load_param(&format!("{lp}.self_attn.q_norm.weight"), false)?),
110 Some(ctx.load_param(&format!("{lp}.self_attn.k_norm.weight"), false)?),
111 )
112 } else {
113 (None, None)
114 };
115
116 let mut gb = HirMut::new(ctx.hir());
117 let skip = input.id;
118
119 let normed_in = gb.rms_norm(skip, in_ln_g, zero_beta_h, spec.eps);
120 let mut q = gb.mm(normed_in, q_w);
121 let mut k = gb.mm(normed_in, k_w);
122 let mut v = gb.mm(normed_in, v_w);
123
124 if let (Some(qb), Some(kb), Some(vb)) = (q_bias, k_bias, v_bias) {
125 q = gb.add(q, qb);
126 k = gb.add(k, kb);
127 v = gb.add(v, vb);
128 }
129
130 let (q_rope_in, k_rope_in) = if let (Some(qng), Some(kng)) = (q_norm_g, k_norm_g) {
131 let q_normed = per_head_rms(
132 &mut gb,
133 q,
134 qng,
135 zero_beta_dh,
136 spec.batch,
137 spec.seq,
138 nh,
139 dh,
140 spec.eps,
141 );
142 let k_normed = per_head_rms(
143 &mut gb,
144 k,
145 kng,
146 zero_beta_dh,
147 spec.batch,
148 spec.seq,
149 nkv,
150 dh,
151 spec.eps,
152 );
153 (q_normed, k_normed)
154 } else {
155 (q, k)
156 };
157
158 let q_rope = gb.rope(q_rope_in, cos, sin, dh);
159 let k_rope = gb.rope(k_rope_in, cos, sin, dh);
160 if let Some(ref sink) = self.kv_sink {
161 sink.lock().expect("qwen3 kv sink").push(k_rope);
162 sink.lock().expect("qwen3 kv sink").push(v);
163 }
164 let k_rep = repeat_kv(&mut gb, k_rope, nkv, dh, group);
165 let v_rep = repeat_kv(&mut gb, v, nkv, dh, group);
166
167 let attn_shape = shape::attention_shape(gb.shape(q_rope));
168 let attn = gb.attention_kind(q_rope, k_rep, v_rep, nh, dh, MaskKind::Causal, attn_shape);
169 let attn_out = gb.mm(attn, o_w);
170 let post_attn = gb.add(skip, attn_out);
171 let normed_post = gb.rms_norm(post_attn, post_ln_g, zero_beta_h, spec.eps);
172
173 let gate = gb.mm(normed_post, gate_w);
174 let up = gb.mm(normed_post, up_w);
175 let gate_act = gb.silu(gate);
176 let swiglu = gb.mul(gate_act, up);
177 let ffn_out = gb.mm(swiglu, down_w);
178 let out = gb.add(post_attn, ffn_out);
179
180 Ok(Some(ctx.wrap(out, spec.hidden_shape.clone())))
181 }
182}
183
184pub(crate) fn per_head_rms(
185 gb: &mut HirMut,
186 x: rlx_ir::HirNodeId,
187 gamma: rlx_ir::HirNodeId,
188 beta: rlx_ir::HirNodeId,
189 batch: usize,
190 seq: usize,
191 heads: usize,
192 head_dim: usize,
193 eps: f32,
194) -> rlx_ir::HirNodeId {
195 let flat = (batch * seq * heads) as i64;
196 let dh = head_dim as i64;
197 let r = gb.reshape_(x, vec![flat, dh]);
198 let n = gb.rms_norm(r, gamma, beta, eps);
199 gb.reshape_(n, vec![batch as i64, seq as i64, (heads * head_dim) as i64])
200}