rlx_flow/blocks/
llama_decoder.rs1use anyhow::Result;
17use rlx_ir::op::MaskKind;
18
19use super::BlockStage;
20use crate::context::FlowCtx;
21use crate::value::FlowValue;
22#[derive(Debug, Clone)]
23pub struct LlamaDecoderStage {
24 pub layer_prefix: String,
25 pub num_heads: usize,
26 pub head_dim: usize,
27 pub num_kv_heads: usize,
28 pub eps: f32,
29 pub mask: MaskKind,
30 pub hidden_shape: rlx_ir::Shape,
31}
32
33impl LlamaDecoderStage {
34 pub fn layer(layer_idx: usize, spec: LlamaDecoderSpec) -> Self {
35 Self {
36 layer_prefix: format!("model.layers.{layer_idx}"),
37 num_heads: spec.num_heads,
38 head_dim: spec.head_dim,
39 num_kv_heads: spec.num_kv_heads,
40 eps: spec.eps,
41 mask: spec.mask,
42 hidden_shape: spec.hidden_shape,
43 }
44 }
45}
46
47#[derive(Debug, Clone)]
48pub struct LlamaDecoderSpec {
49 pub num_heads: usize,
50 pub head_dim: usize,
51 pub num_kv_heads: usize,
52 pub eps: f32,
53 pub mask: MaskKind,
54 pub hidden_shape: rlx_ir::Shape,
55}
56
57impl BlockStage for LlamaDecoderStage {
58 fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
59 let lp = &self.layer_prefix;
60 let zero_beta = ctx
61 .state
62 .zero_beta
63 .ok_or_else(|| anyhow::anyhow!("LlamaDecoder requires ZeroBeta stage"))?;
64 let cos = ctx
65 .state
66 .rope_cos
67 .ok_or_else(|| anyhow::anyhow!("LlamaDecoder requires RopeTables stage"))?;
68 let sin = ctx
69 .state
70 .rope_sin
71 .ok_or_else(|| anyhow::anyhow!("LlamaDecoder requires RopeTables stage"))?;
72
73 let in_ln_g = ctx.load_param(&format!("{lp}.input_layernorm.weight"), false)?;
74 let q_w = ctx.load_param(&format!("{lp}.self_attn.q_proj.weight"), true)?;
75 let k_w = ctx.load_param(&format!("{lp}.self_attn.k_proj.weight"), true)?;
76 let v_w = ctx.load_param(&format!("{lp}.self_attn.v_proj.weight"), true)?;
77 let o_w = ctx.load_param(&format!("{lp}.self_attn.o_proj.weight"), true)?;
78 let post_ln_g = ctx.load_param(&format!("{lp}.post_attention_layernorm.weight"), false)?;
79 let gate_w = ctx.load_param(&format!("{lp}.mlp.gate_proj.weight"), true)?;
80 let up_w = ctx.load_param(&format!("{lp}.mlp.up_proj.weight"), true)?;
81 let down_w = ctx.load_param(&format!("{lp}.mlp.down_proj.weight"), true)?;
82
83 let id = ctx.hir().llama_decoder_block(
84 input.id,
85 in_ln_g,
86 zero_beta,
87 q_w,
88 k_w,
89 v_w,
90 o_w,
91 post_ln_g,
92 zero_beta,
93 gate_w,
94 up_w,
95 down_w,
96 cos,
97 sin,
98 None,
99 self.num_heads,
100 self.head_dim,
101 self.num_kv_heads,
102 self.eps,
103 self.mask,
104 self.hidden_shape.clone(),
105 );
106
107 Ok(Some(ctx.wrap(id, self.hidden_shape.clone())))
108 }
109}