1use anyhow::Result;
17use rlx_ir::HirGraphExt;
18use rlx_ir::hir::HirMut;
19use rlx_ir::op::MaskKind;
20use rlx_ir::shape;
21
22use std::sync::{Arc, Mutex};
23
24use super::BlockStage;
25use super::self_attn::repeat_kv;
26use crate::context::FlowCtx;
27use crate::value::FlowValue;
28
29#[derive(Debug, Clone)]
30pub struct Qwen3DecoderSpec {
31 pub num_heads: usize,
32 pub num_kv_heads: usize,
33 pub head_dim: usize,
34 pub eps: f32,
35 pub hidden_shape: rlx_ir::Shape,
36 pub batch: usize,
37 pub seq: usize,
38 pub qk_norm: bool,
40 pub attention_bias: bool,
42}
43
44#[derive(Debug, Clone)]
45pub struct Qwen3DecoderStage {
46 pub layer_prefix: String,
47 pub spec: Qwen3DecoderSpec,
48 pub kv_sink: Option<Arc<Mutex<Vec<rlx_ir::HirNodeId>>>>,
49}
50
51impl Qwen3DecoderStage {
52 pub fn layer(layer_idx: usize, spec: Qwen3DecoderSpec) -> Self {
53 Self {
54 layer_prefix: format!("model.layers.{layer_idx}"),
55 spec,
56 kv_sink: None,
57 }
58 }
59
60 pub fn layer_with_kv(
61 layer_idx: usize,
62 spec: Qwen3DecoderSpec,
63 kv_sink: Arc<Mutex<Vec<rlx_ir::HirNodeId>>>,
64 ) -> Self {
65 Self {
66 layer_prefix: format!("model.layers.{layer_idx}"),
67 spec,
68 kv_sink: Some(kv_sink),
69 }
70 }
71}
72
73impl BlockStage for Qwen3DecoderStage {
74 fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
75 let lp = &self.layer_prefix;
76 let spec = &self.spec;
77 let nh = spec.num_heads;
78 let nkv = spec.num_kv_heads;
79 let dh = spec.head_dim;
80 let group = nh / nkv;
81
82 let zero_beta_h = ctx
83 .state
84 .zero_beta
85 .ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires ZeroBeta"))?;
86 let zero_beta_dh = ctx
87 .state
88 .named
89 .get("zero_beta.head")
90 .copied()
91 .ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires zero_beta.head"))?;
92 let cos = ctx
93 .state
94 .rope_cos
95 .ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires RopeTables"))?;
96 let sin = ctx
97 .state
98 .rope_sin
99 .ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires RopeTables"))?;
100
101 let in_ln_g = ctx.load_param(&format!("{lp}.input_layernorm.weight"), false)?;
102 let q_w = ctx.load_param(&format!("{lp}.self_attn.q_proj.weight"), true)?;
103 let k_w = ctx.load_param(&format!("{lp}.self_attn.k_proj.weight"), true)?;
104 let v_w = ctx.load_param(&format!("{lp}.self_attn.v_proj.weight"), true)?;
105 let o_w = ctx.load_param(&format!("{lp}.self_attn.o_proj.weight"), true)?;
106 let post_ln_g = ctx.load_param(&format!("{lp}.post_attention_layernorm.weight"), false)?;
107 let gate_w = ctx.load_param(&format!("{lp}.mlp.gate_proj.weight"), true)?;
108 let up_w = ctx.load_param(&format!("{lp}.mlp.up_proj.weight"), true)?;
109 let down_w = ctx.load_param(&format!("{lp}.mlp.down_proj.weight"), true)?;
110 let (q_bias, k_bias, v_bias) = if spec.attention_bias {
111 (
112 Some(ctx.load_param(&format!("{lp}.self_attn.q_proj.bias"), false)?),
113 Some(ctx.load_param(&format!("{lp}.self_attn.k_proj.bias"), false)?),
114 Some(ctx.load_param(&format!("{lp}.self_attn.v_proj.bias"), false)?),
115 )
116 } else {
117 (None, None, None)
118 };
119 let (q_norm_g, k_norm_g) = if spec.qk_norm {
120 (
121 Some(ctx.load_param(&format!("{lp}.self_attn.q_norm.weight"), false)?),
122 Some(ctx.load_param(&format!("{lp}.self_attn.k_norm.weight"), false)?),
123 )
124 } else {
125 (None, None)
126 };
127
128 let mut gb = HirMut::new(ctx.hir());
129 let skip = input.id;
130
131 let normed_in = gb.rms_norm(skip, in_ln_g, zero_beta_h, spec.eps);
132 let mut q = gb.mm(normed_in, q_w);
133 let mut k = gb.mm(normed_in, k_w);
134 let mut v = gb.mm(normed_in, v_w);
135
136 if let (Some(qb), Some(kb), Some(vb)) = (q_bias, k_bias, v_bias) {
137 q = gb.add(q, qb);
138 k = gb.add(k, kb);
139 v = gb.add(v, vb);
140 }
141
142 let (q_rope_in, k_rope_in) = if let (Some(qng), Some(kng)) = (q_norm_g, k_norm_g) {
143 let q_normed = per_head_rms(
144 &mut gb,
145 q,
146 qng,
147 zero_beta_dh,
148 spec.batch,
149 spec.seq,
150 nh,
151 dh,
152 spec.eps,
153 );
154 let k_normed = per_head_rms(
155 &mut gb,
156 k,
157 kng,
158 zero_beta_dh,
159 spec.batch,
160 spec.seq,
161 nkv,
162 dh,
163 spec.eps,
164 );
165 (q_normed, k_normed)
166 } else {
167 (q, k)
168 };
169
170 let q_rope = gb.rope(q_rope_in, cos, sin, dh);
171 let k_rope = gb.rope(k_rope_in, cos, sin, dh);
172 if let Some(ref sink) = self.kv_sink {
173 sink.lock().expect("qwen3 kv sink").push(k_rope);
174 sink.lock().expect("qwen3 kv sink").push(v);
175 }
176 let k_rep = repeat_kv(&mut gb, k_rope, nkv, dh, group);
177 let v_rep = repeat_kv(&mut gb, v, nkv, dh, group);
178
179 let attn_shape = shape::attention_shape(gb.shape(q_rope));
180 let attn = gb.attention_kind(q_rope, k_rep, v_rep, nh, dh, MaskKind::Causal, attn_shape);
181 let attn_out = gb.mm(attn, o_w);
182 let post_attn = gb.add(skip, attn_out);
183 let normed_post = gb.rms_norm(post_attn, post_ln_g, zero_beta_h, spec.eps);
184
185 let gate = gb.mm(normed_post, gate_w);
186 let up = gb.mm(normed_post, up_w);
187 let gate_act = gb.silu(gate);
188 let swiglu = gb.mul(gate_act, up);
189 let ffn_out = gb.mm(swiglu, down_w);
190 let out = gb.add(post_attn, ffn_out);
191
192 Ok(Some(ctx.wrap(out, spec.hidden_shape.clone())))
193 }
194}
195
196pub(crate) fn per_head_rms(
197 gb: &mut HirMut,
198 x: rlx_ir::HirNodeId,
199 gamma: rlx_ir::HirNodeId,
200 beta: rlx_ir::HirNodeId,
201 batch: usize,
202 seq: usize,
203 heads: usize,
204 head_dim: usize,
205 eps: f32,
206) -> rlx_ir::HirNodeId {
207 let flat = (batch * seq * heads) as i64;
208 let dh = head_dim as i64;
209 let r = gb.reshape_(x, vec![flat, dh]);
210 let n = gb.rms_norm(r, gamma, beta, eps);
211 gb.reshape_(n, vec![batch as i64, seq as i64, (heads * head_dim) as i64])
212}