1use anyhow::Result;
17use rlx_ir::HirGraphExt;
18use rlx_ir::hir::HirMut;
19use rlx_ir::op::MaskKind;
20use rlx_ir::shape;
21
22use super::BlockStage;
23use crate::context::FlowCtx;
24use crate::value::FlowValue;
25#[derive(Debug, Clone)]
26pub struct LlamaDecodeLayerSpec {
27 pub num_heads: usize,
28 pub head_dim: usize,
29 pub num_kv_heads: usize,
30 pub kv_group_size: usize,
31 pub eps: f32,
32 pub use_custom_mask: bool,
33 pub hidden_shape: rlx_ir::Shape,
34}
35
36#[derive(Debug, Clone)]
37pub struct LlamaDecodeLayerStage {
38 pub layer_prefix: String,
39 pub spec: LlamaDecodeLayerSpec,
40 pub layer_idx: usize,
41 pub kv_out: std::sync::Arc<std::sync::Mutex<Vec<rlx_ir::HirNodeId>>>,
42 pub aux_in_out: Option<std::sync::Arc<std::sync::Mutex<Vec<rlx_ir::HirNodeId>>>>,
47}
48
49impl LlamaDecodeLayerStage {
50 pub fn layer(
51 layer_idx: usize,
52 spec: LlamaDecodeLayerSpec,
53 kv_out: std::sync::Arc<std::sync::Mutex<Vec<rlx_ir::HirNodeId>>>,
54 ) -> Self {
55 Self {
56 layer_prefix: format!("model.layers.{layer_idx}"),
57 spec,
58 layer_idx,
59 kv_out,
60 aux_in_out: None,
61 }
62 }
63
64 pub fn with_aux_input_tap(
65 mut self,
66 sink: std::sync::Arc<std::sync::Mutex<Vec<rlx_ir::HirNodeId>>>,
67 ) -> Self {
68 self.aux_in_out = Some(sink);
69 self
70 }
71}
72
73impl BlockStage for LlamaDecodeLayerStage {
74 fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
75 if let Some(sink) = self.aux_in_out.as_ref() {
76 sink.lock().expect("aux in out").push(input.id);
77 }
78
79 let decode = ctx
80 .state
81 .decode
82 .clone()
83 .ok_or_else(|| anyhow::anyhow!("LlamaDecodeLayer requires BindDecodeInputs"))?;
84 let zero_beta = ctx
85 .state
86 .zero_beta
87 .ok_or_else(|| anyhow::anyhow!("LlamaDecodeLayer requires ZeroBeta"))?;
88
89 let lp = &self.layer_prefix;
90 let spec = &self.spec;
91 let in_ln_g = ctx.load_param(&format!("{lp}.input_layernorm.weight"), false)?;
92 let q_w = ctx.load_param(&format!("{lp}.self_attn.q_proj.weight"), true)?;
93 let k_w = ctx.load_param(&format!("{lp}.self_attn.k_proj.weight"), true)?;
94 let v_w = ctx.load_param(&format!("{lp}.self_attn.v_proj.weight"), true)?;
95 let o_w = ctx.load_param(&format!("{lp}.self_attn.o_proj.weight"), true)?;
96 let post_ln_g = ctx.load_param(&format!("{lp}.post_attention_layernorm.weight"), false)?;
97 let gate_w = ctx.load_param(&format!("{lp}.mlp.gate_proj.weight"), true)?;
98 let up_w = ctx.load_param(&format!("{lp}.mlp.up_proj.weight"), true)?;
99 let down_w = ctx.load_param(&format!("{lp}.mlp.down_proj.weight"), true)?;
100
101 let past_k = decode.past_k.get(self.layer_idx);
102 let past_v = decode.past_v.get(self.layer_idx);
103
104 let mut gb = HirMut::new(ctx.hir());
105 let normed_in = gb.rms_norm(input.id, in_ln_g, zero_beta, spec.eps);
106 let q = gb.mm(normed_in, q_w);
107 let k = gb.mm(normed_in, k_w);
108 let v = gb.mm(normed_in, v_w);
109
110 let q_rope = gb.rope(q, decode.cos, decode.sin, spec.head_dim);
111 let k_rope = gb.rope(k, decode.cos, decode.sin, spec.head_dim);
112
113 let (new_k, new_v) = match (past_k, past_v) {
114 (Some(past_k), Some(past_v)) => (
115 gb.concat_(vec![*past_k, k_rope], 1),
116 gb.concat_(vec![*past_v, v], 1),
117 ),
118 _ => (k_rope, v),
119 };
120 self.kv_out.lock().expect("kv out").push(new_k);
121 self.kv_out.lock().expect("kv out").push(new_v);
122
123 let k_rep = super::self_attn::repeat_kv(
124 &mut gb,
125 new_k,
126 spec.num_kv_heads,
127 spec.head_dim,
128 spec.kv_group_size,
129 );
130 let v_rep = super::self_attn::repeat_kv(
131 &mut gb,
132 new_v,
133 spec.num_kv_heads,
134 spec.head_dim,
135 spec.kv_group_size,
136 );
137
138 let attn_shape = shape::attention_shape(gb.shape(q_rope));
139 let attn = if spec.use_custom_mask {
140 let mask = decode
141 .mask
142 .ok_or_else(|| anyhow::anyhow!("custom mask requested but not bound"))?;
143 gb.attention(
144 q_rope,
145 k_rep,
146 v_rep,
147 mask,
148 spec.num_heads,
149 spec.head_dim,
150 attn_shape,
151 )
152 } else {
153 gb.attention_kind(
154 q_rope,
155 k_rep,
156 v_rep,
157 spec.num_heads,
158 spec.head_dim,
159 MaskKind::Causal,
160 attn_shape,
161 )
162 };
163
164 let attn_out = gb.mm(attn, o_w);
165 let post_attn = gb.add(input.id, attn_out);
166 let normed_post = gb.rms_norm(post_attn, post_ln_g, zero_beta, spec.eps);
167 let gate = gb.mm(normed_post, gate_w);
168 let up = gb.mm(normed_post, up_w);
169 let gate_act = gb.silu(gate);
170 let swiglu = gb.mul(gate_act, up);
171 let ffn_out = gb.mm(swiglu, down_w);
172 let h_id = gb.add(post_attn, ffn_out);
173
174 Ok(Some(ctx.wrap(h_id, spec.hidden_shape.clone())))
175 }
176}