1use rlx_flow::blocks::{GemmaLayerStyle, gemma_strided_layer_mask, gemma2_layer_mask};
19use rlx_gguf::{GgufFile, MetaValue};
20use rlx_ir::op::MaskKind;
21use serde::Deserialize;
22use std::path::Path;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
25#[serde(rename_all = "lowercase")]
26pub enum GemmaArch {
27 #[default]
28 Gemma,
29 Gemma2,
30 Gemma3,
31 Gemma4,
32}
33
34impl GemmaArch {
35 pub fn sliding_window_stride(self) -> usize {
36 match self {
37 GemmaArch::Gemma3 | GemmaArch::Gemma4 => 6,
38 _ => 0,
39 }
40 }
41
42 fn from_gguf_tag(tag: &str) -> Self {
43 match tag {
44 "gemma2" => GemmaArch::Gemma2,
45 "gemma3" | "gemma3n" => GemmaArch::Gemma3,
46 "gemma4" | "gemma4moe" | "gemma4_unified" | "gemma4_unified_text" => GemmaArch::Gemma4,
47 _ => GemmaArch::Gemma,
48 }
49 }
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
56#[serde(rename_all = "snake_case")]
57pub enum GemmaLayerType {
58 SlidingAttention,
59 FullAttention,
60}
61
62#[derive(Debug, Clone, Copy, Deserialize, Default)]
67pub struct GemmaRopeParameters {
68 #[serde(default)]
69 pub partial_rotary_factor: Option<f32>,
70 #[serde(default)]
71 pub rope_theta: Option<f32>,
72 #[serde(default)]
73 pub rope_type: Option<GemmaRopeKind>,
74}
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
77#[serde(rename_all = "snake_case")]
78pub enum GemmaRopeKind {
79 #[default]
80 Default,
81 Proportional,
82 Linear,
83 Dynamic,
84}
85
86#[derive(Debug, Clone, Default, Deserialize)]
87pub struct GemmaRopeMap {
88 #[serde(default)]
89 pub sliding_attention: Option<GemmaRopeParameters>,
90 #[serde(default)]
91 pub full_attention: Option<GemmaRopeParameters>,
92}
93
94#[derive(Debug, Clone, Deserialize)]
95pub struct GemmaConfig {
96 #[serde(default)]
97 pub arch: GemmaArch,
98 pub vocab_size: usize,
99 pub hidden_size: usize,
100 pub intermediate_size: usize,
101 pub num_hidden_layers: usize,
102 pub num_attention_heads: usize,
103 pub num_key_value_heads: usize,
104 pub max_position_embeddings: usize,
105 #[serde(default = "default_rms_norm_eps")]
106 pub rms_norm_eps: f64,
107 #[serde(default = "default_rope_theta")]
108 pub rope_theta: f64,
109 #[serde(default)]
110 pub tie_word_embeddings: bool,
111 #[serde(default)]
112 pub attention_bias: bool,
113 #[serde(default)]
114 pub head_dim: Option<usize>,
115 #[serde(default)]
116 pub attn_logit_softcapping: Option<f32>,
117 #[serde(default)]
118 pub final_logit_softcapping: Option<f32>,
119 #[serde(default)]
120 pub sliding_window: Option<usize>,
121 #[serde(default)]
122 pub query_pre_attn_scalar: Option<f32>,
123 #[serde(default)]
124 pub effective_num_layers: Option<usize>,
125 #[serde(default)]
126 pub num_experts: usize,
127 #[serde(default)]
128 pub num_experts_used: usize,
129 #[serde(default)]
130 pub expert_ffn_size: usize,
131 #[serde(default = "default_expert_weights_scale")]
132 pub expert_weights_scale: f32,
133
134 #[serde(default)]
138 pub layer_types: Vec<GemmaLayerType>,
139 #[serde(default)]
141 pub rope_parameters: GemmaRopeMap,
142 #[serde(default)]
146 pub global_head_dim: Option<usize>,
147 #[serde(default)]
150 pub num_global_key_value_heads: Option<usize>,
151 #[serde(default)]
155 pub attention_k_eq_v: bool,
156 #[serde(default)]
159 pub use_bidirectional_attention: Option<String>,
160}
161
162fn default_rms_norm_eps() -> f64 {
163 1e-6
164}
165fn default_rope_theta() -> f64 {
166 10_000.0
167}
168fn default_expert_weights_scale() -> f32 {
169 1.0
170}
171
172impl GemmaConfig {
173 pub fn from_file(path: &Path) -> anyhow::Result<Self> {
174 let data = std::fs::read_to_string(path)?;
175 let value: serde_json::Value = serde_json::from_str(&data)?;
180 let lm_value = match value.get("text_config") {
181 Some(tc) if tc.is_object() => tc.clone(),
182 _ => value.clone(),
183 };
184 let lm_value = normalize_hf_null_usize_fields(lm_value);
185 let mut cfg: Self = serde_json::from_value(lm_value)?;
186 if cfg.arch == GemmaArch::Gemma {
187 cfg.arch = infer_arch_from_json(&data);
188 }
189 Ok(cfg)
190 }
191
192 pub fn from_gguf(raw: &GgufFile) -> anyhow::Result<Self> {
193 gemma_cfg_from_gguf(raw)
194 }
195
196 pub fn head_dim(&self) -> usize {
197 self.head_dim
198 .unwrap_or(self.hidden_size / self.num_attention_heads)
199 }
200
201 pub fn kv_group_size(&self) -> usize {
202 self.num_attention_heads / self.num_key_value_heads
203 }
204
205 pub fn q_proj_dim(&self) -> usize {
206 self.num_attention_heads * self.head_dim()
207 }
208
209 pub fn kv_proj_dim(&self) -> usize {
210 self.num_key_value_heads * self.head_dim()
211 }
212
213 pub fn layer_style(&self) -> GemmaLayerStyle {
214 match self.arch {
215 GemmaArch::Gemma => GemmaLayerStyle::Gemma,
216 GemmaArch::Gemma2 => GemmaLayerStyle::Gemma2,
217 GemmaArch::Gemma3 => GemmaLayerStyle::Gemma3,
218 GemmaArch::Gemma4 => GemmaLayerStyle::Gemma4,
219 }
220 }
221
222 pub fn active_num_layers(&self) -> usize {
223 self.effective_num_layers.unwrap_or(self.num_hidden_layers)
224 }
225
226 pub fn is_moe(&self) -> bool {
227 self.arch == GemmaArch::Gemma4 && self.num_experts > 0
228 }
229
230 pub fn use_bidirectional_vision(&self) -> bool {
232 self.use_bidirectional_attention.as_deref() == Some("vision")
233 }
234
235 pub fn expert_ffn_dim(&self) -> usize {
236 if self.expert_ffn_size > 0 {
237 self.expert_ffn_size
238 } else {
239 self.intermediate_size
240 }
241 }
242
243 pub fn attn_score_scale(&self) -> Option<f32> {
244 match self.arch {
245 GemmaArch::Gemma => None,
246 GemmaArch::Gemma2 | GemmaArch::Gemma3 | GemmaArch::Gemma4 => {
247 if let Some(s) = self.query_pre_attn_scalar {
248 Some(1.0 / s)
249 } else {
250 Some(1.0 / (self.head_dim() as f32).sqrt())
251 }
252 }
253 }
254 }
255
256 pub fn layer_attn_options(&self, layer: usize) -> (MaskKind, Option<f32>, Option<f32>) {
266 let scale = self.attn_score_scale();
267 let softcap = self.attn_logit_softcapping;
268 let mask = match (self.arch, self.sliding_window) {
269 (_, None) => MaskKind::Causal,
270 (GemmaArch::Gemma2, Some(w)) => gemma2_layer_mask(layer, w),
271 (GemmaArch::Gemma3 | GemmaArch::Gemma4, Some(w)) => {
272 gemma_strided_layer_mask(layer, w, self.arch.sliding_window_stride())
273 }
274 _ => MaskKind::Causal,
275 };
276 (mask, scale, softcap)
277 }
278
279 #[cfg(test)]
280 pub(crate) fn tiny_test() -> Self {
281 Self {
282 arch: GemmaArch::Gemma,
283 vocab_size: 32,
284 hidden_size: 16,
285 intermediate_size: 32,
286 num_hidden_layers: 2,
287 num_attention_heads: 4,
288 num_key_value_heads: 2,
289 max_position_embeddings: 64,
290 rms_norm_eps: 1e-6,
291 rope_theta: 10_000.0,
292 tie_word_embeddings: true,
293 attention_bias: false,
294 head_dim: None,
295 attn_logit_softcapping: None,
296 final_logit_softcapping: None,
297 sliding_window: None,
298 query_pre_attn_scalar: None,
299 effective_num_layers: None,
300 num_experts: 0,
301 num_experts_used: 0,
302 expert_ffn_size: 0,
303 expert_weights_scale: 1.0,
304 layer_types: Vec::new(),
305 rope_parameters: GemmaRopeMap::default(),
306 global_head_dim: None,
307 num_global_key_value_heads: None,
308 attention_k_eq_v: false,
309 use_bidirectional_attention: None,
310 }
311 }
312
313 pub fn is_full_attention_layer(&self, layer: usize) -> bool {
325 if !self.layer_types.is_empty() {
326 return matches!(
327 self.layer_types.get(layer),
328 Some(GemmaLayerType::FullAttention),
329 );
330 }
331 let stride = self.arch.sliding_window_stride();
332 stride > 1 && (layer + 1).is_multiple_of(stride)
333 }
334
335 pub fn layer_head_dim(&self, layer: usize) -> usize {
339 if self.is_full_attention_layer(layer) {
340 self.global_head_dim.unwrap_or_else(|| self.head_dim())
341 } else {
342 self.head_dim()
343 }
344 }
345
346 pub fn layer_num_kv_heads(&self, layer: usize) -> usize {
350 if self.is_full_attention_layer(layer) {
351 self.num_global_key_value_heads
352 .unwrap_or(self.num_key_value_heads)
353 } else {
354 self.num_key_value_heads
355 }
356 }
357
358 pub fn layer_n_rot(&self, layer: usize) -> usize {
362 let dh = self.layer_head_dim(layer);
363 let params = self.layer_rope_parameters(layer);
364 let kind = params
365 .and_then(|p| p.rope_type)
366 .unwrap_or(GemmaRopeKind::Default);
367 let factor = params.and_then(|p| p.partial_rotary_factor);
368 match (kind, factor) {
369 (GemmaRopeKind::Proportional, Some(f)) if f > 0.0 && f < 1.0 => {
370 ((dh as f32) * f).floor() as usize
371 }
372 _ => dh,
373 }
374 }
375
376 pub fn layer_rope_theta(&self, layer: usize) -> f64 {
379 self.layer_rope_parameters(layer)
380 .and_then(|p| p.rope_theta)
381 .map(|t| t as f64)
382 .unwrap_or(self.rope_theta)
383 }
384
385 fn layer_rope_parameters(&self, layer: usize) -> Option<&GemmaRopeParameters> {
386 if self.is_full_attention_layer(layer) {
387 self.rope_parameters.full_attention.as_ref()
388 } else {
389 self.rope_parameters.sliding_attention.as_ref()
390 }
391 }
392}
393
394fn normalize_hf_null_usize_fields(mut value: serde_json::Value) -> serde_json::Value {
396 let Some(obj) = value.as_object_mut() else {
397 return value;
398 };
399 for key in [
400 "num_experts",
401 "num_experts_used",
402 "top_k_experts",
403 "expert_ffn_size",
404 "moe_intermediate_size",
405 "hidden_size_per_layer_input",
406 ] {
407 if obj.get(key).is_some_and(|v| v.is_null()) {
408 obj.insert(key.to_string(), serde_json::Value::from(0usize));
409 }
410 }
411 value
412}
413
414fn infer_arch_from_json(raw: &str) -> GemmaArch {
415 if raw.contains("\"gemma4_unified\"")
420 || raw.contains("\"gemma4_unified_text\"")
421 || raw.contains("\"gemma4\"")
422 || raw.contains("\"gemma4moe\"")
423 || raw.contains("Gemma4UnifiedForConditionalGeneration")
424 || raw.contains("Gemma4ForCausalLM")
425 {
426 return GemmaArch::Gemma4;
427 }
428 if raw.contains("\"model_type\"") {
429 if raw.contains("\"gemma2\"") {
430 return GemmaArch::Gemma2;
431 }
432 if raw.contains("\"gemma3\"") {
433 return GemmaArch::Gemma3;
434 }
435 }
436 GemmaArch::Gemma
437}
438
439pub fn gemma_cfg_from_gguf(raw: &GgufFile) -> anyhow::Result<GemmaConfig> {
440 let arch_tag = raw
441 .metadata
442 .get("general.architecture")
443 .and_then(MetaValue::as_str)
444 .unwrap_or("gemma");
445 let arch_prefix = arch_tag;
446 let arch = GemmaArch::from_gguf_tag(arch_tag);
447
448 let get_meta = |k: &str| -> Option<&MetaValue> {
449 raw.metadata.get(k).or_else(|| {
450 let suffix = k.strip_prefix("gemma.")?;
451 if arch_prefix == "gemma" {
452 None
453 } else {
454 let arch_key = format!("{arch_prefix}.{suffix}");
455 raw.metadata.get(&arch_key)
456 }
457 })
458 };
459 let get_u32 = |k: &str| -> anyhow::Result<u32> {
460 get_meta(k)
461 .and_then(MetaValue::as_u32)
462 .ok_or_else(|| anyhow::anyhow!("missing GGUF metadata key: {k}"))
463 };
464 let get_f32 = |k: &str| -> Option<f32> {
465 get_meta(k).and_then(|v| match v {
466 MetaValue::F32(x) => Some(*x),
467 _ => None,
468 })
469 };
470 let get_bool = |k: &str| -> Option<bool> {
471 get_meta(k).and_then(|v| match v {
472 MetaValue::Bool(b) => Some(*b),
473 _ => None,
474 })
475 };
476
477 let hidden_size = get_u32("gemma.embedding_length")? as usize;
478 let num_attention_heads = get_u32("gemma.attention.head_count")? as usize;
479 let head_dim = get_u32("gemma.attention.key_length")
480 .ok()
481 .or_else(|| get_u32("gemma.rope.dimension_count").ok())
482 .map(|v| v as usize);
483
484 Ok(GemmaConfig {
485 arch,
486 vocab_size: get_u32("gemma.vocab_size").unwrap_or(256_000) as usize,
487 hidden_size,
488 intermediate_size: get_u32("gemma.feed_forward_length")? as usize,
489 num_hidden_layers: get_u32("gemma.block_count")? as usize,
490 num_attention_heads,
491 num_key_value_heads: get_u32("gemma.attention.head_count_kv")? as usize,
492 max_position_embeddings: get_u32("gemma.context_length").unwrap_or(8192) as usize,
493 rms_norm_eps: get_f32("gemma.attention.layer_norm_rms_epsilon").unwrap_or(1e-6) as f64,
494 rope_theta: get_f32("gemma.rope.freq_base").unwrap_or(10_000.0) as f64,
495 tie_word_embeddings: get_bool("gemma.tie_word_embeddings").unwrap_or(true),
496 attention_bias: get_bool("gemma.attention.bias").unwrap_or(false),
497 head_dim,
498 attn_logit_softcapping: get_f32("gemma.attn_logit_softcapping"),
499 final_logit_softcapping: get_f32("gemma.final_logit_softcapping"),
500 sliding_window: get_u32("gemma.attention.sliding_window")
501 .ok()
502 .map(|v| v as usize),
503 query_pre_attn_scalar: get_f32("gemma.attention.query_pre_attn_scalar"),
504 effective_num_layers: get_u32("gemma.block_count_effective")
505 .ok()
506 .map(|v| v as usize),
507 num_experts: get_u32("gemma.expert_count").unwrap_or(0) as usize,
508 num_experts_used: get_u32("gemma.expert_used_count").unwrap_or(0) as usize,
509 expert_ffn_size: get_u32("gemma.expert_feed_forward_length").unwrap_or(0) as usize,
510 expert_weights_scale: get_f32("gemma.expert_weights_scale").unwrap_or(1.0),
511 layer_types: Vec::new(),
516 rope_parameters: GemmaRopeMap::default(),
517 global_head_dim: None,
518 num_global_key_value_heads: None,
519 attention_k_eq_v: false,
520 use_bidirectional_attention: None,
521 })
522}
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527
528 const GEMMA_4_12B_CONFIG: &str = r#"{
533 "architectures": ["Gemma4UnifiedForConditionalGeneration"],
534 "model_type": "gemma4_unified",
535 "tie_word_embeddings": true,
536 "text_config": {
537 "model_type": "gemma4_unified_text",
538 "vocab_size": 262144,
539 "hidden_size": 3840,
540 "intermediate_size": 15360,
541 "num_hidden_layers": 48,
542 "num_attention_heads": 16,
543 "num_key_value_heads": 8,
544 "num_global_key_value_heads": 1,
545 "head_dim": 256,
546 "global_head_dim": 512,
547 "attention_k_eq_v": true,
548 "max_position_embeddings": 131072,
549 "rms_norm_eps": 1e-6,
550 "tie_word_embeddings": true,
551 "attention_bias": false,
552 "final_logit_softcapping": 30.0,
553 "sliding_window": 1024,
554 "layer_types": [
555 "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
556 "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
557 "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
558 "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
559 "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
560 "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
561 "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
562 "sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
563 ],
564 "rope_parameters": {
565 "full_attention": { "partial_rotary_factor": 0.25, "rope_theta": 1000000.0, "rope_type": "proportional" },
566 "sliding_attention": { "rope_theta": 10000.0, "rope_type": "default" }
567 }
568 }
569 }"#;
570
571 #[test]
572 fn gemma_4_12b_unified_config_parses_text_subtree() {
573 let dir = std::env::temp_dir();
574 let path = dir.join("rlx_gemma_gemma4_12b_test_config.json");
575 std::fs::write(&path, GEMMA_4_12B_CONFIG).unwrap();
576 let cfg = GemmaConfig::from_file(&path).unwrap();
577 std::fs::remove_file(&path).ok();
578
579 assert_eq!(cfg.arch, GemmaArch::Gemma4);
580 assert_eq!(cfg.vocab_size, 262_144);
581 assert_eq!(cfg.hidden_size, 3840);
582 assert_eq!(cfg.intermediate_size, 15_360);
583 assert_eq!(cfg.num_hidden_layers, 48);
584 assert_eq!(cfg.num_attention_heads, 16);
585 assert_eq!(cfg.num_key_value_heads, 8);
586 assert_eq!(cfg.head_dim(), 256);
587 assert_eq!(cfg.global_head_dim, Some(512));
588 assert_eq!(cfg.num_global_key_value_heads, Some(1));
589 assert!(cfg.attention_k_eq_v);
590 assert_eq!(cfg.sliding_window, Some(1024));
591 assert_eq!(cfg.final_logit_softcapping, Some(30.0));
592 assert!(cfg.tie_word_embeddings);
593 assert_eq!(cfg.layer_types.len(), 48);
594 assert_eq!(cfg.arch.sliding_window_stride(), 6);
596 }
597
598 #[test]
599 fn hf_null_moe_fields_default_to_zero() {
600 let json = r#"{"num_experts": null, "top_k_experts": null}"#;
601 let v = normalize_hf_null_usize_fields(serde_json::from_str(json).unwrap());
602 let obj = v.as_object().unwrap();
603 assert_eq!(obj["num_experts"], 0);
604 assert_eq!(obj["top_k_experts"], 0);
605 }
606
607 #[test]
608 fn gemma_4_12b_per_layer_dispatch() {
609 let dir = std::env::temp_dir();
610 let path = dir.join("rlx_gemma_gemma4_12b_dispatch_config.json");
611 std::fs::write(&path, GEMMA_4_12B_CONFIG).unwrap();
612 let cfg = GemmaConfig::from_file(&path).unwrap();
613 std::fs::remove_file(&path).ok();
614
615 assert!(!cfg.is_full_attention_layer(0));
617 assert_eq!(cfg.layer_head_dim(0), 256);
618 assert_eq!(cfg.layer_num_kv_heads(0), 8);
619 assert_eq!(cfg.layer_n_rot(0), 256);
620 assert!((cfg.layer_rope_theta(0) - 10_000.0).abs() < 1e-3);
621
622 assert!(cfg.is_full_attention_layer(5));
625 assert_eq!(cfg.layer_head_dim(5), 512);
626 assert_eq!(cfg.layer_num_kv_heads(5), 1);
627 assert_eq!(cfg.layer_n_rot(5), 128);
628 assert!((cfg.layer_rope_theta(5) - 1_000_000.0).abs() < 1e-3);
629
630 assert!(cfg.is_full_attention_layer(47));
632 }
633
634 #[test]
635 fn pre_gemma4_archs_keep_uniform_layer_shape() {
636 let mut cfg = GemmaConfig::tiny_test();
640 cfg.arch = GemmaArch::Gemma3;
641 cfg.head_dim = Some(64);
642 cfg.num_key_value_heads = 2;
643 cfg.rope_theta = 1_000.0;
644 for i in 0..cfg.num_hidden_layers {
645 assert_eq!(cfg.layer_head_dim(i), 64);
646 assert_eq!(cfg.layer_num_kv_heads(i), 2);
647 assert_eq!(cfg.layer_n_rot(i), 64);
648 assert!((cfg.layer_rope_theta(i) - 1_000.0).abs() < 1e-3);
649 }
650 }
651
652 #[test]
653 fn infer_arch_picks_up_gemma4_markers() {
654 assert_eq!(
655 infer_arch_from_json(r#"{"model_type":"gemma4_unified"}"#),
656 GemmaArch::Gemma4,
657 );
658 assert_eq!(
659 infer_arch_from_json(r#"{"architectures":["Gemma4UnifiedForConditionalGeneration"]}"#),
660 GemmaArch::Gemma4,
661 );
662 assert_eq!(
663 infer_arch_from_json(r#"{"model_type":"gemma3"}"#),
664 GemmaArch::Gemma3,
665 );
666 }
667}