rlx_flow/blocks/
gemma_layer.rs1use std::sync::{Arc, Mutex};
7
8use super::{
9 GeGluStage, GemmaDecodeLayerSpec, GemmaDecodeLayerStage, GemmaKvTapStage, GemmaRmsNormStage,
10 SelfAttnPrefillSpec,
11};
12use crate::layer::LayerStack;
13use crate::stage::FlowStage;
14use rlx_ir::op::MaskKind;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum GemmaLayerStyle {
19 Gemma,
20 Gemma2,
21 Gemma3,
22 Gemma4,
23}
24
25pub fn gemma_attn_spec(
27 layer: usize,
28 num_heads: usize,
29 head_dim: usize,
30 num_kv_heads: usize,
31 mask: MaskKind,
32 score_scale: Option<f32>,
33 attn_logit_softcap: Option<f32>,
34) -> SelfAttnPrefillSpec {
35 let prefix = format!("model.layers.{layer}");
36 SelfAttnPrefillSpec {
37 q_key: format!("{prefix}.self_attn.q_proj.weight"),
38 k_key: format!("{prefix}.self_attn.k_proj.weight"),
39 v_key: format!("{prefix}.self_attn.v_proj.weight"),
40 num_heads,
41 head_dim,
42 num_kv_heads,
43 mask,
44 score_scale,
45 attn_logit_softcap,
46 }
47}
48
49pub fn gemma2_layer_mask(_layer: usize, window: usize) -> MaskKind {
51 MaskKind::SlidingWindow(window)
52}
53
54pub fn gemma_strided_layer_mask(layer: usize, window: usize, stride: usize) -> MaskKind {
56 if stride > 1 && (layer + 1).is_multiple_of(stride) {
57 MaskKind::Causal
58 } else {
59 MaskKind::SlidingWindow(window)
60 }
61}
62
63pub fn gemma_prefill_layer_composed(
65 layer_idx: usize,
66 style: GemmaLayerStyle,
67 attn: SelfAttnPrefillSpec,
68 eps: f32,
69 kv_sink: Option<Arc<Mutex<Vec<rlx_ir::HirNodeId>>>>,
70) -> FlowStage {
71 let prefix = format!("model.layers.{layer_idx}");
72 let mut stack = LayerStack::named(format!("layer{layer_idx}"))
73 .residual_save()
74 .stage(FlowStage::GemmaRmsNorm(GemmaRmsNormStage::hf_layer(
75 format!("{prefix}.input_layernorm"),
76 eps,
77 )));
78
79 if let Some(sink) = kv_sink {
80 stack = stack.stage(FlowStage::GemmaKvTap(GemmaKvTapStage::layer(
81 layer_idx,
82 attn.head_dim,
83 eps,
84 sink,
85 )));
86 }
87
88 stack = stack
89 .self_attn_prefill(attn)
90 .linear(format!("{prefix}.self_attn.o_proj.weight"), true)
91 .residual_add()
92 .residual_save();
93
94 stack = if matches!(
95 style,
96 GemmaLayerStyle::Gemma2 | GemmaLayerStyle::Gemma3 | GemmaLayerStyle::Gemma4
97 ) {
98 stack.stage(FlowStage::GemmaRmsNorm(GemmaRmsNormStage::hf_layer(
99 format!("{prefix}.pre_feedforward_layernorm"),
100 eps,
101 )))
102 } else {
103 stack.stage(FlowStage::GemmaRmsNorm(GemmaRmsNormStage::hf_layer(
104 format!("{prefix}.post_attention_layernorm"),
105 eps,
106 )))
107 };
108
109 stack = stack.stage(FlowStage::GeGlu(GeGluStage::hf_mlp(&prefix)));
110
111 if matches!(
112 style,
113 GemmaLayerStyle::Gemma2 | GemmaLayerStyle::Gemma3 | GemmaLayerStyle::Gemma4
114 ) {
115 stack = stack.stage(FlowStage::GemmaRmsNorm(GemmaRmsNormStage::hf_layer(
116 format!("{prefix}.post_feedforward_layernorm"),
117 eps,
118 )));
119 }
120
121 stack.residual_add().build()
122}
123
124pub fn gemma_moe_prefill_layer_composed(
126 layer_idx: usize,
127 style: GemmaLayerStyle,
128 attn: SelfAttnPrefillSpec,
129 eps: f32,
130 kv_sink: Option<Arc<Mutex<Vec<rlx_ir::HirNodeId>>>>,
131 _moe: super::MoeFfnStage,
132) -> FlowStage {
133 gemma_prefill_layer_composed(layer_idx, style, attn, eps, kv_sink)
134}
135
136pub fn gemma_moe_decode_layer_composed(
137 layer_idx: usize,
138 spec: GemmaDecodeLayerSpec,
139 kv_out: Arc<Mutex<Vec<rlx_ir::HirNodeId>>>,
140 _moe: super::MoeFfnStage,
141) -> FlowStage {
142 FlowStage::Named {
143 name: format!("layer{layer_idx}"),
144 inner: Arc::new(FlowStage::GemmaDecodeLayer(GemmaDecodeLayerStage::layer(
145 layer_idx, spec, kv_out,
146 ))),
147 }
148}