use crate::error::{Result, RullamaError};
use crate::gguf::{GgufReader, GgufValue};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LayerKind {
SlidingWindow,
Global,
}
#[derive(Debug, Clone)]
pub struct Gemma4Config {
pub n_layers: u32,
pub d_model: u32,
pub max_pos: u32,
pub n_heads: u32,
pub n_kv_heads_swa: u32,
pub n_kv_heads_global: u32,
pub head_dim_global: u32,
pub head_dim_swa: u32,
pub rms_norm_eps: f32,
pub sliding_window: u32,
pub layer_kinds: Vec<LayerKind>,
pub shared_kv_layers: u32,
pub ffn_inter: Vec<u32>,
pub rope_freq_base: f32,
pub rope_freq_base_swa: f32,
pub rope_dim_global: u32,
pub rope_dim_swa: u32,
pub final_logit_softcap: f32,
pub ple_dim: u32,
pub vocab_size: u32,
pub bos_id: Option<u32>,
pub eos_ids: Vec<u32>,
pub pad_id: Option<u32>,
pub unk_id: Option<u32>,
}
impl Gemma4Config {
pub fn from_gguf(r: &GgufReader) -> Result<Self> {
let arch = r.get("general.architecture")?.as_str()?;
if arch != "gemma4" {
return Err(RullamaError::Config(format!(
"expected architecture 'gemma4', got '{arch}'"
)));
}
let n_layers = r.get("gemma4.block_count")?.as_u32()?;
let d_model = r.get("gemma4.embedding_length")?.as_u32()?;
let max_pos = r.get("gemma4.context_length")?.as_u32()?;
let n_heads = r.get("gemma4.attention.head_count")?.as_u32()?;
let n_kv_heads_swa = r.get("gemma4.attention.head_count_kv")?.as_u32()?;
let n_kv_heads_global = r
.get_opt("gemma4.attention.global_head_count_kv")
.map(|v| v.as_u32())
.transpose()?
.unwrap_or(n_kv_heads_swa);
let head_dim_global = r.get("gemma4.attention.key_length")?.as_u32()?;
let head_dim_swa = r.get("gemma4.attention.key_length_swa")?.as_u32()?;
let rms_norm_eps = r.get("gemma4.attention.layer_norm_rms_epsilon")?.as_f32()?;
let sliding_window = r.get("gemma4.attention.sliding_window")?.as_u32()?;
let shared_kv_layers = r
.get_opt("gemma4.attention.shared_kv_layers")
.map(|v| v.as_u32())
.transpose()?
.unwrap_or(0);
let pattern = r
.get("gemma4.attention.sliding_window_pattern")?
.as_bool_array()?;
if pattern.len() as u32 != n_layers {
return Err(RullamaError::Config(format!(
"sliding_window_pattern length {} != n_layers {}",
pattern.len(),
n_layers
)));
}
let layer_kinds: Vec<LayerKind> = pattern
.iter()
.map(|&b| {
if b {
LayerKind::SlidingWindow
} else {
LayerKind::Global
}
})
.collect();
let ffn_inter: Vec<u32> = match r.get("gemma4.feed_forward_length")? {
GgufValue::ArrayU32(v) => v.clone(),
GgufValue::ArrayU64(v) => v.iter().map(|&x| x as u32).collect(),
GgufValue::ArrayI32(v) => v.iter().map(|&x| x as u32).collect(),
GgufValue::ArrayI64(v) => v.iter().map(|&x| x as u32).collect(),
scalar => {
let s = scalar.as_u32()?;
vec![s; n_layers as usize]
}
};
if ffn_inter.len() as u32 != n_layers {
return Err(RullamaError::Config(format!(
"feed_forward_length array length {} != n_layers {}",
ffn_inter.len(),
n_layers
)));
}
let rope_freq_base = r.get("gemma4.rope.freq_base")?.as_f32()?;
let rope_freq_base_swa = r.get("gemma4.rope.freq_base_swa")?.as_f32()?;
let rope_dim_global = r
.get_opt("gemma4.rope.dimension_count")
.map(|v| v.as_u32())
.transpose()?
.unwrap_or(head_dim_global / 4); let rope_dim_swa = r
.get_opt("gemma4.rope.dimension_count_swa")
.map(|v| v.as_u32())
.transpose()?
.unwrap_or(head_dim_swa);
let final_logit_softcap = r.get("gemma4.final_logit_softcapping")?.as_f32()?;
let ple_dim = r
.get_opt("gemma4.embedding_length_per_layer_input")
.map(|v| v.as_u32())
.transpose()?
.unwrap_or(0);
let tokens = r.get("tokenizer.ggml.tokens")?.as_string_array()?;
let vocab_size = tokens.len() as u32;
let bos_id = r
.get_opt("tokenizer.ggml.bos_token_id")
.map(|v| v.as_u32())
.transpose()?;
let pad_id = r
.get_opt("tokenizer.ggml.padding_token_id")
.map(|v| v.as_u32())
.transpose()?;
let unk_id = r
.get_opt("tokenizer.ggml.unknown_token_id")
.map(|v| v.as_u32())
.transpose()?;
let eos_ids: Vec<u32> = match r.get_opt("tokenizer.ggml.eos_token_ids") {
Some(v) => v.as_u32_array()?,
None => match r.get_opt("tokenizer.ggml.eos_token_id") {
Some(v) => vec![v.as_u32()?],
None => Vec::new(),
},
};
Ok(Self {
n_layers,
d_model,
max_pos,
n_heads,
n_kv_heads_swa,
n_kv_heads_global,
head_dim_global,
head_dim_swa,
rms_norm_eps,
sliding_window,
layer_kinds,
shared_kv_layers,
ffn_inter,
rope_freq_base,
rope_freq_base_swa,
rope_dim_global,
rope_dim_swa,
final_logit_softcap,
ple_dim,
vocab_size,
bos_id,
eos_ids,
pad_id,
unk_id,
})
}
pub fn has_ple(&self) -> bool {
self.ple_dim > 0
}
pub fn kind(&self, i: u32) -> LayerKind {
self.layer_kinds[i as usize]
}
pub fn ffn(&self, i: u32) -> u32 {
self.ffn_inter[i as usize]
}
pub fn n_kv_heads(&self, i: u32) -> u32 {
match self.kind(i) {
LayerKind::SlidingWindow => self.n_kv_heads_swa,
LayerKind::Global => self.n_kv_heads_global,
}
}
pub fn head_dim(&self, i: u32) -> u32 {
match self.kind(i) {
LayerKind::SlidingWindow => self.head_dim_swa,
LayerKind::Global => self.head_dim_global,
}
}
}