1use anyhow::Result;
5use rlx_ir::HirGraphExt;
6use rlx_ir::hir::HirMut;
7use rlx_ir::op::MaskKind;
8use rlx_ir::shape;
9use rlx_ir::{DType, Shape};
10
11use std::sync::{Arc, Mutex};
12
13use super::{BlockStage, GemmaLayerStyle};
14use crate::context::FlowCtx;
15use crate::value::FlowValue;
16
17#[derive(Debug, Clone)]
18pub struct GemmaDecodeLayerSpec {
19 pub style: GemmaLayerStyle,
20 pub num_heads: usize,
21 pub head_dim: usize,
22 pub num_kv_heads: usize,
23 pub kv_group_size: usize,
24 pub eps: f32,
25 pub use_custom_mask: bool,
26 pub hidden_shape: rlx_ir::Shape,
27 pub mask: MaskKind,
28 pub score_scale: Option<f32>,
29 pub attn_logit_softcap: Option<f32>,
30}
31
32#[derive(Debug, Clone)]
33pub struct GemmaDecodeLayerStage {
34 pub layer_prefix: String,
35 pub spec: GemmaDecodeLayerSpec,
36 pub layer_idx: usize,
37 pub kv_out: Arc<Mutex<Vec<rlx_ir::HirNodeId>>>,
38}
39
40impl GemmaDecodeLayerStage {
41 pub fn layer(
42 layer_idx: usize,
43 spec: GemmaDecodeLayerSpec,
44 kv_out: Arc<Mutex<Vec<rlx_ir::HirNodeId>>>,
45 ) -> Self {
46 Self {
47 layer_prefix: format!("model.layers.{layer_idx}"),
48 spec,
49 layer_idx,
50 kv_out,
51 }
52 }
53}
54
55impl BlockStage for GemmaDecodeLayerStage {
56 fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
57 let decode = ctx
58 .state
59 .decode
60 .clone()
61 .ok_or_else(|| anyhow::anyhow!("GemmaDecodeLayer requires BindDecodeInputs"))?;
62 let zero_beta = ctx
63 .state
64 .zero_beta
65 .ok_or_else(|| anyhow::anyhow!("GemmaDecodeLayer requires ZeroBeta"))?;
66
67 let lp = &self.layer_prefix;
68 let spec = &self.spec;
69 let style = spec.style;
70
71 let in_ln_w = ctx.load_param(&format!("{lp}.input_layernorm.weight"), false)?;
72 let in_ln_len = norm_len(ctx, in_ln_w)?;
73 let in_ln_ones = ctx.synth_param(
74 &format!("{lp}.input_layernorm.ones"),
75 vec![1.0f32; in_ln_len],
76 Shape::new(&[in_ln_len], DType::F32),
77 );
78
79 let pre_ffn_key = if matches!(
80 style,
81 GemmaLayerStyle::Gemma2 | GemmaLayerStyle::Gemma3 | GemmaLayerStyle::Gemma4
82 ) {
83 format!("{lp}.pre_feedforward_layernorm")
84 } else {
85 format!("{lp}.post_attention_layernorm")
86 };
87 let pre_ffn_w = ctx.load_param(&format!("{pre_ffn_key}.weight"), false)?;
88 let pre_ffn_len = norm_len(ctx, pre_ffn_w)?;
89 let pre_ffn_ones = ctx.synth_param(
90 &format!("{pre_ffn_key}.ones"),
91 vec![1.0f32; pre_ffn_len],
92 Shape::new(&[pre_ffn_len], DType::F32),
93 );
94
95 let post_ffn = if matches!(
96 style,
97 GemmaLayerStyle::Gemma2 | GemmaLayerStyle::Gemma3 | GemmaLayerStyle::Gemma4
98 ) {
99 let post_key = format!("{lp}.post_feedforward_layernorm");
100 let w = ctx.load_param(&format!("{post_key}.weight"), false)?;
101 let len = norm_len(ctx, w)?;
102 let ones = ctx.synth_param(
103 &format!("{post_key}.ones"),
104 vec![1.0f32; len],
105 Shape::new(&[len], DType::F32),
106 );
107 Some((w, ones))
108 } else {
109 None
110 };
111
112 let q_w = ctx.load_param(&format!("{lp}.self_attn.q_proj.weight"), true)?;
113 let k_w = ctx.load_param(&format!("{lp}.self_attn.k_proj.weight"), true)?;
114 let v_w = ctx.load_param(&format!("{lp}.self_attn.v_proj.weight"), true)?;
115 let o_w = ctx.load_param(&format!("{lp}.self_attn.o_proj.weight"), true)?;
116 let gate_w = ctx.load_param(&format!("{lp}.mlp.gate_proj.weight"), true)?;
117 let up_w = ctx.load_param(&format!("{lp}.mlp.up_proj.weight"), true)?;
118 let down_w = ctx.load_param(&format!("{lp}.mlp.down_proj.weight"), true)?;
119
120 let past_k = decode.past_k[self.layer_idx];
121 let past_v = decode.past_v[self.layer_idx];
122
123 let mut gb = HirMut::new(ctx.hir());
124 let in_gamma = gb.add(in_ln_ones, in_ln_w);
125 let normed_in = gb.rms_norm(input.id, in_gamma, zero_beta, spec.eps);
126 let q = gb.mm(normed_in, q_w);
127 let k = gb.mm(normed_in, k_w);
128 let v = gb.mm(normed_in, v_w);
129 let q_rope = gb.rope(q, decode.cos, decode.sin, spec.head_dim);
130 let k_rope = gb.rope(k, decode.cos, decode.sin, spec.head_dim);
131
132 let new_k = gb.concat_(vec![past_k, k_rope], 1);
133 let new_v = gb.concat_(vec![past_v, v], 1);
134 self.kv_out.lock().expect("kv out").push(new_k);
135 self.kv_out.lock().expect("kv out").push(new_v);
136
137 let k_rep = super::self_attn::repeat_kv(
138 &mut gb,
139 new_k,
140 spec.num_kv_heads,
141 spec.head_dim,
142 spec.kv_group_size,
143 );
144 let v_rep = super::self_attn::repeat_kv(
145 &mut gb,
146 new_v,
147 spec.num_kv_heads,
148 spec.head_dim,
149 spec.kv_group_size,
150 );
151
152 let attn_shape = shape::attention_shape(gb.shape(q_rope));
153 let attn = if spec.use_custom_mask {
154 let mask = decode
155 .mask
156 .ok_or_else(|| anyhow::anyhow!("custom mask requested but not bound"))?;
157 gb.attention_(q_rope, k_rep, v_rep, mask, spec.num_heads, spec.head_dim)
158 } else {
159 gb.attention_kind_opts(
160 q_rope,
161 k_rep,
162 v_rep,
163 spec.num_heads,
164 spec.head_dim,
165 spec.mask,
166 attn_shape,
167 spec.score_scale,
168 spec.attn_logit_softcap,
169 )
170 };
171
172 let attn_out = gb.mm(attn, o_w);
173 let post_attn = gb.add(input.id, attn_out);
174
175 let pre_gamma = gb.add(pre_ffn_ones, pre_ffn_w);
176 let mut h = gb.rms_norm(post_attn, pre_gamma, zero_beta, spec.eps);
177 let gate = gb.mm(h, gate_w);
178 let up = gb.mm(h, up_w);
179 let gate_act = gb.gelu_approx(gate);
180 h = gb.mul(gate_act, up);
181 h = gb.mm(h, down_w);
182
183 if let Some((post_w, post_ones)) = post_ffn {
184 let post_gamma = gb.add(post_ones, post_w);
185 h = gb.rms_norm(h, post_gamma, zero_beta, spec.eps);
186 }
187
188 let out_id = gb.add(post_attn, h);
189 Ok(Some(ctx.wrap(out_id, spec.hidden_shape.clone())))
190 }
191}
192
193fn norm_len(ctx: &FlowCtx<'_>, weight: rlx_ir::HirNodeId) -> Result<usize> {
194 match ctx.node_shape(weight)?.dims().last() {
195 Some(rlx_ir::shape::Dim::Static(n)) => Ok(*n),
196 _ => Ok(0),
197 }
198}