1use anyhow::Result;
17use rlx_ir::HirGraphExt;
18use rlx_ir::hir::HirMut;
19use rlx_ir::op::MaskKind;
20use rlx_ir::shape;
21
22use std::sync::{Arc, Mutex};
23
24use super::BlockStage;
25use super::qwen3_decoder::per_head_rms;
26use super::self_attn::repeat_kv;
27use crate::context::FlowCtx;
28use crate::value::FlowValue;
29
30#[derive(Debug, Clone)]
31pub struct Qwen3DecodeLayerSpec {
32 pub num_heads: usize,
33 pub num_kv_heads: usize,
34 pub head_dim: usize,
35 pub kv_group_size: usize,
36 pub eps: f32,
37 pub use_custom_mask: bool,
38 pub hidden_shape: rlx_ir::Shape,
39 pub batch: usize,
40 pub qk_norm: bool,
41 pub attention_bias: bool,
42}
43
44#[derive(Debug, Clone)]
45pub struct Qwen3DecodeLayerStage {
46 pub layer_prefix: String,
47 pub spec: Qwen3DecodeLayerSpec,
48 pub layer_idx: usize,
49 pub kv_out: Arc<Mutex<Vec<rlx_ir::HirNodeId>>>,
50}
51
52impl Qwen3DecodeLayerStage {
53 pub fn layer(
54 layer_idx: usize,
55 spec: Qwen3DecodeLayerSpec,
56 kv_out: Arc<Mutex<Vec<rlx_ir::HirNodeId>>>,
57 ) -> Self {
58 Self {
59 layer_prefix: format!("model.layers.{layer_idx}"),
60 spec,
61 layer_idx,
62 kv_out,
63 }
64 }
65}
66
67impl BlockStage for Qwen3DecodeLayerStage {
68 fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
69 let decode = ctx
70 .state
71 .decode
72 .clone()
73 .ok_or_else(|| anyhow::anyhow!("Qwen3DecodeLayer requires BindDecodeInputs"))?;
74 let zero_beta_h = ctx
75 .state
76 .zero_beta
77 .ok_or_else(|| anyhow::anyhow!("Qwen3DecodeLayer requires ZeroBeta"))?;
78 let zero_beta_dh = ctx
79 .state
80 .named
81 .get("zero_beta.head")
82 .copied()
83 .ok_or_else(|| anyhow::anyhow!("Qwen3DecodeLayer requires zero_beta.head"))?;
84
85 let lp = &self.layer_prefix;
86 let spec = &self.spec;
87 let nh = spec.num_heads;
88 let nkv = spec.num_kv_heads;
89 let dh = spec.head_dim;
90 let batch = spec.batch;
91
92 let in_ln_g = ctx.load_param(&format!("{lp}.input_layernorm.weight"), false)?;
93 let q_w = ctx.load_param(&format!("{lp}.self_attn.q_proj.weight"), true)?;
94 let k_w = ctx.load_param(&format!("{lp}.self_attn.k_proj.weight"), true)?;
95 let v_w = ctx.load_param(&format!("{lp}.self_attn.v_proj.weight"), true)?;
96 let o_w = ctx.load_param(&format!("{lp}.self_attn.o_proj.weight"), true)?;
97 let post_ln_g = ctx.load_param(&format!("{lp}.post_attention_layernorm.weight"), false)?;
98 let gate_w = ctx.load_param(&format!("{lp}.mlp.gate_proj.weight"), true)?;
99 let up_w = ctx.load_param(&format!("{lp}.mlp.up_proj.weight"), true)?;
100 let down_w = ctx.load_param(&format!("{lp}.mlp.down_proj.weight"), true)?;
101 let (q_bias, k_bias, v_bias) = if spec.attention_bias {
102 (
103 Some(ctx.load_param(&format!("{lp}.self_attn.q_proj.bias"), false)?),
104 Some(ctx.load_param(&format!("{lp}.self_attn.k_proj.bias"), false)?),
105 Some(ctx.load_param(&format!("{lp}.self_attn.v_proj.bias"), false)?),
106 )
107 } else {
108 (None, None, None)
109 };
110 let (q_norm_g, k_norm_g) = if spec.qk_norm {
111 (
112 Some(ctx.load_param(&format!("{lp}.self_attn.q_norm.weight"), false)?),
113 Some(ctx.load_param(&format!("{lp}.self_attn.k_norm.weight"), false)?),
114 )
115 } else {
116 (None, None)
117 };
118
119 let past_k = decode.past_k[self.layer_idx];
120 let past_v = decode.past_v[self.layer_idx];
121
122 let mut gb = HirMut::new(ctx.hir());
123 let skip = input.id;
124 let normed_in = gb.rms_norm(skip, in_ln_g, zero_beta_h, spec.eps);
125 let mut q = gb.mm(normed_in, q_w);
126 let mut k = gb.mm(normed_in, k_w);
127 let mut v = gb.mm(normed_in, v_w);
128
129 if let (Some(qb), Some(kb), Some(vb)) = (q_bias, k_bias, v_bias) {
130 q = gb.add(q, qb);
131 k = gb.add(k, kb);
132 v = gb.add(v, vb);
133 }
134
135 let (q_rope_in, k_rope_in) = if let (Some(qng), Some(kng)) = (q_norm_g, k_norm_g) {
136 let q_normed = per_head_rms(&mut gb, q, qng, zero_beta_dh, batch, 1, nh, dh, spec.eps);
137 let k_normed = per_head_rms(&mut gb, k, kng, zero_beta_dh, batch, 1, nkv, dh, spec.eps);
138 (q_normed, k_normed)
139 } else {
140 (q, k)
141 };
142
143 let q_rope = gb.rope(q_rope_in, decode.cos, decode.sin, dh);
144 let k_rope = gb.rope(k_rope_in, decode.cos, decode.sin, dh);
145
146 let new_k = gb.concat_(vec![past_k, k_rope], 1);
147 let new_v = gb.concat_(vec![past_v, v], 1);
148 self.kv_out.lock().expect("kv out").push(new_k);
149 self.kv_out.lock().expect("kv out").push(new_v);
150
151 let k_rep = repeat_kv(&mut gb, new_k, nkv, dh, spec.kv_group_size);
152 let v_rep = repeat_kv(&mut gb, new_v, nkv, dh, spec.kv_group_size);
153
154 let attn_shape = shape::attention_shape(gb.shape(q_rope));
155 let attn = if spec.use_custom_mask {
156 let mask = decode
157 .mask
158 .ok_or_else(|| anyhow::anyhow!("custom mask requested but not bound"))?;
159 gb.attention(q_rope, k_rep, v_rep, mask, nh, dh, attn_shape)
160 } else {
161 gb.attention_kind(q_rope, k_rep, v_rep, nh, dh, MaskKind::Causal, attn_shape)
162 };
163
164 let attn_out = gb.mm(attn, o_w);
165 let post_attn = gb.add(skip, attn_out);
166 let normed_post = gb.rms_norm(post_attn, post_ln_g, zero_beta_h, spec.eps);
167 let gate = gb.mm(normed_post, gate_w);
168 let up = gb.mm(normed_post, up_w);
169 let gate_act = gb.silu(gate);
170 let swiglu = gb.mul(gate_act, up);
171 let ffn_out = gb.mm(swiglu, down_w);
172 let h_id = gb.add(post_attn, ffn_out);
173
174 Ok(Some(ctx.wrap(h_id, spec.hidden_shape.clone())))
175 }
176}