1use rlx_flow::blocks::{GemmaLayerStyle, gemma_strided_layer_mask, gemma2_layer_mask};
19use rlx_gguf::{GgufFile, MetaValue};
20use rlx_ir::op::MaskKind;
21use serde::Deserialize;
22use std::path::Path;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
25#[serde(rename_all = "lowercase")]
26pub enum GemmaArch {
27 #[default]
28 Gemma,
29 Gemma2,
30 Gemma3,
31 Gemma4,
32}
33
34impl GemmaArch {
35 pub fn sliding_window_stride(self) -> usize {
36 match self {
37 GemmaArch::Gemma3 | GemmaArch::Gemma4 => 6,
38 _ => 0,
39 }
40 }
41
42 fn from_gguf_tag(tag: &str) -> Self {
43 match tag {
44 "gemma2" => GemmaArch::Gemma2,
45 "gemma3" | "gemma3n" => GemmaArch::Gemma3,
46 "gemma4" | "gemma4moe" => GemmaArch::Gemma4,
47 _ => GemmaArch::Gemma,
48 }
49 }
50}
51
52#[derive(Debug, Clone, Deserialize)]
53pub struct GemmaConfig {
54 #[serde(default)]
55 pub arch: GemmaArch,
56 pub vocab_size: usize,
57 pub hidden_size: usize,
58 pub intermediate_size: usize,
59 pub num_hidden_layers: usize,
60 pub num_attention_heads: usize,
61 pub num_key_value_heads: usize,
62 pub max_position_embeddings: usize,
63 #[serde(default = "default_rms_norm_eps")]
64 pub rms_norm_eps: f64,
65 #[serde(default = "default_rope_theta")]
66 pub rope_theta: f64,
67 #[serde(default)]
68 pub tie_word_embeddings: bool,
69 #[serde(default)]
70 pub attention_bias: bool,
71 #[serde(default)]
72 pub head_dim: Option<usize>,
73 #[serde(default)]
74 pub attn_logit_softcapping: Option<f32>,
75 #[serde(default)]
76 pub final_logit_softcapping: Option<f32>,
77 #[serde(default)]
78 pub sliding_window: Option<usize>,
79 #[serde(default)]
80 pub query_pre_attn_scalar: Option<f32>,
81 #[serde(default)]
82 pub effective_num_layers: Option<usize>,
83 #[serde(default)]
84 pub num_experts: usize,
85 #[serde(default)]
86 pub num_experts_used: usize,
87 #[serde(default)]
88 pub expert_ffn_size: usize,
89 #[serde(default = "default_expert_weights_scale")]
90 pub expert_weights_scale: f32,
91}
92
93fn default_rms_norm_eps() -> f64 {
94 1e-6
95}
96fn default_rope_theta() -> f64 {
97 10_000.0
98}
99fn default_expert_weights_scale() -> f32 {
100 1.0
101}
102
103impl GemmaConfig {
104 pub fn from_file(path: &Path) -> anyhow::Result<Self> {
105 let data = std::fs::read_to_string(path)?;
106 let mut cfg: Self = serde_json::from_str(&data)?;
107 if cfg.arch == GemmaArch::Gemma {
108 cfg.arch = infer_arch_from_json(&data);
109 }
110 Ok(cfg)
111 }
112
113 pub fn from_gguf(raw: &GgufFile) -> anyhow::Result<Self> {
114 gemma_cfg_from_gguf(raw)
115 }
116
117 pub fn head_dim(&self) -> usize {
118 self.head_dim
119 .unwrap_or(self.hidden_size / self.num_attention_heads)
120 }
121
122 pub fn kv_group_size(&self) -> usize {
123 self.num_attention_heads / self.num_key_value_heads
124 }
125
126 pub fn q_proj_dim(&self) -> usize {
127 self.num_attention_heads * self.head_dim()
128 }
129
130 pub fn kv_proj_dim(&self) -> usize {
131 self.num_key_value_heads * self.head_dim()
132 }
133
134 pub fn layer_style(&self) -> GemmaLayerStyle {
135 match self.arch {
136 GemmaArch::Gemma => GemmaLayerStyle::Gemma,
137 GemmaArch::Gemma2 => GemmaLayerStyle::Gemma2,
138 GemmaArch::Gemma3 => GemmaLayerStyle::Gemma3,
139 GemmaArch::Gemma4 => GemmaLayerStyle::Gemma4,
140 }
141 }
142
143 pub fn active_num_layers(&self) -> usize {
144 self.effective_num_layers.unwrap_or(self.num_hidden_layers)
145 }
146
147 pub fn is_moe(&self) -> bool {
148 self.arch == GemmaArch::Gemma4 && self.num_experts > 0
149 }
150
151 pub fn expert_ffn_dim(&self) -> usize {
152 if self.expert_ffn_size > 0 {
153 self.expert_ffn_size
154 } else {
155 self.intermediate_size
156 }
157 }
158
159 pub fn attn_score_scale(&self) -> Option<f32> {
160 match self.arch {
161 GemmaArch::Gemma => None,
162 GemmaArch::Gemma2 | GemmaArch::Gemma3 | GemmaArch::Gemma4 => {
163 if let Some(s) = self.query_pre_attn_scalar {
164 Some(1.0 / s)
165 } else {
166 Some(1.0 / (self.head_dim() as f32).sqrt())
167 }
168 }
169 }
170 }
171
172 pub fn layer_attn_options(&self, layer: usize) -> (MaskKind, Option<f32>, Option<f32>) {
173 let scale = self.attn_score_scale();
174 let softcap = self.attn_logit_softcapping;
175 let mask = match (self.arch, self.sliding_window) {
176 (_, None) => MaskKind::Causal,
177 (GemmaArch::Gemma2, Some(w)) => gemma2_layer_mask(layer, w),
178 (GemmaArch::Gemma3 | GemmaArch::Gemma4, Some(w)) => {
179 gemma_strided_layer_mask(layer, w, self.arch.sliding_window_stride())
180 }
181 _ => MaskKind::Causal,
182 };
183 (mask, scale, softcap)
184 }
185
186 #[cfg(test)]
187 pub(crate) fn tiny_test() -> Self {
188 Self {
189 arch: GemmaArch::Gemma,
190 vocab_size: 32,
191 hidden_size: 16,
192 intermediate_size: 32,
193 num_hidden_layers: 2,
194 num_attention_heads: 4,
195 num_key_value_heads: 2,
196 max_position_embeddings: 64,
197 rms_norm_eps: 1e-6,
198 rope_theta: 10_000.0,
199 tie_word_embeddings: true,
200 attention_bias: false,
201 head_dim: None,
202 attn_logit_softcapping: None,
203 final_logit_softcapping: None,
204 sliding_window: None,
205 query_pre_attn_scalar: None,
206 effective_num_layers: None,
207 num_experts: 0,
208 num_experts_used: 0,
209 expert_ffn_size: 0,
210 expert_weights_scale: 1.0,
211 }
212 }
213}
214
215fn infer_arch_from_json(raw: &str) -> GemmaArch {
216 if raw.contains("\"model_type\"") {
217 if raw.contains("\"gemma2\"") {
218 return GemmaArch::Gemma2;
219 }
220 if raw.contains("\"gemma3\"") {
221 return GemmaArch::Gemma3;
222 }
223 }
224 GemmaArch::Gemma
225}
226
227pub fn gemma_cfg_from_gguf(raw: &GgufFile) -> anyhow::Result<GemmaConfig> {
228 let arch_tag = raw
229 .metadata
230 .get("general.architecture")
231 .and_then(MetaValue::as_str)
232 .unwrap_or("gemma");
233 let arch_prefix = arch_tag;
234 let arch = GemmaArch::from_gguf_tag(arch_tag);
235
236 let get_meta = |k: &str| -> Option<&MetaValue> {
237 raw.metadata.get(k).or_else(|| {
238 let suffix = k.strip_prefix("gemma.")?;
239 if arch_prefix == "gemma" {
240 None
241 } else {
242 let arch_key = format!("{arch_prefix}.{suffix}");
243 raw.metadata.get(&arch_key)
244 }
245 })
246 };
247 let get_u32 = |k: &str| -> anyhow::Result<u32> {
248 get_meta(k)
249 .and_then(MetaValue::as_u32)
250 .ok_or_else(|| anyhow::anyhow!("missing GGUF metadata key: {k}"))
251 };
252 let get_f32 = |k: &str| -> Option<f32> {
253 get_meta(k).and_then(|v| match v {
254 MetaValue::F32(x) => Some(*x),
255 _ => None,
256 })
257 };
258 let get_bool = |k: &str| -> Option<bool> {
259 get_meta(k).and_then(|v| match v {
260 MetaValue::Bool(b) => Some(*b),
261 _ => None,
262 })
263 };
264
265 let hidden_size = get_u32("gemma.embedding_length")? as usize;
266 let num_attention_heads = get_u32("gemma.attention.head_count")? as usize;
267 let head_dim = get_u32("gemma.attention.key_length")
268 .ok()
269 .or_else(|| get_u32("gemma.rope.dimension_count").ok())
270 .map(|v| v as usize);
271
272 Ok(GemmaConfig {
273 arch,
274 vocab_size: get_u32("gemma.vocab_size").unwrap_or(256_000) as usize,
275 hidden_size,
276 intermediate_size: get_u32("gemma.feed_forward_length")? as usize,
277 num_hidden_layers: get_u32("gemma.block_count")? as usize,
278 num_attention_heads,
279 num_key_value_heads: get_u32("gemma.attention.head_count_kv")? as usize,
280 max_position_embeddings: get_u32("gemma.context_length").unwrap_or(8192) as usize,
281 rms_norm_eps: get_f32("gemma.attention.layer_norm_rms_epsilon").unwrap_or(1e-6) as f64,
282 rope_theta: get_f32("gemma.rope.freq_base").unwrap_or(10_000.0) as f64,
283 tie_word_embeddings: get_bool("gemma.tie_word_embeddings").unwrap_or(true),
284 attention_bias: get_bool("gemma.attention.bias").unwrap_or(false),
285 head_dim,
286 attn_logit_softcapping: get_f32("gemma.attn_logit_softcapping"),
287 final_logit_softcapping: get_f32("gemma.final_logit_softcapping"),
288 sliding_window: get_u32("gemma.attention.sliding_window")
289 .ok()
290 .map(|v| v as usize),
291 query_pre_attn_scalar: get_f32("gemma.attention.query_pre_attn_scalar"),
292 effective_num_layers: get_u32("gemma.block_count_effective")
293 .ok()
294 .map(|v| v as usize),
295 num_experts: get_u32("gemma.expert_count").unwrap_or(0) as usize,
296 num_experts_used: get_u32("gemma.expert_used_count").unwrap_or(0) as usize,
297 expert_ffn_size: get_u32("gemma.expert_feed_forward_length").unwrap_or(0) as usize,
298 expert_weights_scale: get_f32("gemma.expert_weights_scale").unwrap_or(1.0),
299 })
300}