rlx-gemma 0.2.0

Gemma / Gemma 2 causal LMs for RLX
Documentation
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.

//! Gemma family configuration — HF `config.json` and GGUF metadata.

use rlx_flow::blocks::{GemmaLayerStyle, gemma_strided_layer_mask, gemma2_layer_mask};
use rlx_gguf::{GgufFile, MetaValue};
use rlx_ir::op::MaskKind;
use serde::Deserialize;
use std::path::Path;

#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum GemmaArch {
    #[default]
    Gemma,
    Gemma2,
    Gemma3,
    Gemma4,
}

impl GemmaArch {
    pub fn sliding_window_stride(self) -> usize {
        match self {
            GemmaArch::Gemma3 | GemmaArch::Gemma4 => 6,
            _ => 0,
        }
    }

    fn from_gguf_tag(tag: &str) -> Self {
        match tag {
            "gemma2" => GemmaArch::Gemma2,
            "gemma3" | "gemma3n" => GemmaArch::Gemma3,
            "gemma4" | "gemma4moe" => GemmaArch::Gemma4,
            _ => GemmaArch::Gemma,
        }
    }
}

#[derive(Debug, Clone, Deserialize)]
pub struct GemmaConfig {
    #[serde(default)]
    pub arch: GemmaArch,
    pub vocab_size: usize,
    pub hidden_size: usize,
    pub intermediate_size: usize,
    pub num_hidden_layers: usize,
    pub num_attention_heads: usize,
    pub num_key_value_heads: usize,
    pub max_position_embeddings: usize,
    #[serde(default = "default_rms_norm_eps")]
    pub rms_norm_eps: f64,
    #[serde(default = "default_rope_theta")]
    pub rope_theta: f64,
    #[serde(default)]
    pub tie_word_embeddings: bool,
    #[serde(default)]
    pub attention_bias: bool,
    #[serde(default)]
    pub head_dim: Option<usize>,
    #[serde(default)]
    pub attn_logit_softcapping: Option<f32>,
    #[serde(default)]
    pub final_logit_softcapping: Option<f32>,
    #[serde(default)]
    pub sliding_window: Option<usize>,
    #[serde(default)]
    pub query_pre_attn_scalar: Option<f32>,
    #[serde(default)]
    pub effective_num_layers: Option<usize>,
    #[serde(default)]
    pub num_experts: usize,
    #[serde(default)]
    pub num_experts_used: usize,
    #[serde(default)]
    pub expert_ffn_size: usize,
    #[serde(default = "default_expert_weights_scale")]
    pub expert_weights_scale: f32,
}

fn default_rms_norm_eps() -> f64 {
    1e-6
}
fn default_rope_theta() -> f64 {
    10_000.0
}
fn default_expert_weights_scale() -> f32 {
    1.0
}

impl GemmaConfig {
    pub fn from_file(path: &Path) -> anyhow::Result<Self> {
        let data = std::fs::read_to_string(path)?;
        let mut cfg: Self = serde_json::from_str(&data)?;
        if cfg.arch == GemmaArch::Gemma {
            cfg.arch = infer_arch_from_json(&data);
        }
        Ok(cfg)
    }

    pub fn from_gguf(raw: &GgufFile) -> anyhow::Result<Self> {
        gemma_cfg_from_gguf(raw)
    }

    pub fn head_dim(&self) -> usize {
        self.head_dim
            .unwrap_or(self.hidden_size / self.num_attention_heads)
    }

    pub fn kv_group_size(&self) -> usize {
        self.num_attention_heads / self.num_key_value_heads
    }

    pub fn q_proj_dim(&self) -> usize {
        self.num_attention_heads * self.head_dim()
    }

    pub fn kv_proj_dim(&self) -> usize {
        self.num_key_value_heads * self.head_dim()
    }

    pub fn layer_style(&self) -> GemmaLayerStyle {
        match self.arch {
            GemmaArch::Gemma => GemmaLayerStyle::Gemma,
            GemmaArch::Gemma2 => GemmaLayerStyle::Gemma2,
            GemmaArch::Gemma3 => GemmaLayerStyle::Gemma3,
            GemmaArch::Gemma4 => GemmaLayerStyle::Gemma4,
        }
    }

    pub fn active_num_layers(&self) -> usize {
        self.effective_num_layers.unwrap_or(self.num_hidden_layers)
    }

    pub fn is_moe(&self) -> bool {
        self.arch == GemmaArch::Gemma4 && self.num_experts > 0
    }

    pub fn expert_ffn_dim(&self) -> usize {
        if self.expert_ffn_size > 0 {
            self.expert_ffn_size
        } else {
            self.intermediate_size
        }
    }

    pub fn attn_score_scale(&self) -> Option<f32> {
        match self.arch {
            GemmaArch::Gemma => None,
            GemmaArch::Gemma2 | GemmaArch::Gemma3 | GemmaArch::Gemma4 => {
                if let Some(s) = self.query_pre_attn_scalar {
                    Some(1.0 / s)
                } else {
                    Some(1.0 / (self.head_dim() as f32).sqrt())
                }
            }
        }
    }

