Skip to main content

rlx_gemma/
config.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Gemma family configuration — HF `config.json` and GGUF metadata.
17
18use rlx_flow::blocks::{GemmaLayerStyle, gemma_strided_layer_mask, gemma2_layer_mask};
19use rlx_gguf::{GgufFile, MetaValue};
20use rlx_ir::op::MaskKind;
21use serde::Deserialize;
22use std::path::Path;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
25#[serde(rename_all = "lowercase")]
26pub enum GemmaArch {
27    #[default]
28    Gemma,
29    Gemma2,
30    Gemma3,
31    Gemma4,
32}
33
34impl GemmaArch {
35    pub fn sliding_window_stride(self) -> usize {
36        match self {
37            GemmaArch::Gemma3 | GemmaArch::Gemma4 => 6,
38            _ => 0,
39        }
40    }
41
42    fn from_gguf_tag(tag: &str) -> Self {
43        match tag {
44            "gemma2" => GemmaArch::Gemma2,
45            "gemma3" | "gemma3n" => GemmaArch::Gemma3,
46            "gemma4" | "gemma4moe" => GemmaArch::Gemma4,
47            _ => GemmaArch::Gemma,
48        }
49    }
50}
51
52#[derive(Debug, Clone, Deserialize)]
53pub struct GemmaConfig {
54    #[serde(default)]
55    pub arch: GemmaArch,
56    pub vocab_size: usize,
57    pub hidden_size: usize,
58    pub intermediate_size: usize,
59    pub num_hidden_layers: usize,
60    pub num_attention_heads: usize,
61    pub num_key_value_heads: usize,
62    pub max_position_embeddings: usize,
63    #[serde(default = "default_rms_norm_eps")]
64    pub rms_norm_eps: f64,
65    #[serde(default = "default_rope_theta")]
66    pub rope_theta: f64,
67    #[serde(default)]
68    pub tie_word_embeddings: bool,
69    #[serde(default)]
70    pub attention_bias: bool,
71    #[serde(default)]
72    pub head_dim: Option<usize>,
73    #[serde(default)]
74    pub attn_logit_softcapping: Option<f32>,
75    #[serde(default)]
76    pub final_logit_softcapping: Option<f32>,
77    #[serde(default)]
78    pub sliding_window: Option<usize>,
79    #[serde(default)]
80    pub query_pre_attn_scalar: Option<f32>,
81    #[serde(default)]
82    pub effective_num_layers: Option<usize>,
83    #[serde(default)]
84    pub num_experts: usize,
85    #[serde(default)]
86    pub num_experts_used: usize,
87    #[serde(default)]
88    pub expert_ffn_size: usize,
89    #[serde(default = "default_expert_weights_scale")]
90    pub expert_weights_scale: f32,
91}
92
93fn default_rms_norm_eps() -> f64 {
94    1e-6
95}
96fn default_rope_theta() -> f64 {
97    10_000.0
98}
99fn default_expert_weights_scale() -> f32 {
100    1.0
101}
102
103impl GemmaConfig {
104    pub fn from_file(path: &Path) -> anyhow::Result<Self> {
105        let data = std::fs::read_to_string(path)?;
106        let mut cfg: Self = serde_json::from_str(&data)?;
107        if cfg.arch == GemmaArch::Gemma {
108            cfg.arch = infer_arch_from_json(&data);
109        }
110        Ok(cfg)
111    }
112
113    pub fn from_gguf(raw: &GgufFile) -> anyhow::Result<Self> {
114        gemma_cfg_from_gguf(raw)
115    }
116
117    pub fn head_dim(&self) -> usize {
118        self.head_dim
119            .unwrap_or(self.hidden_size / self.num_attention_heads)
120    }
121
122    pub fn kv_group_size(&self) -> usize {
123        self.num_attention_heads / self.num_key_value_heads
124    }
125
126    pub fn q_proj_dim(&self) -> usize {
127        self.num_attention_heads * self.head_dim()
128    }
129
130    pub fn kv_proj_dim(&self) -> usize {
131        self.num_key_value_heads * self.head_dim()
132    }
133
134    pub fn layer_style(&self) -> GemmaLayerStyle {
135        match self.arch {
136            GemmaArch::Gemma => GemmaLayerStyle::Gemma,
137            GemmaArch::Gemma2 => GemmaLayerStyle::Gemma2,
138            GemmaArch::Gemma3 => GemmaLayerStyle::Gemma3,
139            GemmaArch::Gemma4 => GemmaLayerStyle::Gemma4,
140        }
141    }
142
143    pub fn active_num_layers(&self) -> usize {
144        self.effective_num_layers.unwrap_or(self.num_hidden_layers)
145    }
146
147    pub fn is_moe(&self) -> bool {
148        self.arch == GemmaArch::Gemma4 && self.num_experts > 0
149    }
150
151    pub fn expert_ffn_dim(&self) -> usize {
152        if self.expert_ffn_size > 0 {
153            self.expert_ffn_size
154        } else {
155            self.intermediate_size
156        }
157    }
158
159    pub fn attn_score_scale(&self) -> Option<f32> {
160        match self.arch {
161            GemmaArch::Gemma => None,
162            GemmaArch::Gemma2 | GemmaArch::Gemma3 | GemmaArch::Gemma4 => {
163                if let Some(s) = self.query_pre_attn_scalar {
164                    Some(1.0 / s)
165                } else {
166                    Some(1.0 / (self.head_dim() as f32).sqrt())
167                }
168            }
169        }
170    }
171
172    pub fn layer_attn_options(&self, layer: usize) -> (MaskKind, Option<f32>, Option<f32>) {
173        let scale = self.attn_score_scale();
174        let softcap = self.attn_logit_softcapping;
175        let mask = match (self.arch, self.sliding_window) {
176            (_, None) => MaskKind::Causal,
177            (GemmaArch::Gemma2, Some(w)) => gemma2_layer_mask(layer, w),
178            (GemmaArch::Gemma3 | GemmaArch::Gemma4, Some(w)) => {
179                gemma_strided_layer_mask(layer, w, self.arch.sliding_window_stride())
180            }
181            _ => MaskKind::Causal,
182        };
183        (mask, scale, softcap)
184    }
185
186    #[cfg(test)]
187    pub(crate) fn tiny_test() -> Self {
188        Self {
189            arch: GemmaArch::Gemma,
190            vocab_size: 32,
191            hidden_size: 16,
192            intermediate_size: 32,
193            num_hidden_layers: 2,
194            num_attention_heads: 4,
195            num_key_value_heads: 2,
196            max_position_embeddings: 64,
197            rms_norm_eps: 1e-6,
198            rope_theta: 10_000.0,
199            tie_word_embeddings: true,
200            attention_bias: false,
201            head_dim: None,
202            attn_logit_softcapping: None,
203            final_logit_softcapping: None,
204            sliding_window: None,
205            query_pre_attn_scalar: None,
206            effective_num_layers: None,
207            num_experts: 0,
208            num_experts_used: 0,
209            expert_ffn_size: 0,
210            expert_weights_scale: 1.0,
211        }
212    }
213}
214
215fn infer_arch_from_json(raw: &str) -> GemmaArch {
216    if raw.contains("\"model_type\"") {
217        if raw.contains("\"gemma2\"") {
218            return GemmaArch::Gemma2;
219        }
220        if raw.contains("\"gemma3\"") {
221            return GemmaArch::Gemma3;
222        }
223    }
224    GemmaArch::Gemma
225}
226
227pub fn gemma_cfg_from_gguf(raw: &GgufFile) -> anyhow::Result<GemmaConfig> {
228    let arch_tag = raw
229        .metadata
230        .get("general.architecture")
231        .and_then(MetaValue::as_str)
232        .unwrap_or("gemma");
233    let arch_prefix = arch_tag;
234    let arch = GemmaArch::from_gguf_tag(arch_tag);
235
236    let get_meta = |k: &str| -> Option<&MetaValue> {
237        raw.metadata.get(k).or_else(|| {
238            let suffix = k.strip_prefix("gemma.")?;
239            if arch_prefix == "gemma" {
240                None
241            } else {
242                let arch_key = format!("{arch_prefix}.{suffix}");
243                raw.metadata.get(&arch_key)
244            }
245        })
246    };
247    let get_u32 = |k: &str| -> anyhow::Result<u32> {
248        get_meta(k)
249            .and_then(MetaValue::as_u32)
250            .ok_or_else(|| anyhow::anyhow!("missing GGUF metadata key: {k}"))
251    };
252    let get_f32 = |k: &str| -> Option<f32> {
253        get_meta(k).and_then(|v| match v {
254            MetaValue::F32(x) => Some(*x),
255            _ => None,
256        })
257    };
258    let get_bool = |k: &str| -> Option<bool> {
259        get_meta(k).and_then(|v| match v {
260            MetaValue::Bool(b) => Some(*b),
261            _ => None,
262        })
263    };
264
265    let hidden_size = get_u32("gemma.embedding_length")? as usize;
266    let num_attention_heads = get_u32("gemma.attention.head_count")? as usize;
267    let head_dim = get_u32("gemma.attention.key_length")
268        .ok()
269        .or_else(|| get_u32("gemma.rope.dimension_count").ok())
270        .map(|v| v as usize);
271
272    Ok(GemmaConfig {
273        arch,
274        vocab_size: get_u32("gemma.vocab_size").unwrap_or(256_000) as usize,
275        hidden_size,
276        intermediate_size: get_u32("gemma.feed_forward_length")? as usize,
277        num_hidden_layers: get_u32("gemma.block_count")? as usize,
278        num_attention_heads,
279        num_key_value_heads: get_u32("gemma.attention.head_count_kv")? as usize,
280        max_position_embeddings: get_u32("gemma.context_length").unwrap_or(8192) as usize,
281        rms_norm_eps: get_f32("gemma.attention.layer_norm_rms_epsilon").unwrap_or(1e-6) as f64,
282        rope_theta: get_f32("gemma.rope.freq_base").unwrap_or(10_000.0) as f64,
283        tie_word_embeddings: get_bool("gemma.tie_word_embeddings").unwrap_or(true),
284        attention_bias: get_bool("gemma.attention.bias").unwrap_or(false),
285        head_dim,
286        attn_logit_softcapping: get_f32("gemma.attn_logit_softcapping"),
287        final_logit_softcapping: get_f32("gemma.final_logit_softcapping"),
288        sliding_window: get_u32("gemma.attention.sliding_window")
289            .ok()
290            .map(|v| v as usize),
291        query_pre_attn_scalar: get_f32("gemma.attention.query_pre_attn_scalar"),
292        effective_num_layers: get_u32("gemma.block_count_effective")
293            .ok()
294            .map(|v| v as usize),
295        num_experts: get_u32("gemma.expert_count").unwrap_or(0) as usize,
296        num_experts_used: get_u32("gemma.expert_used_count").unwrap_or(0) as usize,
297        expert_ffn_size: get_u32("gemma.expert_feed_forward_length").unwrap_or(0) as usize,
298        expert_weights_scale: get_f32("gemma.expert_weights_scale").unwrap_or(1.0),
299    })
300}