rlx_flow/blocks/
llama_decode_layer.rs1use anyhow::Result;
5use rlx_ir::HirGraphExt;
6use rlx_ir::hir::HirMut;
7use rlx_ir::op::MaskKind;
8use rlx_ir::shape;
9
10use super::BlockStage;
11use crate::context::FlowCtx;
12use crate::value::FlowValue;
13#[derive(Debug, Clone)]
14pub struct LlamaDecodeLayerSpec {
15 pub num_heads: usize,
16 pub head_dim: usize,
17 pub num_kv_heads: usize,
18 pub kv_group_size: usize,
19 pub eps: f32,
20 pub use_custom_mask: bool,
21 pub hidden_shape: rlx_ir::Shape,
22}
23
24#[derive(Debug, Clone)]
25pub struct LlamaDecodeLayerStage {
26 pub layer_prefix: String,
27 pub spec: LlamaDecodeLayerSpec,
28 pub layer_idx: usize,
29 pub kv_out: std::sync::Arc<std::sync::Mutex<Vec<rlx_ir::HirNodeId>>>,
30}
31
32impl LlamaDecodeLayerStage {
33 pub fn layer(
34 layer_idx: usize,
35 spec: LlamaDecodeLayerSpec,
36 kv_out: std::sync::Arc<std::sync::Mutex<Vec<rlx_ir::HirNodeId>>>,
37 ) -> Self {
38 Self {
39 layer_prefix: format!("model.layers.{layer_idx}"),
40 spec,
41 layer_idx,
42 kv_out,
43 }
44 }
45}
46
47impl BlockStage for LlamaDecodeLayerStage {
48 fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
49 let decode = ctx
50 .state
51 .decode
52 .clone()
53 .ok_or_else(|| anyhow::anyhow!("LlamaDecodeLayer requires BindDecodeInputs"))?;
54 let zero_beta = ctx
55 .state
56 .zero_beta
57 .ok_or_else(|| anyhow::anyhow!("LlamaDecodeLayer requires ZeroBeta"))?;
58
59 let lp = &self.layer_prefix;
60 let spec = &self.spec;
61 let in_ln_g = ctx.load_param(&format!("{lp}.input_layernorm.weight"), false)?;
62 let q_w = ctx.load_param(&format!("{lp}.self_attn.q_proj.weight"), true)?;
63 let k_w = ctx.load_param(&format!("{lp}.self_attn.k_proj.weight"), true)?;
64 let v_w = ctx.load_param(&format!("{lp}.self_attn.v_proj.weight"), true)?;
65 let o_w = ctx.load_param(&format!("{lp}.self_attn.o_proj.weight"), true)?;
66 let post_ln_g = ctx.load_param(&format!("{lp}.post_attention_layernorm.weight"), false)?;
67 let gate_w = ctx.load_param(&format!("{lp}.mlp.gate_proj.weight"), true)?;
68 let up_w = ctx.load_param(&format!("{lp}.mlp.up_proj.weight"), true)?;
69 let down_w = ctx.load_param(&format!("{lp}.mlp.down_proj.weight"), true)?;
70
71 let past_k = decode.past_k[self.layer_idx];
72 let past_v = decode.past_v[self.layer_idx];
73
74 let mut gb = HirMut::new(ctx.hir());
75 let normed_in = gb.rms_norm(input.id, in_ln_g, zero_beta, spec.eps);
76 let q = gb.mm(normed_in, q_w);
77 let k = gb.mm(normed_in, k_w);
78 let v = gb.mm(normed_in, v_w);
79
80 let q_rope = gb.rope(q, decode.cos, decode.sin, spec.head_dim);
81 let k_rope = gb.rope(k, decode.cos, decode.sin, spec.head_dim);
82
83 let new_k = gb.concat_(vec![past_k, k_rope], 1);
84 let new_v = gb.concat_(vec![past_v, v], 1);
85 self.kv_out.lock().expect("kv out").push(new_k);
86 self.kv_out.lock().expect("kv out").push(new_v);
87
88 let k_rep = super::self_attn::repeat_kv(
89 &mut gb,
90 new_k,
91 spec.num_kv_heads,
92 spec.head_dim,
93 spec.kv_group_size,
94 );
95 let v_rep = super::self_attn::repeat_kv(
96 &mut gb,
97 new_v,
98 spec.num_kv_heads,
99 spec.head_dim,
100 spec.kv_group_size,
101 );
102
103 let attn_shape = shape::attention_shape(gb.shape(q_rope));
104 let attn = if spec.use_custom_mask {
105 let mask = decode
106 .mask
107 .ok_or_else(|| anyhow::anyhow!("custom mask requested but not bound"))?;
108 gb.attention(
109 q_rope,
110 k_rep,
111 v_rep,
112 mask,
113 spec.num_heads,
114 spec.head_dim,
115 attn_shape,
116 )
117 } else {
118 gb.attention_kind(
119 q_rope,
120 k_rep,
121 v_rep,
122 spec.num_heads,
123 spec.head_dim,
124 MaskKind::Causal,
125 attn_shape,
126 )
127 };
128
129 let attn_out = gb.mm(attn, o_w);
130 let post_attn = gb.add(input.id, attn_out);
131 let normed_post = gb.rms_norm(post_attn, post_ln_g, zero_beta, spec.eps);
132 let gate = gb.mm(normed_post, gate_w);
133 let up = gb.mm(normed_post, up_w);
134 let gate_act = gb.silu(gate);
135 let swiglu = gb.mul(gate_act, up);
136 let ffn_out = gb.mm(swiglu, down_w);
137 let h_id = gb.add(post_attn, ffn_out);
138
139 Ok(Some(ctx.wrap(h_id, spec.hidden_shape.clone())))
140 }
141}