rlx_flow/blocks/
llama_decode_layer.rs1use anyhow::Result;
17use rlx_ir::HirGraphExt;
18use rlx_ir::hir::HirMut;
19use rlx_ir::op::MaskKind;
20use rlx_ir::shape;
21
22use super::BlockStage;
23use crate::context::FlowCtx;
24use crate::value::FlowValue;
25#[derive(Debug, Clone)]
26pub struct LlamaDecodeLayerSpec {
27 pub num_heads: usize,
28 pub head_dim: usize,
29 pub num_kv_heads: usize,
30 pub kv_group_size: usize,
31 pub eps: f32,
32 pub use_custom_mask: bool,
33 pub hidden_shape: rlx_ir::Shape,
34}
35
36#[derive(Debug, Clone)]
37pub struct LlamaDecodeLayerStage {
38 pub layer_prefix: String,
39 pub spec: LlamaDecodeLayerSpec,
40 pub layer_idx: usize,
41 pub kv_out: std::sync::Arc<std::sync::Mutex<Vec<rlx_ir::HirNodeId>>>,
42}
43
44impl LlamaDecodeLayerStage {
45 pub fn layer(
46 layer_idx: usize,
47 spec: LlamaDecodeLayerSpec,
48 kv_out: std::sync::Arc<std::sync::Mutex<Vec<rlx_ir::HirNodeId>>>,
49 ) -> Self {
50 Self {
51 layer_prefix: format!("model.layers.{layer_idx}"),
52 spec,
53 layer_idx,
54 kv_out,
55 }
56 }
57}
58
59impl BlockStage for LlamaDecodeLayerStage {
60 fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
61 let decode = ctx
62 .state
63 .decode
64 .clone()
65 .ok_or_else(|| anyhow::anyhow!("LlamaDecodeLayer requires BindDecodeInputs"))?;
66 let zero_beta = ctx
67 .state
68 .zero_beta
69 .ok_or_else(|| anyhow::anyhow!("LlamaDecodeLayer requires ZeroBeta"))?;
70
71 let lp = &self.layer_prefix;
72 let spec = &self.spec;
73 let in_ln_g = ctx.load_param(&format!("{lp}.input_layernorm.weight"), false)?;
74 let q_w = ctx.load_param(&format!("{lp}.self_attn.q_proj.weight"), true)?;
75 let k_w = ctx.load_param(&format!("{lp}.self_attn.k_proj.weight"), true)?;
76 let v_w = ctx.load_param(&format!("{lp}.self_attn.v_proj.weight"), true)?;
77 let o_w = ctx.load_param(&format!("{lp}.self_attn.o_proj.weight"), true)?;
78 let post_ln_g = ctx.load_param(&format!("{lp}.post_attention_layernorm.weight"), false)?;
79 let gate_w = ctx.load_param(&format!("{lp}.mlp.gate_proj.weight"), true)?;
80 let up_w = ctx.load_param(&format!("{lp}.mlp.up_proj.weight"), true)?;
81 let down_w = ctx.load_param(&format!("{lp}.mlp.down_proj.weight"), true)?;
82
83 let past_k = decode.past_k[self.layer_idx];
84 let past_v = decode.past_v[self.layer_idx];
85
86 let mut gb = HirMut::new(ctx.hir());
87 let normed_in = gb.rms_norm(input.id, in_ln_g, zero_beta, spec.eps);
88 let q = gb.mm(normed_in, q_w);
89 let k = gb.mm(normed_in, k_w);
90 let v = gb.mm(normed_in, v_w);
91
92 let q_rope = gb.rope(q, decode.cos, decode.sin, spec.head_dim);
93 let k_rope = gb.rope(k, decode.cos, decode.sin, spec.head_dim);
94
95 let new_k = gb.concat_(vec![past_k, k_rope], 1);
96 let new_v = gb.concat_(vec![past_v, v], 1);
97 self.kv_out.lock().expect("kv out").push(new_k);
98 self.kv_out.lock().expect("kv out").push(new_v);
99
100 let k_rep = super::self_attn::repeat_kv(
101 &mut gb,
102 new_k,
103 spec.num_kv_heads,
104 spec.head_dim,
105 spec.kv_group_size,
106 );
107 let v_rep = super::self_attn::repeat_kv(
108 &mut gb,
109 new_v,
110 spec.num_kv_heads,
111 spec.head_dim,
112 spec.kv_group_size,
113 );
114
115 let attn_shape = shape::attention_shape(gb.shape(q_rope));
116 let attn = if spec.use_custom_mask {
117 let mask = decode
118 .mask
119 .ok_or_else(|| anyhow::anyhow!("custom mask requested but not bound"))?;
120 gb.attention(
121 q_rope,
122 k_rep,
123 v_rep,
124 mask,
125 spec.num_heads,
126 spec.head_dim,
127 attn_shape,
128 )
129 } else {
130 gb.attention_kind(
131 q_rope,
132 k_rep,
133 v_rep,
134 spec.num_heads,
135 spec.head_dim,
136 MaskKind::Causal,
137 attn_shape,
138 )
139 };
140
141 let attn_out = gb.mm(attn, o_w);
142 let post_attn = gb.add(input.id, attn_out);
143 let normed_post = gb.rms_norm(post_attn, post_ln_g, zero_beta, spec.eps);
144 let gate = gb.mm(normed_post, gate_w);
145 let up = gb.mm(normed_post, up_w);
146 let gate_act = gb.silu(gate);
147 let swiglu = gb.mul(gate_act, up);
148 let ffn_out = gb.mm(swiglu, down_w);
149 let h_id = gb.add(post_attn, ffn_out);
150
151 Ok(Some(ctx.wrap(h_id, spec.hidden_shape.clone())))
152 }
153}