rlx_flow/blocks/
llama_decoder.rs1use anyhow::Result;
5use rlx_ir::op::MaskKind;
6
7use super::BlockStage;
8use crate::context::FlowCtx;
9use crate::value::FlowValue;
10#[derive(Debug, Clone)]
11pub struct LlamaDecoderStage {
12 pub layer_prefix: String,
13 pub num_heads: usize,
14 pub head_dim: usize,
15 pub num_kv_heads: usize,
16 pub eps: f32,
17 pub mask: MaskKind,
18 pub hidden_shape: rlx_ir::Shape,
19}
20
21impl LlamaDecoderStage {
22 pub fn layer(layer_idx: usize, spec: LlamaDecoderSpec) -> Self {
23 Self {
24 layer_prefix: format!("model.layers.{layer_idx}"),
25 num_heads: spec.num_heads,
26 head_dim: spec.head_dim,
27 num_kv_heads: spec.num_kv_heads,
28 eps: spec.eps,
29 mask: spec.mask,
30 hidden_shape: spec.hidden_shape,
31 }
32 }
33}
34
35#[derive(Debug, Clone)]
36pub struct LlamaDecoderSpec {
37 pub num_heads: usize,
38 pub head_dim: usize,
39 pub num_kv_heads: usize,
40 pub eps: f32,
41 pub mask: MaskKind,
42 pub hidden_shape: rlx_ir::Shape,
43}
44
45impl BlockStage for LlamaDecoderStage {
46 fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
47 let lp = &self.layer_prefix;
48 let zero_beta = ctx
49 .state
50 .zero_beta
51 .ok_or_else(|| anyhow::anyhow!("LlamaDecoder requires ZeroBeta stage"))?;
52 let cos = ctx
53 .state
54 .rope_cos
55 .ok_or_else(|| anyhow::anyhow!("LlamaDecoder requires RopeTables stage"))?;
56 let sin = ctx
57 .state
58 .rope_sin
59 .ok_or_else(|| anyhow::anyhow!("LlamaDecoder requires RopeTables stage"))?;
60
61 let in_ln_g = ctx.load_param(&format!("{lp}.input_layernorm.weight"), false)?;
62 let q_w = ctx.load_param(&format!("{lp}.self_attn.q_proj.weight"), true)?;
63 let k_w = ctx.load_param(&format!("{lp}.self_attn.k_proj.weight"), true)?;
64 let v_w = ctx.load_param(&format!("{lp}.self_attn.v_proj.weight"), true)?;
65 let o_w = ctx.load_param(&format!("{lp}.self_attn.o_proj.weight"), true)?;
66 let post_ln_g = ctx.load_param(&format!("{lp}.post_attention_layernorm.weight"), false)?;
67 let gate_w = ctx.load_param(&format!("{lp}.mlp.gate_proj.weight"), true)?;
68 let up_w = ctx.load_param(&format!("{lp}.mlp.up_proj.weight"), true)?;
69 let down_w = ctx.load_param(&format!("{lp}.mlp.down_proj.weight"), true)?;
70
71 let id = ctx.hir().llama_decoder_block(
72 input.id,
73 in_ln_g,
74 zero_beta,
75 q_w,
76 k_w,
77 v_w,
78 o_w,
79 post_ln_g,
80 zero_beta,
81 gate_w,
82 up_w,
83 down_w,
84 cos,
85 sin,
86 None,
87 self.num_heads,
88 self.head_dim,
89 self.num_kv_heads,
90 self.eps,
91 self.mask,
92 self.hidden_shape.clone(),
93 );
94
95 Ok(Some(ctx.wrap(id, self.hidden_shape.clone())))
96 }
97}