1use anyhow::Result;
5use rlx_ir::HirGraphExt;
6use rlx_ir::hir::HirMut;
7use rlx_ir::op::MaskKind;
8use rlx_ir::shape;
9
10use std::sync::{Arc, Mutex};
11
12use super::BlockStage;
13use super::self_attn::repeat_kv;
14use crate::context::FlowCtx;
15use crate::value::FlowValue;
16
17#[derive(Debug, Clone)]
18pub struct Qwen3DecoderSpec {
19 pub num_heads: usize,
20 pub num_kv_heads: usize,
21 pub head_dim: usize,
22 pub eps: f32,
23 pub hidden_shape: rlx_ir::Shape,
24 pub batch: usize,
25 pub seq: usize,
26}
27
28#[derive(Debug, Clone)]
29pub struct Qwen3DecoderStage {
30 pub layer_prefix: String,
31 pub spec: Qwen3DecoderSpec,
32 pub kv_sink: Option<Arc<Mutex<Vec<rlx_ir::HirNodeId>>>>,
33}
34
35impl Qwen3DecoderStage {
36 pub fn layer(layer_idx: usize, spec: Qwen3DecoderSpec) -> Self {
37 Self {
38 layer_prefix: format!("model.layers.{layer_idx}"),
39 spec,
40 kv_sink: None,
41 }
42 }
43
44 pub fn layer_with_kv(
45 layer_idx: usize,
46 spec: Qwen3DecoderSpec,
47 kv_sink: Arc<Mutex<Vec<rlx_ir::HirNodeId>>>,
48 ) -> Self {
49 Self {
50 layer_prefix: format!("model.layers.{layer_idx}"),
51 spec,
52 kv_sink: Some(kv_sink),
53 }
54 }
55}
56
57impl BlockStage for Qwen3DecoderStage {
58 fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
59 let lp = &self.layer_prefix;
60 let spec = &self.spec;
61 let nh = spec.num_heads;
62 let nkv = spec.num_kv_heads;
63 let dh = spec.head_dim;
64 let group = nh / nkv;
65
66 let zero_beta_h = ctx
67 .state
68 .zero_beta
69 .ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires ZeroBeta"))?;
70 let zero_beta_dh = ctx
71 .state
72 .named
73 .get("zero_beta.head")
74 .copied()
75 .ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires zero_beta.head"))?;
76 let cos = ctx
77 .state
78 .rope_cos
79 .ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires RopeTables"))?;
80 let sin = ctx
81 .state
82 .rope_sin
83 .ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires RopeTables"))?;
84
85 let in_ln_g = ctx.load_param(&format!("{lp}.input_layernorm.weight"), false)?;
86 let q_w = ctx.load_param(&format!("{lp}.self_attn.q_proj.weight"), true)?;
87 let k_w = ctx.load_param(&format!("{lp}.self_attn.k_proj.weight"), true)?;
88 let v_w = ctx.load_param(&format!("{lp}.self_attn.v_proj.weight"), true)?;
89 let q_norm_g = ctx.load_param(&format!("{lp}.self_attn.q_norm.weight"), false)?;
90 let k_norm_g = ctx.load_param(&format!("{lp}.self_attn.k_norm.weight"), false)?;
91 let o_w = ctx.load_param(&format!("{lp}.self_attn.o_proj.weight"), true)?;
92 let post_ln_g = ctx.load_param(&format!("{lp}.post_attention_layernorm.weight"), false)?;
93 let gate_w = ctx.load_param(&format!("{lp}.mlp.gate_proj.weight"), true)?;
94 let up_w = ctx.load_param(&format!("{lp}.mlp.up_proj.weight"), true)?;
95 let down_w = ctx.load_param(&format!("{lp}.mlp.down_proj.weight"), true)?;
96
97 let mut gb = HirMut::new(ctx.hir());
98 let skip = input.id;
99
100 let normed_in = gb.rms_norm(skip, in_ln_g, zero_beta_h, spec.eps);
101 let q = gb.mm(normed_in, q_w);
102 let k = gb.mm(normed_in, k_w);
103 let v = gb.mm(normed_in, v_w);
104
105 let q_normed = per_head_rms(
106 &mut gb,
107 q,
108 q_norm_g,
109 zero_beta_dh,
110 spec.batch,
111 spec.seq,
112 nh,
113 dh,
114 spec.eps,
115 );
116 let k_normed = per_head_rms(
117 &mut gb,
118 k,
119 k_norm_g,
120 zero_beta_dh,
121 spec.batch,
122 spec.seq,
123 nkv,
124 dh,
125 spec.eps,
126 );
127
128 let q_rope = gb.rope(q_normed, cos, sin, dh);
129 let k_rope = gb.rope(k_normed, cos, sin, dh);
130 if let Some(ref sink) = self.kv_sink {
131 sink.lock().expect("qwen3 kv sink").push(k_rope);
132 sink.lock().expect("qwen3 kv sink").push(v);
133 }
134 let k_rep = repeat_kv(&mut gb, k_rope, nkv, dh, group);
135 let v_rep = repeat_kv(&mut gb, v, nkv, dh, group);
136
137 let attn_shape = shape::attention_shape(gb.shape(q_rope));
138 let attn = gb.attention_kind(q_rope, k_rep, v_rep, nh, dh, MaskKind::Causal, attn_shape);
139 let attn_out = gb.mm(attn, o_w);
140 let post_attn = gb.add(skip, attn_out);
141 let normed_post = gb.rms_norm(post_attn, post_ln_g, zero_beta_h, spec.eps);
142
143 let gate = gb.mm(normed_post, gate_w);
144 let up = gb.mm(normed_post, up_w);
145 let gate_act = gb.silu(gate);
146 let swiglu = gb.mul(gate_act, up);
147 let ffn_out = gb.mm(swiglu, down_w);
148 let out = gb.add(post_attn, ffn_out);
149
150 Ok(Some(ctx.wrap(out, spec.hidden_shape.clone())))
151 }
152}
153
154pub(crate) fn per_head_rms(
155 gb: &mut HirMut,
156 x: rlx_ir::HirNodeId,
157 gamma: rlx_ir::HirNodeId,
158 beta: rlx_ir::HirNodeId,
159 batch: usize,
160 seq: usize,
161 heads: usize,
162 head_dim: usize,
163 eps: f32,
164) -> rlx_ir::HirNodeId {
165 let flat = (batch * seq * heads) as i64;
166 let dh = head_dim as i64;
167 let r = gb.reshape_(x, vec![flat, dh]);
168 let n = gb.rms_norm(r, gamma, beta, eps);
169 gb.reshape_(n, vec![batch as i64, seq as i64, (heads * head_dim) as i64])
170}