Skip to main content

rlx_flow/blocks/
gemma_layer.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4//! Gemma / Gemma 2 decoder blocks for tier-0 [`ModelFlow`] recipes.
5
6use std::sync::{Arc, Mutex};
7
8use super::{
9    GeGluStage, GemmaDecodeLayerSpec, GemmaDecodeLayerStage, GemmaKvTapStage, GemmaRmsNormStage,
10    SelfAttnPrefillSpec,
11};
12use crate::layer::LayerStack;
13use crate::stage::FlowStage;
14use rlx_ir::op::MaskKind;
15
16/// Per-architecture layer recipe (norm placement + FFN style).
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum GemmaLayerStyle {
19    Gemma,
20    Gemma2,
21    Gemma3,
22    Gemma4,
23}
24
25/// Build prefill self-attention spec for one layer.
26pub fn gemma_attn_spec(
27    layer: usize,
28    num_heads: usize,
29    head_dim: usize,
30    num_kv_heads: usize,
31    mask: MaskKind,
32    score_scale: Option<f32>,
33    attn_logit_softcap: Option<f32>,
34) -> SelfAttnPrefillSpec {
35    let prefix = format!("model.layers.{layer}");
36    SelfAttnPrefillSpec {
37        q_key: format!("{prefix}.self_attn.q_proj.weight"),
38        k_key: format!("{prefix}.self_attn.k_proj.weight"),
39        v_key: format!("{prefix}.self_attn.v_proj.weight"),
40        num_heads,
41        head_dim,
42        num_kv_heads,
43        mask,
44        score_scale,
45        attn_logit_softcap,
46    }
47}
48
49/// Sliding-window mask for Gemma 2 local-attention layers.
50pub fn gemma2_layer_mask(_layer: usize, window: usize) -> MaskKind {
51    MaskKind::SlidingWindow(window)
52}
53
54/// Gemma 3 / 4 strided pattern: `stride-1` layers use full causal, others sliding.
55pub fn gemma_strided_layer_mask(layer: usize, window: usize, stride: usize) -> MaskKind {
56    if stride > 1 && (layer + 1).is_multiple_of(stride) {
57        MaskKind::Causal
58    } else {
59        MaskKind::SlidingWindow(window)
60    }
61}
62
63/// Composed Gemma prefill decoder block.
64pub fn gemma_prefill_layer_composed(
65    layer_idx: usize,
66    style: GemmaLayerStyle,
67    attn: SelfAttnPrefillSpec,
68    eps: f32,
69    kv_sink: Option<Arc<Mutex<Vec<rlx_ir::HirNodeId>>>>,
70) -> FlowStage {
71    let prefix = format!("model.layers.{layer_idx}");
72    let mut stack = LayerStack::named(format!("layer{layer_idx}"))
73        .residual_save()
74        .stage(FlowStage::GemmaRmsNorm(GemmaRmsNormStage::hf_layer(
75            format!("{prefix}.input_layernorm"),
76            eps,
77        )));
78
79    if let Some(sink) = kv_sink {
80        stack = stack.stage(FlowStage::GemmaKvTap(GemmaKvTapStage::layer(
81            layer_idx,
82            attn.head_dim,
83            eps,
84            sink,
85        )));
86    }
87
88    stack = stack
89        .self_attn_prefill(attn)
90        .linear(format!("{prefix}.self_attn.o_proj.weight"), true)
91        .residual_add()
92        .residual_save();
93
94    stack = if matches!(
95        style,
96        GemmaLayerStyle::Gemma2 | GemmaLayerStyle::Gemma3 | GemmaLayerStyle::Gemma4
97    ) {
98        stack.stage(FlowStage::GemmaRmsNorm(GemmaRmsNormStage::hf_layer(
99            format!("{prefix}.pre_feedforward_layernorm"),
100            eps,
101        )))
102    } else {
103        stack.stage(FlowStage::GemmaRmsNorm(GemmaRmsNormStage::hf_layer(
104            format!("{prefix}.post_attention_layernorm"),
105            eps,
106        )))
107    };
108
109    stack = stack.stage(FlowStage::GeGlu(GeGluStage::hf_mlp(&prefix)));
110
111    if matches!(
112        style,
113        GemmaLayerStyle::Gemma2 | GemmaLayerStyle::Gemma3 | GemmaLayerStyle::Gemma4
114    ) {
115        stack = stack.stage(FlowStage::GemmaRmsNorm(GemmaRmsNormStage::hf_layer(
116            format!("{prefix}.post_feedforward_layernorm"),
117            eps,
118        )));
119    }
120
121    stack.residual_add().build()
122}
123
124/// MoE placeholder — dense Gemma paths use [`gemma_prefill_layer_composed`].
125pub fn gemma_moe_prefill_layer_composed(
126    layer_idx: usize,
127    style: GemmaLayerStyle,
128    attn: SelfAttnPrefillSpec,
129    eps: f32,
130    kv_sink: Option<Arc<Mutex<Vec<rlx_ir::HirNodeId>>>>,
131    _moe: super::MoeFfnStage,
132) -> FlowStage {
133    gemma_prefill_layer_composed(layer_idx, style, attn, eps, kv_sink)
134}
135
136pub fn gemma_moe_decode_layer_composed(
137    layer_idx: usize,
138    spec: GemmaDecodeLayerSpec,
139    kv_out: Arc<Mutex<Vec<rlx_ir::HirNodeId>>>,
140    _moe: super::MoeFfnStage,
141) -> FlowStage {
142    FlowStage::Named {
143        name: format!("layer{layer_idx}"),
144        inner: Arc::new(FlowStage::GemmaDecodeLayer(GemmaDecodeLayerStage::layer(
145            layer_idx, spec, kv_out,
146        ))),
147    }
148}