use rlx_flow::blocks::{GemmaLayerStyle, gemma_strided_layer_mask, gemma2_layer_mask};
use rlx_gguf::{GgufFile, MetaValue};
use rlx_ir::op::MaskKind;
use serde::Deserialize;
use std::path::Path;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum GemmaArch {
#[default]
Gemma,
Gemma2,
Gemma3,
Gemma4,
}
impl GemmaArch {
pub fn sliding_window_stride(self) -> usize {
match self {
GemmaArch::Gemma3 | GemmaArch::Gemma4 => 6,
_ => 0,
}
}
fn from_gguf_tag(tag: &str) -> Self {
match tag {
"gemma2" => GemmaArch::Gemma2,
"gemma3" | "gemma3n" => GemmaArch::Gemma3,
"gemma4" | "gemma4moe" => GemmaArch::Gemma4,
_ => GemmaArch::Gemma,
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct GemmaConfig {
#[serde(default)]
pub arch: GemmaArch,
pub vocab_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: usize,
pub max_position_embeddings: usize,
#[serde(default = "default_rms_norm_eps")]
pub rms_norm_eps: f64,
#[serde(default = "default_rope_theta")]
pub rope_theta: f64,
#[serde(default)]
pub tie_word_embeddings: bool,
#[serde(default)]
pub attention_bias: bool,
#[serde(default)]
pub head_dim: Option<usize>,
#[serde(default)]
pub attn_logit_softcapping: Option<f32>,
#[serde(default)]
pub final_logit_softcapping: Option<f32>,
#[serde(default)]
pub sliding_window: Option<usize>,
#[serde(default)]
pub query_pre_attn_scalar: Option<f32>,
#[serde(default)]
pub effective_num_layers: Option<usize>,
#[serde(default)]
pub num_experts: usize,
#[serde(default)]
pub num_experts_used: usize,
#[serde(default)]
pub expert_ffn_size: usize,
#[serde(default = "default_expert_weights_scale")]
pub expert_weights_scale: f32,
}
fn default_rms_norm_eps() -> f64 {
1e-6
}
fn default_rope_theta() -> f64 {
10_000.0
}
fn default_expert_weights_scale() -> f32 {
1.0
}
impl GemmaConfig {
pub fn from_file(path: &Path) -> anyhow::Result<Self> {
let data = std::fs::read_to_string(path)?;
let mut cfg: Self = serde_json::from_str(&data)?;
if cfg.arch == GemmaArch::Gemma {
cfg.arch = infer_arch_from_json(&data);
}
Ok(cfg)
}
pub fn from_gguf(raw: &GgufFile) -> anyhow::Result<Self> {
gemma_cfg_from_gguf(raw)
}
pub fn head_dim(&self) -> usize {
self.head_dim
.unwrap_or(self.hidden_size / self.num_attention_heads)
}
pub fn kv_group_size(&self) -> usize {
self.num_attention_heads / self.num_key_value_heads
}
pub fn q_proj_dim(&self) -> usize {
self.num_attention_heads * self.head_dim()
}
pub fn kv_proj_dim(&self) -> usize {
self.num_key_value_heads * self.head_dim()
}
pub fn layer_style(&self) -> GemmaLayerStyle {
match self.arch {
GemmaArch::Gemma => GemmaLayerStyle::Gemma,
GemmaArch::Gemma2 => GemmaLayerStyle::Gemma2,
GemmaArch::Gemma3 => GemmaLayerStyle::Gemma3,
GemmaArch::Gemma4 => GemmaLayerStyle::Gemma4,
}
}
pub fn active_num_layers(&self) -> usize {
self.effective_num_layers.unwrap_or(self.num_hidden_layers)
}
pub fn is_moe(&self) -> bool {
self.arch == GemmaArch::Gemma4 && self.num_experts > 0
}
pub fn expert_ffn_dim(&self) -> usize {
if self.expert_ffn_size > 0 {
self.expert_ffn_size
} else {
self.intermediate_size
}
}
pub fn attn_score_scale(&self) -> Option<f32> {
match self.arch {
GemmaArch::Gemma => None,
GemmaArch::Gemma2 | GemmaArch::Gemma3 | GemmaArch::Gemma4 => {
if let Some(s) = self.query_pre_attn_scalar {
Some(1.0 / s)
} else {
Some(1.0 / (self.head_dim() as f32).sqrt())
}
}
}
}
pub fn layer_attn_options(&self, layer: usize) -> (MaskKind, Option<f32>, Option<f32>) {
let scale = self.attn_score_scale();
let softcap = self.attn_logit_softcapping;
let mask = match (self.arch, self.sliding_window) {
(_, None) => MaskKind::Causal,
(GemmaArch::Gemma2, Some(w)) => gemma2_layer_mask(layer, w),
(GemmaArch::Gemma3 | GemmaArch::Gemma4, Some(w)) => {
gemma_strided_layer_mask(layer, w, self.arch.sliding_window_stride())
}
_ => MaskKind::Causal,
};
(mask, scale, softcap)
}
#[cfg(test)]
pub(crate) fn tiny_test() -> Self {
Self {
arch: GemmaArch::Gemma,
vocab_size: 32,
hidden_size: 16,
intermediate_size: 32,
num_hidden_layers: 2,
num_attention_heads: 4,
num_key_value_heads: 2,
max_position_embeddings: 64,
rms_norm_eps: 1e-6,
rope_theta: 10_000.0,
tie_word_embeddings: true,
attention_bias: false,
head_dim: None,
attn_logit_softcapping: None,
final_logit_softcapping: None,
sliding_window: None,
query_pre_attn_scalar: None,
effective_num_layers: None,
num_experts: 0,
num_experts_used: 0,
expert_ffn_size: 0,
expert_weights_scale: 1.0,
}
}
}
fn infer_arch_from_json(raw: &str) -> GemmaArch {
if raw.contains("\"model_type\"") {
if raw.contains("\"gemma2\"") {
return GemmaArch::Gemma2;
}
if raw.contains("\"gemma3\"") {
return GemmaArch::Gemma3;
}
}
GemmaArch::Gemma
}
pub fn gemma_cfg_from_gguf(raw: &GgufFile) -> anyhow::Result<GemmaConfig> {
let arch_tag = raw
.metadata
.get("general.architecture")
.and_then(MetaValue::as_str)
.unwrap_or("gemma");
let arch_prefix = arch_tag;
let arch = GemmaArch::from_gguf_tag(arch_tag);
let get_meta = |k: &str| -> Option<&MetaValue> {
raw.metadata.get(k).or_else(|| {
let suffix = k.strip_prefix("gemma.")?;
if arch_prefix == "gemma" {
None
} else {
let arch_key = format!("{arch_prefix}.{suffix}");
raw.metadata.get(&arch_key)
}
})
};
let get_u32 = |k: &str| -> anyhow::Result<u32> {
get_meta(k)
.and_then(MetaValue::as_u32)
.ok_or_else(|| anyhow::anyhow!("missing GGUF metadata key: {k}"))
};
let get_f32 = |k: &str| -> Option<f32> {
get_meta(k).and_then(|v| match v {
MetaValue::F32(x) => Some(*x),
_ => None,
})
};
let get_bool = |k: &str| -> Option<bool> {
get_meta(k).and_then(|v| match v {
MetaValue::Bool(b) => Some(*b),
_ => None,
})
};
let hidden_size = get_u32("gemma.embedding_length")? as usize;
let num_attention_heads = get_u32("gemma.attention.head_count")? as usize;
let head_dim = get_u32("gemma.attention.key_length")
.ok()
.or_else(|| get_u32("gemma.rope.dimension_count").ok())
.map(|v| v as usize);
Ok(GemmaConfig {
arch,
vocab_size: get_u32("gemma.vocab_size").unwrap_or(256_000) as usize,
hidden_size,
intermediate_size: get_u32("gemma.feed_forward_length")? as usize,
num_hidden_layers: get_u32("gemma.block_count")? as usize,
num_attention_heads,
num_key_value_heads: get_u32("gemma.attention.head_count_kv")? as usize,
max_position_embeddings: get_u32("gemma.context_length").unwrap_or(8192) as usize,
rms_norm_eps: get_f32("gemma.attention.layer_norm_rms_epsilon").unwrap_or(1e-6) as f64,
rope_theta: get_f32("gemma.rope.freq_base").unwrap_or(10_000.0) as f64,
tie_word_embeddings: get_bool("gemma.tie_word_embeddings").unwrap_or(true),
attention_bias: get_bool("gemma.attention.bias").unwrap_or(false),
head_dim,
attn_logit_softcapping: get_f32("gemma.attn_logit_softcapping"),
final_logit_softcapping: get_f32("gemma.final_logit_softcapping"),
sliding_window: get_u32("gemma.attention.sliding_window")
.ok()
.map(|v| v as usize),
query_pre_attn_scalar: get_f32("gemma.attention.query_pre_attn_scalar"),
effective_num_layers: get_u32("gemma.block_count_effective")
.ok()
.map(|v| v as usize),
num_experts: get_u32("gemma.expert_count").unwrap_or(0) as usize,
num_experts_used: get_u32("gemma.expert_used_count").unwrap_or(0) as usize,
expert_ffn_size: get_u32("gemma.expert_feed_forward_length").unwrap_or(0) as usize,
expert_weights_scale: get_f32("gemma.expert_weights_scale").unwrap_or(1.0),
})
}