use anyhow::Result;
use rlx_ir::HirGraphExt;
use rlx_ir::hir::HirMut;
use rlx_ir::op::MaskKind;
use rlx_ir::shape;
use std::sync::{Arc, Mutex};
use super::BlockStage;
use super::self_attn::repeat_kv;
use crate::context::FlowCtx;
use crate::value::FlowValue;
#[derive(Debug, Clone)]
pub struct Qwen3DecoderSpec {
pub num_heads: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub eps: f32,
pub hidden_shape: rlx_ir::Shape,
pub batch: usize,
pub seq: usize,
pub qk_norm: bool,
pub attention_bias: bool,
}
#[derive(Debug, Clone)]
pub struct Qwen3DecoderStage {
pub layer_prefix: String,
pub spec: Qwen3DecoderSpec,
pub kv_sink: Option<Arc<Mutex<Vec<rlx_ir::HirNodeId>>>>,
}
impl Qwen3DecoderStage {
pub fn layer(layer_idx: usize, spec: Qwen3DecoderSpec) -> Self {
Self {
layer_prefix: format!("model.layers.{layer_idx}"),
spec,
kv_sink: None,
}
}
pub fn layer_with_kv(
layer_idx: usize,
spec: Qwen3DecoderSpec,
kv_sink: Arc<Mutex<Vec<rlx_ir::HirNodeId>>>,
) -> Self {
Self {
layer_prefix: format!("model.layers.{layer_idx}"),
spec,
kv_sink: Some(kv_sink),
}
}
}
impl BlockStage for Qwen3DecoderStage {
fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
let lp = &self.layer_prefix;
let spec = &self.spec;
let nh = spec.num_heads;
let nkv = spec.num_kv_heads;
let dh = spec.head_dim;
let group = nh / nkv;
let zero_beta_h = ctx
.state
.zero_beta
.ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires ZeroBeta"))?;
let zero_beta_dh = ctx
.state
.named
.get("zero_beta.head")
.copied()
.ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires zero_beta.head"))?;
let cos = ctx
.state
.rope_cos
.ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires RopeTables"))?;
let sin = ctx
.state
.rope_sin
.ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires RopeTables"))?;
let in_ln_g = ctx.load_param(&format!("{lp}.input_layernorm.weight"), false)?;
let q_w = ctx.load_param(&format!("{lp}.self_attn.q_proj.weight"), true)?;
let k_w = ctx.load_param(&format!("{lp}.self_attn.k_proj.weight"), true)?;
let v_w = ctx.load_param(&format!("{lp}.self_attn.v_proj.weight"), true)?;
let o_w = ctx.load_param(&format!("{lp}.self_attn.o_proj.weight"), true)?;
let post_ln_g = ctx.load_param(&format!("{lp}.post_attention_layernorm.weight"), false)?;
let gate_w = ctx.load_param(&format!("{lp}.mlp.gate_proj.weight"), true)?;
let up_w = ctx.load_param(&format!("{lp}.mlp.up_proj.weight"), true)?;
let down_w = ctx.load_param(&format!("{lp}.mlp.down_proj.weight"), true)?;
let (q_bias, k_bias, v_bias) = if spec.attention_bias {
(
Some(ctx.load_param(&format!("{lp}.self_attn.q_proj.bias"), false)?),
Some(ctx.load_param(&format!("{lp}.self_attn.k_proj.bias"), false)?),
Some(ctx.load_param(&format!("{lp}.self_attn.v_proj.bias"), false)?),
)
} else {
(None, None, None)
};
let (q_norm_g, k_norm_g) = if spec.qk_norm {
(
Some(ctx.load_param(&format!("{lp}.self_attn.q_norm.weight"), false)?),
Some(ctx.load_param(&format!("{lp}.self_attn.k_norm.weight"), false)?),
)
} else {
(None, None)
};
let mut gb = HirMut::new(ctx.hir());
let skip = input.id;
let normed_in = gb.rms_norm(skip, in_ln_g, zero_beta_h, spec.eps);
let mut q = gb.mm(normed_in, q_w);
let mut k = gb.mm(normed_in, k_w);
let mut v = gb.mm(normed_in, v_w);
if let (Some(qb), Some(kb), Some(vb)) = (q_bias, k_bias, v_bias) {
q = gb.add(q, qb);
k = gb.add(k, kb);
v = gb.add(v, vb);
}
let (q_rope_in, k_rope_in) = if let (Some(qng), Some(kng)) = (q_norm_g, k_norm_g) {
let q_normed = per_head_rms(
&mut gb,
q,
qng,
zero_beta_dh,
spec.batch,
spec.seq,
nh,
dh,
spec.eps,
);
let k_normed = per_head_rms(
&mut gb,
k,
kng,
zero_beta_dh,
spec.batch,
spec.seq,
nkv,
dh,
spec.eps,
);
(q_normed, k_normed)
} else {
(q, k)
};
let q_rope = gb.rope(q_rope_in, cos, sin, dh);
let k_rope = gb.rope(k_rope_in, cos, sin, dh);
if let Some(ref sink) = self.kv_sink {
sink.lock().expect("qwen3 kv sink").push(k_rope);
sink.lock().expect("qwen3 kv sink").push(v);
}
let k_rep = repeat_kv(&mut gb, k_rope, nkv, dh, group);
let v_rep = repeat_kv(&mut gb, v, nkv, dh, group);
let attn_shape = shape::attention_shape(gb.shape(q_rope));
let attn = gb.attention_kind(q_rope, k_rep, v_rep, nh, dh, MaskKind::Causal, attn_shape);
let attn_out = gb.mm(attn, o_w);
let post_attn = gb.add(skip, attn_out);
let normed_post = gb.rms_norm(post_attn, post_ln_g, zero_beta_h, spec.eps);
let gate = gb.mm(normed_post, gate_w);
let up = gb.mm(normed_post, up_w);
let gate_act = gb.silu(gate);
let swiglu = gb.mul(gate_act, up);
let ffn_out = gb.mm(swiglu, down_w);
let out = gb.add(post_attn, ffn_out);
Ok(Some(ctx.wrap(out, spec.hidden_shape.clone())))
}
}
pub(crate) fn per_head_rms(
gb: &mut HirMut,
x: rlx_ir::HirNodeId,
gamma: rlx_ir::HirNodeId,
beta: rlx_ir::HirNodeId,
batch: usize,
seq: usize,
heads: usize,
head_dim: usize,
eps: f32,
) -> rlx_ir::HirNodeId {
let flat = (batch * seq * heads) as i64;
let dh = head_dim as i64;
let r = gb.reshape_(x, vec![flat, dh]);
let n = gb.rms_norm(r, gamma, beta, eps);
gb.reshape_(n, vec![batch as i64, seq as i64, (heads * head_dim) as i64])
}