Skip to main content

rlx_gemma/
config.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Gemma family configuration — HF `config.json` and GGUF metadata.
17
18use rlx_flow::blocks::{GemmaLayerStyle, gemma_strided_layer_mask, gemma2_layer_mask};
19use rlx_gguf::{GgufFile, MetaValue};
20use rlx_ir::op::MaskKind;
21use serde::Deserialize;
22use std::path::Path;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
25#[serde(rename_all = "lowercase")]
26pub enum GemmaArch {
27    #[default]
28    Gemma,
29    Gemma2,
30    Gemma3,
31    Gemma4,
32}
33
34impl GemmaArch {
35    pub fn sliding_window_stride(self) -> usize {
36        match self {
37            GemmaArch::Gemma3 | GemmaArch::Gemma4 => 6,
38            _ => 0,
39        }
40    }
41
42    fn from_gguf_tag(tag: &str) -> Self {
43        match tag {
44            "gemma2" => GemmaArch::Gemma2,
45            "gemma3" | "gemma3n" => GemmaArch::Gemma3,
46            "gemma4" | "gemma4moe" | "gemma4_unified" | "gemma4_unified_text" => GemmaArch::Gemma4,
47            _ => GemmaArch::Gemma,
48        }
49    }
50}
51
52/// One entry in the Gemma 4 `text_config.layer_types` array. The
53/// repeating "5 sliding + 1 full" Gemma 3 pattern is just a special
54/// case of this richer per-layer schema.
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
56#[serde(rename_all = "snake_case")]
57pub enum GemmaLayerType {
58    SlidingAttention,
59    FullAttention,
60}
61
62/// Nested rope_parameters block. Gemma 4 12B carries per-attention-kind
63/// rope parameters: sliding layers use `theta=1e4` with full rotation,
64/// full-attention layers use `theta=1e6` with `partial_rotary_factor`
65/// (p-RoPE rotating only the leading slice).
66#[derive(Debug, Clone, Copy, Deserialize, Default)]
67pub struct GemmaRopeParameters {
68    #[serde(default)]
69    pub partial_rotary_factor: Option<f32>,
70    #[serde(default)]
71    pub rope_theta: Option<f32>,
72    #[serde(default)]
73    pub rope_type: Option<GemmaRopeKind>,
74}
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
77#[serde(rename_all = "snake_case")]
78pub enum GemmaRopeKind {
79    #[default]
80    Default,
81    Proportional,
82    Linear,
83    Dynamic,
84}
85
86#[derive(Debug, Clone, Default, Deserialize)]
87pub struct GemmaRopeMap {
88    #[serde(default)]
89    pub sliding_attention: Option<GemmaRopeParameters>,
90    #[serde(default)]
91    pub full_attention: Option<GemmaRopeParameters>,
92}
93
94#[derive(Debug, Clone, Deserialize)]
95pub struct GemmaConfig {
96    #[serde(default)]
97    pub arch: GemmaArch,
98    pub vocab_size: usize,
99    pub hidden_size: usize,
100    pub intermediate_size: usize,
101    pub num_hidden_layers: usize,
102    pub num_attention_heads: usize,
103    pub num_key_value_heads: usize,
104    pub max_position_embeddings: usize,
105    #[serde(default = "default_rms_norm_eps")]
106    pub rms_norm_eps: f64,
107    #[serde(default = "default_rope_theta")]
108    pub rope_theta: f64,
109    #[serde(default)]
110    pub tie_word_embeddings: bool,
111    #[serde(default)]
112    pub attention_bias: bool,
113    #[serde(default)]
114    pub head_dim: Option<usize>,
115    #[serde(default)]
116    pub attn_logit_softcapping: Option<f32>,
117    #[serde(default)]
118    pub final_logit_softcapping: Option<f32>,
119    #[serde(default)]
120    pub sliding_window: Option<usize>,
121    #[serde(default)]
122    pub query_pre_attn_scalar: Option<f32>,
123    #[serde(default)]
124    pub effective_num_layers: Option<usize>,
125    #[serde(default)]
126    pub num_experts: usize,
127    #[serde(default)]
128    pub num_experts_used: usize,
129    #[serde(default)]
130    pub expert_ffn_size: usize,
131    #[serde(default = "default_expert_weights_scale")]
132    pub expert_weights_scale: f32,
133
134    // ── Gemma 4 unified additions ──────────────────────────────────
135    /// Per-layer attention kind. Empty for Gemma <=3 — fall back to
136    /// the strided pattern derived from `arch.sliding_window_stride`.
137    #[serde(default)]
138    pub layer_types: Vec<GemmaLayerType>,
139    /// Per-attention-kind rope settings. Empty for Gemma <=3.
140    #[serde(default)]
141    pub rope_parameters: GemmaRopeMap,
142    /// Head dim for full-attention (global) layers. `None` ⇒ reuse
143    /// the base `head_dim`. Gemma 4 12B sets this to 512 while the
144    /// sliding `head_dim` stays at 256.
145    #[serde(default)]
146    pub global_head_dim: Option<usize>,
147    /// Num KV heads for full-attention layers. `None` ⇒ reuse the
148    /// base `num_key_value_heads`. Gemma 4 12B sets this to 1.
149    #[serde(default)]
150    pub num_global_key_value_heads: Option<usize>,
151    /// When true (Gemma 4 12B), the K projection is reused as V at
152    /// load time — weights only ship `.k_proj` and `.v_proj` becomes
153    /// an alias.
154    #[serde(default)]
155    pub attention_k_eq_v: bool,
156    /// When `"vision"`, media placeholder spans use bidirectional
157    /// attention on sliding layers (Gemma 4 unified).
158    #[serde(default)]
159    pub use_bidirectional_attention: Option<String>,
160}
161
162fn default_rms_norm_eps() -> f64 {
163    1e-6
164}
165fn default_rope_theta() -> f64 {
166    10_000.0
167}
168fn default_expert_weights_scale() -> f32 {
169    1.0
170}
171
172impl GemmaConfig {
173    pub fn from_file(path: &Path) -> anyhow::Result<Self> {
174        let data = std::fs::read_to_string(path)?;
175        // Gemma 4 unified (e.g. `google/gemma-4-12B`) nests the LM
176        // hyperparameters under `text_config` because the same file
177        // also carries vision + audio configs. Pick that subtree if
178        // it looks like the unified shape, otherwise stay flat.
179        let value: serde_json::Value = serde_json::from_str(&data)?;
180        let lm_value = match value.get("text_config") {
181            Some(tc) if tc.is_object() => tc.clone(),
182            _ => value.clone(),
183        };
184        let lm_value = normalize_hf_null_usize_fields(lm_value);
185        let mut cfg: Self = serde_json::from_value(lm_value)?;
186        if cfg.arch == GemmaArch::Gemma {
187            cfg.arch = infer_arch_from_json(&data);
188        }
189        Ok(cfg)
190    }
191
192    pub fn from_gguf(raw: &GgufFile) -> anyhow::Result<Self> {
193        gemma_cfg_from_gguf(raw)
194    }
195
196    pub fn head_dim(&self) -> usize {
197        self.head_dim
198            .unwrap_or(self.hidden_size / self.num_attention_heads)
199    }
200
201    pub fn kv_group_size(&self) -> usize {
202        self.num_attention_heads / self.num_key_value_heads
203    }
204
205    pub fn q_proj_dim(&self) -> usize {
206        self.num_attention_heads * self.head_dim()
207    }
208
209    pub fn kv_proj_dim(&self) -> usize {
210        self.num_key_value_heads * self.head_dim()
211    }
212
213    pub fn layer_style(&self) -> GemmaLayerStyle {
214        match self.arch {
215            GemmaArch::Gemma => GemmaLayerStyle::Gemma,
216            GemmaArch::Gemma2 => GemmaLayerStyle::Gemma2,
217            GemmaArch::Gemma3 => GemmaLayerStyle::Gemma3,
218            GemmaArch::Gemma4 => GemmaLayerStyle::Gemma4,
219        }
220    }
221
222    pub fn active_num_layers(&self) -> usize {
223        self.effective_num_layers.unwrap_or(self.num_hidden_layers)
224    }
225
226    pub fn is_moe(&self) -> bool {
227        self.arch == GemmaArch::Gemma4 && self.num_experts > 0
228    }
229
230    /// Gemma 4 unified: bidirectional attention inside vision/audio spans.
231    pub fn use_bidirectional_vision(&self) -> bool {
232        self.use_bidirectional_attention.as_deref() == Some("vision")
233    }
234
235    pub fn expert_ffn_dim(&self) -> usize {
236        if self.expert_ffn_size > 0 {
237            self.expert_ffn_size
238        } else {
239            self.intermediate_size
240        }
241    }
242
243    pub fn attn_score_scale(&self) -> Option<f32> {
244        match self.arch {
245            GemmaArch::Gemma => None,
246            GemmaArch::Gemma2 | GemmaArch::Gemma3 | GemmaArch::Gemma4 => {
247                if let Some(s) = self.query_pre_attn_scalar {
248                    Some(1.0 / s)
249                } else {
250                    Some(1.0 / (self.head_dim() as f32).sqrt())
251                }
252            }
253        }
254    }
255
256    /// Per-layer attention options driving the prefill self-attn block:
257    /// `(mask kind, softmax score scale, attention logit soft-cap)`.
258    /// The mask varies across Gemma variants:
259    ///
260    /// - Gemma 1 / no sliding window → all-causal.
261    /// - Gemma 2 → alternating sliding-window via [`gemma2_layer_mask`].
262    /// - Gemma 3 / 4 → strided pattern via
263    ///   [`gemma_strided_layer_mask`] (stride-6: every 6th layer is
264    ///   full causal, others are sliding-window).
265    pub fn layer_attn_options(&self, layer: usize) -> (MaskKind, Option<f32>, Option<f32>) {
266        let scale = self.attn_score_scale();
267        let softcap = self.attn_logit_softcapping;
268        let mask = match (self.arch, self.sliding_window) {
269            (_, None) => MaskKind::Causal,
270            (GemmaArch::Gemma2, Some(w)) => gemma2_layer_mask(layer, w),
271            (GemmaArch::Gemma3 | GemmaArch::Gemma4, Some(w)) => {
272                gemma_strided_layer_mask(layer, w, self.arch.sliding_window_stride())
273            }
274            _ => MaskKind::Causal,
275        };
276        (mask, scale, softcap)
277    }
278
279    #[cfg(test)]
280    pub(crate) fn tiny_test() -> Self {
281        Self {
282            arch: GemmaArch::Gemma,
283            vocab_size: 32,
284            hidden_size: 16,
285            intermediate_size: 32,
286            num_hidden_layers: 2,
287            num_attention_heads: 4,
288            num_key_value_heads: 2,
289            max_position_embeddings: 64,
290            rms_norm_eps: 1e-6,
291            rope_theta: 10_000.0,
292            tie_word_embeddings: true,
293            attention_bias: false,
294            head_dim: None,
295            attn_logit_softcapping: None,
296            final_logit_softcapping: None,
297            sliding_window: None,
298            query_pre_attn_scalar: None,
299            effective_num_layers: None,
300            num_experts: 0,
301            num_experts_used: 0,
302            expert_ffn_size: 0,
303            expert_weights_scale: 1.0,
304            layer_types: Vec::new(),
305            rope_parameters: GemmaRopeMap::default(),
306            global_head_dim: None,
307            num_global_key_value_heads: None,
308            attention_k_eq_v: false,
309            use_bidirectional_attention: None,
310        }
311    }
312
313    // ── Per-layer dispatch (Gemma 4 unified). ──────────────────────
314    //
315    // For Gemma 1/2/3 the `layer_types` array is empty and these
316    // helpers reduce to the existing strided pattern; for Gemma 4
317    // they read the explicit array so each layer can ship its own
318    // (head_dim, num_kv_heads, n_rot, rope_theta).
319
320    /// Whether layer `i` is a full-attention (global) layer rather
321    /// than a sliding-window one. Falls back to the strided pattern
322    /// (every `stride`-th layer is global) when `layer_types` is
323    /// unset.
324    pub fn is_full_attention_layer(&self, layer: usize) -> bool {
325        if !self.layer_types.is_empty() {
326            return matches!(
327                self.layer_types.get(layer),
328                Some(GemmaLayerType::FullAttention),
329            );
330        }
331        let stride = self.arch.sliding_window_stride();
332        stride > 1 && (layer + 1).is_multiple_of(stride)
333    }
334
335    /// Per-layer head_dim. Sliding layers always use the base
336    /// `head_dim`; full-attention layers use `global_head_dim` when
337    /// set (Gemma 4 12B: 512 vs base 256).
338    pub fn layer_head_dim(&self, layer: usize) -> usize {
339        if self.is_full_attention_layer(layer) {
340            self.global_head_dim.unwrap_or_else(|| self.head_dim())
341        } else {
342            self.head_dim()
343        }
344    }
345
346    /// Per-layer KV head count. Sliding layers use
347    /// `num_key_value_heads`; full-attention layers use
348    /// `num_global_key_value_heads` when set (Gemma 4 12B: 1 vs 8).
349    pub fn layer_num_kv_heads(&self, layer: usize) -> usize {
350        if self.is_full_attention_layer(layer) {
351            self.num_global_key_value_heads
352                .unwrap_or(self.num_key_value_heads)
353        } else {
354            self.num_key_value_heads
355        }
356    }
357
358    /// Number of leading per-head dimensions that get RoPE-rotated
359    /// in layer `i`. Returns `layer_head_dim` for "default" RoPE,
360    /// or `floor(partial_rotary_factor * head_dim)` for p-RoPE.
361    pub fn layer_n_rot(&self, layer: usize) -> usize {
362        let dh = self.layer_head_dim(layer);
363        let params = self.layer_rope_parameters(layer);
364        let kind = params
365            .and_then(|p| p.rope_type)
366            .unwrap_or(GemmaRopeKind::Default);
367        let factor = params.and_then(|p| p.partial_rotary_factor);
368        match (kind, factor) {
369            (GemmaRopeKind::Proportional, Some(f)) if f > 0.0 && f < 1.0 => {
370                ((dh as f32) * f).floor() as usize
371            }
372            _ => dh,
373        }
374    }
375
376    /// RoPE base frequency for layer `i`. Falls back to the
377    /// top-level `rope_theta` when the unified map omits the entry.
378    pub fn layer_rope_theta(&self, layer: usize) -> f64 {
379        self.layer_rope_parameters(layer)
380            .and_then(|p| p.rope_theta)
381            .map(|t| t as f64)
382            .unwrap_or(self.rope_theta)
383    }
384
385    fn layer_rope_parameters(&self, layer: usize) -> Option<&GemmaRopeParameters> {
386        if self.is_full_attention_layer(layer) {
387            self.rope_parameters.full_attention.as_ref()
388        } else {
389            self.rope_parameters.sliding_attention.as_ref()
390        }
391    }
392}
393
394/// HF dense Gemma 4 checkpoints use JSON `null` for unused MoE keys.
395fn normalize_hf_null_usize_fields(mut value: serde_json::Value) -> serde_json::Value {
396    let Some(obj) = value.as_object_mut() else {
397        return value;
398    };
399    for key in [
400        "num_experts",
401        "num_experts_used",
402        "top_k_experts",
403        "expert_ffn_size",
404        "moe_intermediate_size",
405        "hidden_size_per_layer_input",
406    ] {
407        if obj.get(key).is_some_and(|v| v.is_null()) {
408            obj.insert(key.to_string(), serde_json::Value::from(0usize));
409        }
410    }
411    value
412}
413
414fn infer_arch_from_json(raw: &str) -> GemmaArch {
415    // Detect Gemma 4 first — its unified config also contains a
416    // nested `gemma4_unified_text` model_type that we want to catch
417    // even when the outer `model_type` is `gemma4_unified` or the
418    // architecture is `Gemma4UnifiedForConditionalGeneration`.
419    if raw.contains("\"gemma4_unified\"")
420        || raw.contains("\"gemma4_unified_text\"")
421        || raw.contains("\"gemma4\"")
422        || raw.contains("\"gemma4moe\"")
423        || raw.contains("Gemma4UnifiedForConditionalGeneration")
424        || raw.contains("Gemma4ForCausalLM")
425    {
426        return GemmaArch::Gemma4;
427    }
428    if raw.contains("\"model_type\"") {
429        if raw.contains("\"gemma2\"") {
430            return GemmaArch::Gemma2;
431        }
432        if raw.contains("\"gemma3\"") {
433            return GemmaArch::Gemma3;
434        }
435    }
436    GemmaArch::Gemma
437}
438
439pub fn gemma_cfg_from_gguf(raw: &GgufFile) -> anyhow::Result<GemmaConfig> {
440    let arch_tag = raw
441        .metadata
442        .get("general.architecture")
443        .and_then(MetaValue::as_str)
444        .unwrap_or("gemma");
445    let arch_prefix = arch_tag;
446    let arch = GemmaArch::from_gguf_tag(arch_tag);
447
448    let get_meta = |k: &str| -> Option<&MetaValue> {
449        raw.metadata.get(k).or_else(|| {
450            let suffix = k.strip_prefix("gemma.")?;
451            if arch_prefix == "gemma" {
452                None
453            } else {
454                let arch_key = format!("{arch_prefix}.{suffix}");
455                raw.metadata.get(&arch_key)
456            }
457        })
458    };
459    let get_u32 = |k: &str| -> anyhow::Result<u32> {
460        get_meta(k)
461            .and_then(MetaValue::as_u32)
462            .ok_or_else(|| anyhow::anyhow!("missing GGUF metadata key: {k}"))
463    };
464    let get_f32 = |k: &str| -> Option<f32> {
465        get_meta(k).and_then(|v| match v {
466            MetaValue::F32(x) => Some(*x),
467            _ => None,
468        })
469    };
470    let get_bool = |k: &str| -> Option<bool> {
471        get_meta(k).and_then(|v| match v {
472            MetaValue::Bool(b) => Some(*b),
473            _ => None,
474        })
475    };
476
477    let hidden_size = get_u32("gemma.embedding_length")? as usize;
478    let num_attention_heads = get_u32("gemma.attention.head_count")? as usize;
479    let head_dim = get_u32("gemma.attention.key_length")
480        .ok()
481        .or_else(|| get_u32("gemma.rope.dimension_count").ok())
482        .map(|v| v as usize);
483
484    Ok(GemmaConfig {
485        arch,
486        vocab_size: get_u32("gemma.vocab_size").unwrap_or(256_000) as usize,
487        hidden_size,
488        intermediate_size: get_u32("gemma.feed_forward_length")? as usize,
489        num_hidden_layers: get_u32("gemma.block_count")? as usize,
490        num_attention_heads,
491        num_key_value_heads: get_u32("gemma.attention.head_count_kv")? as usize,
492        max_position_embeddings: get_u32("gemma.context_length").unwrap_or(8192) as usize,
493        rms_norm_eps: get_f32("gemma.attention.layer_norm_rms_epsilon").unwrap_or(1e-6) as f64,
494        rope_theta: get_f32("gemma.rope.freq_base").unwrap_or(10_000.0) as f64,
495        tie_word_embeddings: get_bool("gemma.tie_word_embeddings").unwrap_or(true),
496        attention_bias: get_bool("gemma.attention.bias").unwrap_or(false),
497        head_dim,
498        attn_logit_softcapping: get_f32("gemma.attn_logit_softcapping"),
499        final_logit_softcapping: get_f32("gemma.final_logit_softcapping"),
500        sliding_window: get_u32("gemma.attention.sliding_window")
501            .ok()
502            .map(|v| v as usize),
503        query_pre_attn_scalar: get_f32("gemma.attention.query_pre_attn_scalar"),
504        effective_num_layers: get_u32("gemma.block_count_effective")
505            .ok()
506            .map(|v| v as usize),
507        num_experts: get_u32("gemma.expert_count").unwrap_or(0) as usize,
508        num_experts_used: get_u32("gemma.expert_used_count").unwrap_or(0) as usize,
509        expert_ffn_size: get_u32("gemma.expert_feed_forward_length").unwrap_or(0) as usize,
510        expert_weights_scale: get_f32("gemma.expert_weights_scale").unwrap_or(1.0),
511        // GGUF doesn't carry the Gemma 4 unified per-layer schema
512        // yet; the dense path falls back to the strided pattern and
513        // uniform head dims that match every Gemma 4 GGUF currently
514        // emitted by llama.cpp.
515        layer_types: Vec::new(),
516        rope_parameters: GemmaRopeMap::default(),
517        global_head_dim: None,
518        num_global_key_value_heads: None,
519        attention_k_eq_v: false,
520        use_bidirectional_attention: None,
521    })
522}
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527
528    /// Trimmed copy of `google/gemma-4-12B`'s `config.json` — only the
529    /// fields the loader actually consumes plus the surrounding shape
530    /// (top-level `model_type`, nested `text_config`) that proves we
531    /// unwrap the unified layout correctly.
532    const GEMMA_4_12B_CONFIG: &str = r#"{
533      "architectures": ["Gemma4UnifiedForConditionalGeneration"],
534      "model_type": "gemma4_unified",
535      "tie_word_embeddings": true,
536      "text_config": {
537        "model_type": "gemma4_unified_text",
538        "vocab_size": 262144,
539        "hidden_size": 3840,
540        "intermediate_size": 15360,
541        "num_hidden_layers": 48,
542        "num_attention_heads": 16,
543        "num_key_value_heads": 8,
544        "num_global_key_value_heads": 1,
545        "head_dim": 256,
546        "global_head_dim": 512,
547        "attention_k_eq_v": true,
548        "max_position_embeddings": 131072,
549        "rms_norm_eps": 1e-6,
550        "tie_word_embeddings": true,
551        "attention_bias": false,
552        "final_logit_softcapping": 30.0,
553        "sliding_window": 1024,
554        "layer_types": [
555          "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
556          "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
557          "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
558          "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
559          "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
560          "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
561          "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
562          "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
563        ],
564        "rope_parameters": {
565          "full_attention":    { "partial_rotary_factor": 0.25, "rope_theta": 1000000.0, "rope_type": "proportional" },
566          "sliding_attention": { "rope_theta": 10000.0, "rope_type": "default" }
567        }
568      }
569    }"#;
570
571    #[test]
572    fn gemma_4_12b_unified_config_parses_text_subtree() {
573        let dir = std::env::temp_dir();
574        let path = dir.join("rlx_gemma_gemma4_12b_test_config.json");
575        std::fs::write(&path, GEMMA_4_12B_CONFIG).unwrap();
576        let cfg = GemmaConfig::from_file(&path).unwrap();
577        std::fs::remove_file(&path).ok();
578
579        assert_eq!(cfg.arch, GemmaArch::Gemma4);
580        assert_eq!(cfg.vocab_size, 262_144);
581        assert_eq!(cfg.hidden_size, 3840);
582        assert_eq!(cfg.intermediate_size, 15_360);
583        assert_eq!(cfg.num_hidden_layers, 48);
584        assert_eq!(cfg.num_attention_heads, 16);
585        assert_eq!(cfg.num_key_value_heads, 8);
586        assert_eq!(cfg.head_dim(), 256);
587        assert_eq!(cfg.global_head_dim, Some(512));
588        assert_eq!(cfg.num_global_key_value_heads, Some(1));
589        assert!(cfg.attention_k_eq_v);
590        assert_eq!(cfg.sliding_window, Some(1024));
591        assert_eq!(cfg.final_logit_softcapping, Some(30.0));
592        assert!(cfg.tie_word_embeddings);
593        assert_eq!(cfg.layer_types.len(), 48);
594        // Stride-6 sliding-window pattern carried over from Gemma 3.
595        assert_eq!(cfg.arch.sliding_window_stride(), 6);
596    }
597
598    #[test]
599    fn hf_null_moe_fields_default_to_zero() {
600        let json = r#"{"num_experts": null, "top_k_experts": null}"#;
601        let v = normalize_hf_null_usize_fields(serde_json::from_str(json).unwrap());
602        let obj = v.as_object().unwrap();
603        assert_eq!(obj["num_experts"], 0);
604        assert_eq!(obj["top_k_experts"], 0);
605    }
606
607    #[test]
608    fn gemma_4_12b_per_layer_dispatch() {
609        let dir = std::env::temp_dir();
610        let path = dir.join("rlx_gemma_gemma4_12b_dispatch_config.json");
611        std::fs::write(&path, GEMMA_4_12B_CONFIG).unwrap();
612        let cfg = GemmaConfig::from_file(&path).unwrap();
613        std::fs::remove_file(&path).ok();
614
615        // Sliding layer 0 — base shapes + full rotary on theta=1e4.
616        assert!(!cfg.is_full_attention_layer(0));
617        assert_eq!(cfg.layer_head_dim(0), 256);
618        assert_eq!(cfg.layer_num_kv_heads(0), 8);
619        assert_eq!(cfg.layer_n_rot(0), 256);
620        assert!((cfg.layer_rope_theta(0) - 10_000.0).abs() < 1e-3);
621
622        // Full-attention layer 5 (1-indexed: 6th layer) — global
623        // shapes, p-RoPE (0.25 of head_dim_full=512 → 128), theta=1e6.
624        assert!(cfg.is_full_attention_layer(5));
625        assert_eq!(cfg.layer_head_dim(5), 512);
626        assert_eq!(cfg.layer_num_kv_heads(5), 1);
627        assert_eq!(cfg.layer_n_rot(5), 128);
628        assert!((cfg.layer_rope_theta(5) - 1_000_000.0).abs() < 1e-3);
629
630        // Last layer (index 47, 1-indexed 48) is also full-attention.
631        assert!(cfg.is_full_attention_layer(47));
632    }
633
634    #[test]
635    fn pre_gemma4_archs_keep_uniform_layer_shape() {
636        // Without `layer_types` / `rope_parameters` the per-layer
637        // accessors collapse to the base values so Gemma 3 / 2 / 1
638        // continue to round-trip the existing flow.
639        let mut cfg = GemmaConfig::tiny_test();
640        cfg.arch = GemmaArch::Gemma3;
641        cfg.head_dim = Some(64);
642        cfg.num_key_value_heads = 2;
643        cfg.rope_theta = 1_000.0;
644        for i in 0..cfg.num_hidden_layers {
645            assert_eq!(cfg.layer_head_dim(i), 64);
646            assert_eq!(cfg.layer_num_kv_heads(i), 2);
647            assert_eq!(cfg.layer_n_rot(i), 64);
648            assert!((cfg.layer_rope_theta(i) - 1_000.0).abs() < 1e-3);
649        }
650    }
651
652    #[test]
653    fn infer_arch_picks_up_gemma4_markers() {
654        assert_eq!(
655            infer_arch_from_json(r#"{"model_type":"gemma4_unified"}"#),
656            GemmaArch::Gemma4,
657        );
658        assert_eq!(
659            infer_arch_from_json(r#"{"architectures":["Gemma4UnifiedForConditionalGeneration"]}"#),
660            GemmaArch::Gemma4,
661        );
662        assert_eq!(
663            infer_arch_from_json(r#"{"model_type":"gemma3"}"#),
664            GemmaArch::Gemma3,
665        );
666    }
667}