use super::config::{GemmaArch, GemmaConfig};
pub use rlx_flow::rope::{
build_tables as build_rope_tables, default_inv_freq, inv_freq_with_factors,
};
pub fn apply_rope_freq_factors(base: &[f64], factors: Option<&[f32]>) -> Vec<f64> {
match factors {
Some(f) if !f.is_empty() && f.len() == base.len() => inv_freq_with_factors(base, f),
_ => base.to_vec(),
}
}
fn first_sliding_layer(cfg: &GemmaConfig) -> Option<usize> {
(0..cfg.num_hidden_layers).find(|&i| !cfg.is_full_attention_layer(i))
}
fn first_full_layer(cfg: &GemmaConfig) -> Option<usize> {
(0..cfg.num_hidden_layers).find(|&i| cfg.is_full_attention_layer(i))
}
pub fn sliding_rope_params(cfg: &GemmaConfig) -> (f64, usize) {
if cfg.arch == GemmaArch::Gemma4 {
if let Some(si) = first_sliding_layer(cfg) {
return (cfg.layer_rope_theta(si), cfg.layer_n_rot(si));
}
}
(cfg.rope_theta, cfg.head_dim())
}
pub fn global_rope_params(cfg: &GemmaConfig) -> Option<(f64, usize)> {
if cfg.arch != GemmaArch::Gemma4 || cfg.layer_types.is_empty() {
return None;
}
let fi = first_full_layer(cfg)?;
let si = first_sliding_layer(cfg)?;
let theta = cfg.layer_rope_theta(fi);
let n_rot = cfg.layer_n_rot(fi);
let differs = (theta - cfg.layer_rope_theta(si)).abs() > 1e-3 || n_rot != cfg.layer_n_rot(si);
differs.then_some((theta, n_rot))
}
pub fn resolve_inv_freq(cfg: &GemmaConfig, rope_freq_factors: Option<&[f32]>) -> Vec<f64> {
let (theta, n_rot) = sliding_rope_params(cfg);
let base = default_inv_freq(theta, n_rot);
apply_rope_freq_factors(&base, rope_freq_factors)
}
pub fn resolve_global_inv_freq(
cfg: &GemmaConfig,
rope_freq_factors: Option<&[f32]>,
) -> Option<Vec<f64>> {
let (theta, n_rot) = global_rope_params(cfg)?;
let base = default_inv_freq(theta, n_rot);
if let Some(f) = rope_freq_factors.filter(|f| !f.is_empty()) {
if f.len() == base.len() {
return Some(apply_rope_freq_factors(&base, Some(f)));
}
if let Some(gdh) = cfg.global_head_dim {
let full = default_inv_freq(theta, gdh);
if f.len() == full.len() {
let scaled = inv_freq_with_factors(&full, f);
let half = n_rot / 2;
return Some(scaled[..half.min(scaled.len())].to_vec());
}
}
}
Some(base)
}
pub fn rope_slice(inv_freq: &[f64], pos: usize) -> (Vec<f32>, Vec<f32>) {
let half = inv_freq.len();
let mut cos = vec![0f32; half];
let mut sin = vec![0f32; half];
for (i, &freq) in inv_freq.iter().enumerate() {
let angle = pos as f64 * freq;
let (s, c) = angle.sin_cos();
cos[i] = c as f32;
sin[i] = s as f32;
}
(cos, sin)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{
GemmaArch, GemmaConfig, GemmaLayerType, GemmaRopeKind, GemmaRopeMap, GemmaRopeParameters,
};
fn gemma4_12b_cfg() -> GemmaConfig {
GemmaConfig {
arch: GemmaArch::Gemma4,
vocab_size: 262_144,
hidden_size: 3840,
intermediate_size: 15_360,
num_hidden_layers: 48,
num_attention_heads: 16,
num_key_value_heads: 8,
max_position_embeddings: 8192,
rms_norm_eps: 1e-6,
rope_theta: 10_000.0,
tie_word_embeddings: true,
attention_bias: false,
head_dim: Some(256),
attn_logit_softcapping: None,
final_logit_softcapping: Some(30.0),
sliding_window: Some(1024),
query_pre_attn_scalar: None,
effective_num_layers: None,
num_experts: 0,
num_experts_used: 0,
expert_ffn_size: 0,
expert_weights_scale: 1.0,
layer_types: (0..48)
.map(|i| {
if (i + 1) % 6 == 0 {
GemmaLayerType::FullAttention
} else {
GemmaLayerType::SlidingAttention
}
})
.collect(),
rope_parameters: GemmaRopeMap {
sliding_attention: Some(GemmaRopeParameters {
rope_theta: Some(10_000.0),
rope_type: Some(GemmaRopeKind::Default),
partial_rotary_factor: None,
}),
full_attention: Some(GemmaRopeParameters {
rope_theta: Some(1_000_000.0),
rope_type: Some(GemmaRopeKind::Proportional),
partial_rotary_factor: Some(0.25),
}),
},
global_head_dim: Some(512),
num_global_key_value_heads: Some(1),
attention_k_eq_v: true,
use_bidirectional_attention: Some("vision".into()),
}
}
#[test]
fn tables_len_matches_max_pos() {
let inv = default_inv_freq(10_000.0, 8);
let (c, s) = build_rope_tables(&inv, 4);
assert_eq!(c.len(), 4 * inv.len());
assert_eq!(s.len(), c.len());
}
#[test]
fn gguf_rope_factors_sized_for_global_head_do_not_break_sliding() {
let cfg = gemma4_12b_cfg();
let factors = vec![1.0f32; 256];
let sliding = resolve_inv_freq(&cfg, Some(&factors));
assert_eq!(sliding.len(), 128);
let global = resolve_global_inv_freq(&cfg, Some(&factors)).expect("global table");
assert_eq!(global.len(), 64);
}
#[test]
fn mismatched_factors_are_skipped_not_panicked() {
let cfg = gemma4_12b_cfg();
let factors = vec![1.0f32; 999];
let sliding = resolve_inv_freq(&cfg, Some(&factors));
assert_eq!(sliding.len(), 128);
}
}