    pub fn layer_attn_options(&self, layer: usize) -> (MaskKind, Option<f32>, Option<f32>) {
        let scale = self.attn_score_scale();
        let softcap = self.attn_logit_softcapping;
        let mask = match (self.arch, self.sliding_window) {
            (_, None) => MaskKind::Causal,
            (GemmaArch::Gemma2, Some(w)) => gemma2_layer_mask(layer, w),
            (GemmaArch::Gemma3 | GemmaArch::Gemma4, Some(w)) => {
                gemma_strided_layer_mask(layer, w, self.arch.sliding_window_stride())
            }
            _ => MaskKind::Causal,
        };
        (mask, scale, softcap)
    }

    #[cfg(test)]
    pub(crate) fn tiny_test() -> Self {
        Self {
            arch: GemmaArch::Gemma,
            vocab_size: 32,
            hidden_size: 16,
            intermediate_size: 32,
            num_hidden_layers: 2,
            num_attention_heads: 4,
            num_key_value_heads: 2,
            max_position_embeddings: 64,
            rms_norm_eps: 1e-6,
            rope_theta: 10_000.0,
            tie_word_embeddings: true,
            attention_bias: false,
            head_dim: None,
            attn_logit_softcapping: None,
            final_logit_softcapping: None,
            sliding_window: None,
            query_pre_attn_scalar: None,
            effective_num_layers: None,
            num_experts: 0,
            num_experts_used: 0,
            expert_ffn_size: 0,
            expert_weights_scale: 1.0,
        }
    }
}

fn infer_arch_from_json(raw: &str) -> GemmaArch {
    if raw.contains("\"model_type\"") {
        if raw.contains("\"gemma2\"") {
            return GemmaArch::Gemma2;
        }
        if raw.contains("\"gemma3\"") {
            return GemmaArch::Gemma3;
        }
    }
    GemmaArch::Gemma
}

pub fn gemma_cfg_from_gguf(raw: &GgufFile) -> anyhow::Result<GemmaConfig> {
    let arch_tag = raw
        .metadata
        .get("general.architecture")
        .and_then(MetaValue::as_str)
        .unwrap_or("gemma");
    let arch_prefix = arch_tag;
    let arch = GemmaArch::from_gguf_tag(arch_tag);

    let get_meta = |k: &str| -> Option<&MetaValue> {
        raw.metadata.get(k).or_else(|| {
            let suffix = k.strip_prefix("gemma.")?;
            if arch_prefix == "gemma" {
                None
            } else {
                let arch_key = format!("{arch_prefix}.{suffix}");
                raw.metadata.get(&arch_key)
            }
        })
    };
    let get_u32 = |k: &str| -> anyhow::Result<u32> {
        get_meta(k)
            .and_then(MetaValue::as_u32)
            .ok_or_else(|| anyhow::anyhow!("missing GGUF metadata key: {k}"))
    };
    let get_f32 = |k: &str| -> Option<f32> {
        get_meta(k).and_then(|v| match v {
            MetaValue::F32(x) => Some(*x),
            _ => None,
        })
    };
    let get_bool = |k: &str| -> Option<bool> {
        get_meta(k).and_then(|v| match v {
            MetaValue::Bool(b) => Some(*b),
            _ => None,
        })
    };

    let hidden_size = get_u32("gemma.embedding_length")? as usize;
    let num_attention_heads = get_u32("gemma.attention.head_count")? as usize;
    let head_dim = get_u32("gemma.attention.key_length")
        .ok()
        .or_else(|| get_u32("gemma.rope.dimension_count").ok())
        .map(|v| v as usize);

    Ok(GemmaConfig {
        arch,
        vocab_size: get_u32("gemma.vocab_size").unwrap_or(256_000) as usize,
        hidden_size,
        intermediate_size: get_u32("gemma.feed_forward_length")? as usize,
        num_hidden_layers: get_u32("gemma.block_count")? as usize,
        num_attention_heads,
        num_key_value_heads: get_u32("gemma.attention.head_count_kv")? as usize,
        max_position_embeddings: get_u32("gemma.context_length").unwrap_or(8192) as usize,
        rms_norm_eps: get_f32("gemma.attention.layer_norm_rms_epsilon").unwrap_or(1e-6) as f64,
        rope_theta: get_f32("gemma.rope.freq_base").unwrap_or(10_000.0) as f64,
        tie_word_embeddings: get_bool("gemma.tie_word_embeddings").unwrap_or(true),
        attention_bias: get_bool("gemma.attention.bias").unwrap_or(false),
        head_dim,
        attn_logit_softcapping: get_f32("gemma.attn_logit_softcapping"),
        final_logit_softcapping: get_f32("gemma.final_logit_softcapping"),
        sliding_window: get_u32("gemma.attention.sliding_window")
            .ok()
            .map(|v| v as usize),
        query_pre_attn_scalar: get_f32("gemma.attention.query_pre_attn_scalar"),
        effective_num_layers: get_u32("gemma.block_count_effective")
            .ok()
            .map(|v| v as usize),
        num_experts: get_u32("gemma.expert_count").unwrap_or(0) as usize,
        num_experts_used: get_u32("gemma.expert_used_count").unwrap_or(0) as usize,
        expert_ffn_size: get_u32("gemma.expert_feed_forward_length").unwrap_or(0) as usize,
        expert_weights_scale: get_f32("gemma.expert_weights_scale").unwrap_or(1.0),
    })
}