1use anyhow::Result;
5use rlx_ir::HirGraphExt;
6use rlx_ir::hir::HirMut;
7use rlx_ir::op::MaskKind;
8use rlx_ir::shape;
9
10use std::sync::{Arc, Mutex};
11
12use super::BlockStage;
13use super::qwen3_decoder::per_head_rms;
14use super::self_attn::repeat_kv;
15use crate::context::FlowCtx;
16use crate::value::FlowValue;
17
18#[derive(Debug, Clone)]
19pub struct Qwen3DecodeLayerSpec {
20 pub num_heads: usize,
21 pub num_kv_heads: usize,
22 pub head_dim: usize,
23 pub kv_group_size: usize,
24 pub eps: f32,
25 pub use_custom_mask: bool,
26 pub hidden_shape: rlx_ir::Shape,
27 pub batch: usize,
28 pub qk_norm: bool,
29 pub attention_bias: bool,
30}
31
32#[derive(Debug, Clone)]
33pub struct Qwen3DecodeLayerStage {
34 pub layer_prefix: String,
35 pub spec: Qwen3DecodeLayerSpec,
36 pub layer_idx: usize,
37 pub kv_out: Arc<Mutex<Vec<rlx_ir::HirNodeId>>>,
38}
39
40impl Qwen3DecodeLayerStage {
41 pub fn layer(
42 layer_idx: usize,
43 spec: Qwen3DecodeLayerSpec,
44 kv_out: Arc<Mutex<Vec<rlx_ir::HirNodeId>>>,
45 ) -> Self {
46 Self {
47 layer_prefix: format!("model.layers.{layer_idx}"),
48 spec,
49 layer_idx,
50 kv_out,
51 }
52 }
53}
54
55impl BlockStage for Qwen3DecodeLayerStage {
56 fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
57 let decode = ctx
58 .state
59 .decode
60 .clone()
61 .ok_or_else(|| anyhow::anyhow!("Qwen3DecodeLayer requires BindDecodeInputs"))?;
62 let zero_beta_h = ctx
63 .state
64 .zero_beta
65 .ok_or_else(|| anyhow::anyhow!("Qwen3DecodeLayer requires ZeroBeta"))?;
66 let zero_beta_dh = ctx
67 .state
68 .named
69 .get("zero_beta.head")
70 .copied()
71 .ok_or_else(|| anyhow::anyhow!("Qwen3DecodeLayer requires zero_beta.head"))?;
72
73 let lp = &self.layer_prefix;
74 let spec = &self.spec;
75 let nh = spec.num_heads;
76 let nkv = spec.num_kv_heads;
77 let dh = spec.head_dim;
78 let batch = spec.batch;
79
80 let in_ln_g = ctx.load_param(&format!("{lp}.input_layernorm.weight"), false)?;
81 let q_w = ctx.load_param(&format!("{lp}.self_attn.q_proj.weight"), true)?;
82 let k_w = ctx.load_param(&format!("{lp}.self_attn.k_proj.weight"), true)?;
83 let v_w = ctx.load_param(&format!("{lp}.self_attn.v_proj.weight"), true)?;
84 let o_w = ctx.load_param(&format!("{lp}.self_attn.o_proj.weight"), true)?;
85 let post_ln_g = ctx.load_param(&format!("{lp}.post_attention_layernorm.weight"), false)?;
86 let gate_w = ctx.load_param(&format!("{lp}.mlp.gate_proj.weight"), true)?;
87 let up_w = ctx.load_param(&format!("{lp}.mlp.up_proj.weight"), true)?;
88 let down_w = ctx.load_param(&format!("{lp}.mlp.down_proj.weight"), true)?;
89 let (q_bias, k_bias, v_bias) = if spec.attention_bias {
90 (
91 Some(ctx.load_param(&format!("{lp}.self_attn.q_proj.bias"), false)?),
92 Some(ctx.load_param(&format!("{lp}.self_attn.k_proj.bias"), false)?),
93 Some(ctx.load_param(&format!("{lp}.self_attn.v_proj.bias"), false)?),
94 )
95 } else {
96 (None, None, None)
97 };
98 let (q_norm_g, k_norm_g) = if spec.qk_norm {
99 (
100 Some(ctx.load_param(&format!("{lp}.self_attn.q_norm.weight"), false)?),
101 Some(ctx.load_param(&format!("{lp}.self_attn.k_norm.weight"), false)?),
102 )
103 } else {
104 (None, None)
105 };
106
107 let past_k = decode.past_k[self.layer_idx];
108 let past_v = decode.past_v[self.layer_idx];
109
110 let mut gb = HirMut::new(ctx.hir());
111 let skip = input.id;
112 let normed_in = gb.rms_norm(skip, in_ln_g, zero_beta_h, spec.eps);
113 let mut q = gb.mm(normed_in, q_w);
114 let mut k = gb.mm(normed_in, k_w);
115 let mut v = gb.mm(normed_in, v_w);
116
117 if let (Some(qb), Some(kb), Some(vb)) = (q_bias, k_bias, v_bias) {
118 q = gb.add(q, qb);
119 k = gb.add(k, kb);
120 v = gb.add(v, vb);
121 }
122
123 let (q_rope_in, k_rope_in) = if let (Some(qng), Some(kng)) = (q_norm_g, k_norm_g) {
124 let q_normed = per_head_rms(&mut gb, q, qng, zero_beta_dh, batch, 1, nh, dh, spec.eps);
125 let k_normed = per_head_rms(&mut gb, k, kng, zero_beta_dh, batch, 1, nkv, dh, spec.eps);
126 (q_normed, k_normed)
127 } else {
128 (q, k)
129 };
130
131 let q_rope = gb.rope(q_rope_in, decode.cos, decode.sin, dh);
132 let k_rope = gb.rope(k_rope_in, decode.cos, decode.sin, dh);
133
134 let new_k = gb.concat_(vec![past_k, k_rope], 1);
135 let new_v = gb.concat_(vec![past_v, v], 1);
136 self.kv_out.lock().expect("kv out").push(new_k);
137 self.kv_out.lock().expect("kv out").push(new_v);
138
139 let k_rep = repeat_kv(&mut gb, new_k, nkv, dh, spec.kv_group_size);
140 let v_rep = repeat_kv(&mut gb, new_v, nkv, dh, spec.kv_group_size);
141
142 let attn_shape = shape::attention_shape(gb.shape(q_rope));
143 let attn = if spec.use_custom_mask {
144 let mask = decode
145 .mask
146 .ok_or_else(|| anyhow::anyhow!("custom mask requested but not bound"))?;
147 gb.attention(q_rope, k_rep, v_rep, mask, nh, dh, attn_shape)
148 } else {
149 gb.attention_kind(q_rope, k_rep, v_rep, nh, dh, MaskKind::Causal, attn_shape)
150 };
151
152 let attn_out = gb.mm(attn, o_w);
153 let post_attn = gb.add(skip, attn_out);
154 let normed_post = gb.rms_norm(post_attn, post_ln_g, zero_beta_h, spec.eps);
155 let gate = gb.mm(normed_post, gate_w);
156 let up = gb.mm(normed_post, up_w);
157 let gate_act = gb.silu(gate);
158 let swiglu = gb.mul(gate_act, up);
159 let ffn_out = gb.mm(swiglu, down_w);
160 let h_id = gb.add(post_attn, ffn_out);
161
162 Ok(Some(ctx.wrap(h_id, spec.hidden_shape.clone())))
163 }
164}