use std::collections::HashMap;
use std::fmt;
use std::path::Path;
use std::sync::Arc;
use anyhow::Result;
use rlx_flow::blocks::{
DecodeRopeParamsStage, EmbedScaleStage, GemmaDecodeLayerSpec, GemmaDecodeLayerStage,
GemmaLayerStyle, GemmaRmsNormStage, LmHeadStage, LogitSoftcapStage, RopeTablesStage,
gemma_attn_spec, gemma_prefill_layer_composed,
};
use rlx_flow::{BuiltModel, CompileProfile, FlowStage, ModelFlow, SideOutputs};
use rlx_ir::dynamic::sym;
use rlx_ir::hir::HirModule;
use rlx_ir::shape::Dim;
use rlx_ir::{DType, Graph, Shape};
use super::config::{GemmaArch, GemmaConfig};
use super::rope::{build_rope_tables, resolve_inv_freq};
use rlx_core::flow_bridge::{WeightLoaderSource, load_compile_profile};
use rlx_core::weight_loader::WeightLoader;
pub const GEMMA_PROFILE_FILE: &str = "gemma.rlx.toml";
pub fn gemma_profile_near_weights(weights: &Path, decode: bool) -> CompileProfile {
let default = if decode {
CompileProfile::gemma_decode()
} else {
CompileProfile::gemma_prefill()
};
let dir = weights.parent().unwrap_or_else(|| Path::new("."));
load_compile_profile(&dir.join(GEMMA_PROFILE_FILE), default)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GemmaMode {
Prefill,
Decode,
}
pub enum GemmaLayerCtx<'a> {
Prefill {
index: usize,
style: GemmaLayerStyle,
attn: rlx_flow::blocks::SelfAttnPrefillSpec,
kv_sink: &'a SideOutputs,
export_kv: bool,
head_dim: usize,
eps: f32,
},
Decode {
index: usize,
spec: GemmaDecodeLayerSpec,
kv_out: &'a SideOutputs,
},
}
impl GemmaLayerCtx<'_> {
pub fn index(&self) -> usize {
match self {
Self::Prefill { index, .. } | Self::Decode { index, .. } => *index,
}
}
pub fn default_stage(&self) -> FlowStage {
match self {
Self::Prefill {
index,
style,
attn,
kv_sink,
export_kv,
head_dim: _,
eps,
} => gemma_prefill_layer_composed(
*index,
*style,
attn.clone(),
*eps,
if *export_kv {
Some(kv_sink.inner())
} else {
None
},
),
Self::Decode {
index,
spec,
kv_out,
} => FlowStage::Named {
name: format!("layer{index}"),
inner: Arc::new(FlowStage::GemmaDecodeLayer(GemmaDecodeLayerStage::layer(
*index,
spec.clone(),
kv_out.inner(),
))),
},
}
}
}
type LayerFn = Arc<dyn Fn(GemmaLayerCtx<'_>) -> FlowStage + Send + Sync>;
type FlowPatchFn = Arc<dyn Fn(ModelFlow) -> ModelFlow + Send + Sync>;
#[derive(Clone)]
pub struct GemmaFlow<'a> {
cfg: &'a GemmaConfig,
mode: GemmaMode,
batch: usize,
seq: usize,
past_seq: usize,
dynamic_seq: bool,
dynamic_past: bool,
with_lm_head: bool,
with_kv_outputs: bool,
last_logits_only: bool,
use_custom_mask: bool,
profile: Option<CompileProfile>,
before_layers: Vec<FlowStage>,
after_layers: Vec<FlowStage>,
layer_fn: Option<LayerFn>,
flow_patch: Option<FlowPatchFn>,
prefill_hidden: bool,
media_attn_bias: bool,
}
impl fmt::Debug for GemmaFlow<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("GemmaFlow")
.field("mode", &self.mode)
.field("batch", &self.batch)
.field("seq", &self.seq)
.field("past_seq", &self.past_seq)
.field("dynamic_seq", &self.dynamic_seq)
.field("dynamic_past", &self.dynamic_past)
.field("with_lm_head", &self.with_lm_head)
.field("with_kv_outputs", &self.with_kv_outputs)
.field("last_logits_only", &self.last_logits_only)
.field("use_custom_mask", &self.use_custom_mask)
.field("profile", &self.profile)
.field("before_layers", &self.before_layers.len())
.field("after_layers", &self.after_layers.len())
.field("layer_fn", &self.layer_fn.is_some())
.field("flow_patch", &self.flow_patch.is_some())
.finish_non_exhaustive()
}
}
impl<'a> GemmaFlow<'a> {
pub fn new(cfg: &'a GemmaConfig) -> Self {
Self {
cfg,
mode: GemmaMode::Prefill,
batch: 1,
seq: 128,
past_seq: 0,
dynamic_seq: false,
dynamic_past: false,
with_lm_head: false,
with_kv_outputs: false,
last_logits_only: false,
use_custom_mask: false,
profile: None,
before_layers: Vec::new(),
after_layers: Vec::new(),
layer_fn: None,
flow_patch: None,
prefill_hidden: false,
media_attn_bias: false,
}
}
pub fn prefill_from_hidden(mut self) -> Self {
self.prefill_hidden = true;
self
}
pub fn prefill_media_attn_bias(mut self) -> Self {
self.media_attn_bias = true;
self
}
pub fn for_prefill(cfg: &'a GemmaConfig, batch: usize, seq: usize) -> Self {
Self::new(cfg).prefill().batch(batch).seq(seq)
}
pub fn for_decode(cfg: &'a GemmaConfig, batch: usize, past_seq: usize) -> Self {
Self::new(cfg)
.decode()
.batch(batch)
.past(past_seq)
.lm_head()
}
pub fn prefill(mut self) -> Self {
self.mode = GemmaMode::Prefill;
self
}
pub fn decode(mut self) -> Self {
self.mode = GemmaMode::Decode;
self
}
pub fn batch(mut self, batch: usize) -> Self {
self.batch = batch;
self
}
pub fn seq(mut self, seq: usize) -> Self {
self.seq = seq;
self
}
pub fn past(mut self, past_seq: usize) -> Self {
self.past_seq = past_seq;
self
}
pub fn dynamic_seq(mut self) -> Self {
self.dynamic_seq = true;
self
}
pub fn dynamic_past(mut self) -> Self {
self.dynamic_past = true;
self
}
pub fn lm_head(mut self) -> Self {
self.with_lm_head = true;
self
}
pub fn hidden_only(mut self) -> Self {
self.with_lm_head = false;
self.last_logits_only = false;
self
}
pub fn last_token_logits(mut self) -> Self {
self.with_lm_head = true;
self.last_logits_only = true;
self
}
pub fn export_kv(mut self) -> Self {
self.with_kv_outputs = true;
self
}
pub fn custom_mask(mut self) -> Self {
self.use_custom_mask = true;
self
}
pub fn profile(mut self, profile: CompileProfile) -> Self {
self.profile = Some(profile);
self
}
pub fn profile_prefill(mut self) -> Self {
self.profile = Some(CompileProfile::gemma_prefill());
self
}
pub fn profile_decode(mut self) -> Self {
self.profile = Some(CompileProfile::gemma_decode());
self
}
pub fn profile_near(mut self, weights_path: &Path) -> Self {
let decode = self.mode == GemmaMode::Decode;
self.profile = Some(gemma_profile_near_weights(weights_path, decode));
self
}
pub fn before_layers(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
self.before_layers.extend(stages);
self
}
pub fn after_layers(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
self.after_layers.extend(stages);
self
}
pub fn layer<F>(mut self, f: F) -> Self
where
F: Fn(GemmaLayerCtx<'_>) -> FlowStage + Send + Sync + 'static,
{
self.layer_fn = Some(Arc::new(f));
self
}
pub fn patch_flow<F>(mut self, f: F) -> Self
where
F: Fn(ModelFlow) -> ModelFlow + Send + Sync + 'static,
{
self.flow_patch = Some(Arc::new(f));
self
}
pub fn build(self, weights: &mut dyn WeightLoader) -> Result<BuiltModel> {
match self.mode {
GemmaMode::Prefill => self.build_prefill(weights),
GemmaMode::Decode => self.build_decode(weights),
}
}
fn build_prefill(self, weights: &mut dyn WeightLoader) -> Result<BuiltModel> {
if self.dynamic_seq && self.batch != 1 {
anyhow::bail!("gemma: dynamic_seq prefill requires batch=1");
}
let cfg = self.cfg;
let profile = self.profile.unwrap_or_else(CompileProfile::gemma_prefill);
let f = DType::F32;
let h = cfg.hidden_size;
let eps = cfg.rms_norm_eps as f32;
let layer_style = cfg.layer_style();
let hidden_shape = prefill_hidden_shape(self.batch, self.seq, h, self.dynamic_seq, f);
let input_shape = prefill_input_shape(self.batch, self.seq, self.dynamic_seq);
let rope_factors = weights.take("rope_freqs.weight").ok().map(|(data, _)| data);
let inv_freq = resolve_inv_freq(cfg, rope_factors.as_deref());
let (cos_data, sin_data) = build_rope_tables(&inv_freq, cfg.max_position_embeddings);
let global_rope =
secondary_rope_tables(cfg, cfg.max_position_embeddings, rope_factors.as_deref());
let kv_sink = SideOutputs::new();
let mut flow = ModelFlow::new("gemma").with_profile(profile);
if self.prefill_hidden {
flow = flow.input("prefill_hidden", hidden_shape.clone());
} else {
flow = flow.input("input_ids", input_shape);
}
if self.dynamic_seq && self.with_lm_head && self.last_logits_only {
flow = flow.input("last_token_idx", Shape::new(&[self.batch], DType::F32));
}
if self.media_attn_bias {
let nh = cfg.num_attention_heads;
if self.dynamic_seq {
flow = flow.input(
"attn_bias",
Shape::from_dims(
&[
rlx_ir::shape::Dim::Static(self.batch),
rlx_ir::shape::Dim::Static(nh),
rlx_ir::shape::Dim::Dynamic(rlx_ir::sym::SEQ),
rlx_ir::shape::Dim::Dynamic(rlx_ir::sym::SEQ),
],
f,
),
);
} else {
flow = flow.input(
"attn_bias",
Shape::new(&[self.batch, nh, self.seq, self.seq], f),
);
}
}
flow = flow
.rope_tables(RopeTablesStage::param(
cfg.max_position_embeddings,
inv_freq.len(),
cos_data,
sin_data,
))
.zero_beta_named("gemma.zero_beta.hidden", h);
if self.prefill_hidden {
flow = flow.plugin_named("gemma.prefill_hidden_bind", move |emit, _| {
let hidden = emit
.flow_input("prefill_hidden")
.map_err(|e| anyhow::anyhow!("prefill_hidden input: {e}"))?;
let _ = emit.load_param("model.embed_tokens.weight", false)?;
Ok(Some(hidden))
});
} else {
flow = flow
.token_embed()
.raw_stage(FlowStage::EmbedScale(EmbedScaleStage::new(h)));
}
flow = flow.raw_stages(self.before_layers.iter().cloned());
if let Some(g) = &global_rope {
flow = flow.raw_stage(FlowStage::RopeTables(RopeTablesStage::param_named(
"global",
cfg.max_position_embeddings,
g.half_dim,
g.cos.clone(),
g.sin.clone(),
)));
}
let layer_fn = self.layer_fn.clone();
let export = self.with_kv_outputs;
let media_bias = self.media_attn_bias;
let num_heads = cfg.num_attention_heads;
let num_layers = cfg.active_num_layers();
let layer_attn: Vec<_> = (0..num_layers).map(|i| cfg.layer_attn_options(i)).collect();
let is_moe = cfg.is_moe();
let moe_num_experts = cfg.num_experts;
let moe_top_k = cfg.num_experts_used;
let moe_n_embd = cfg.hidden_size;
let moe_n_ff = cfg.expert_ffn_dim();
let per_layer: Vec<PerLayerAttn> = (0..num_layers)
.map(|i| PerLayerAttn {
head_dim: cfg.layer_head_dim(i),
num_kv_heads: cfg.layer_num_kv_heads(i),
n_rot: cfg.layer_n_rot(i),
rope_table: if cfg.is_full_attention_layer(i) && global_rope.is_some() {
Some("global".to_string())
} else {
None
},
k_eq_v: cfg.attention_k_eq_v,
})
.collect();
flow = flow.repeat_layers(num_layers, {
let style = layer_style;
let sink = kv_sink.clone();
move |i| {
let (mask, score_scale, softcap) = layer_attn[i];
let pl = &per_layer[i];
let lh = pl.head_dim;
let mut attn = gemma_attn_spec(
i,
num_heads,
pl.head_dim,
pl.num_kv_heads,
pl.n_rot,
mask,
score_scale,
softcap,
);
if let Some(name) = pl.rope_table.as_ref() {
attn = attn.with_rope_table(name);
}
if pl.k_eq_v {
attn = attn.with_k_eq_v();
}
if let Some(ref f) = layer_fn {
return f(GemmaLayerCtx::Prefill {
index: i,
style,
attn: attn.clone(),
kv_sink: &sink,
export_kv: export,
head_dim: lh,
eps,
});
}
if media_bias {
return crate::multimodal_flow::multimodal_layer_override(
GemmaLayerCtx::Prefill {
index: i,
style,
attn,
kv_sink: &sink,
export_kv: export,
head_dim: lh,
eps,
},
true,
);
}
if is_moe {
let prefix = format!("model.layers.{i}");
let moe = rlx_flow::blocks::MoeFfnStage::hf(
prefix,
moe_num_experts,
moe_top_k,
moe_n_embd,
moe_n_ff,
);
let kv = if export { Some(sink.inner()) } else { None };
return rlx_flow::blocks::gemma_moe_prefill_layer_composed(
i, style, attn, eps, kv, moe,
);
}
GemmaLayerCtx::Prefill {
index: i,
style,
attn,
kv_sink: &sink,
export_kv: export,
head_dim: lh,
eps,
}
.default_stage()
}
});
flow = flow.raw_stages(self.after_layers.iter().cloned());
if self.with_lm_head && self.last_logits_only {
flow = if self.dynamic_seq {
flow.gather_last_token_dynamic(self.batch)
} else {
flow.gather_last_token_at(self.batch, self.seq)
};
}
flow = flow.raw_stage(FlowStage::GemmaRmsNorm(GemmaRmsNormStage::hf_layer(
"model.norm",
eps,
)));
if let Some(patch) = self.flow_patch {
flow = patch(flow);
}
let mut built = if self.with_lm_head {
let lm = if cfg.tie_word_embeddings {
FlowStage::LmHead(LmHeadStage::tied(cfg.vocab_size, h))
} else {
FlowStage::LmHead(LmHeadStage::separate("lm_head.weight", cfg.vocab_size, h))
};
flow = flow.raw_stage(lm);
if let Some(cap) = cfg.final_logit_softcapping {
flow = flow.raw_stage(FlowStage::LogitSoftcap(LogitSoftcapStage::new(cap)));
}
flow.output("logits")
.build(&mut WeightLoaderSource(weights))?
} else {
flow.output("hidden")
.build(&mut WeightLoaderSource(weights))?
};
if self.with_kv_outputs {
built = built.with_extra_hir_outputs(kv_sink.drain());
}
Ok(built)
}
fn build_decode(self, weights: &mut dyn WeightLoader) -> Result<BuiltModel> {
let cfg = self.cfg;
let profile = self.profile.unwrap_or_else(CompileProfile::gemma_decode);
let f = DType::F32;
let h = cfg.hidden_size;
let eps = cfg.rms_norm_eps as f32;
let dh = cfg.head_dim();
let half = dh / 2;
let hidden_shape = Shape::new(&[self.batch, 1, h], f);
let decode_style = cfg.layer_style();
let decode_score_scale = cfg.attn_score_scale();
let decode_softcap = cfg.attn_logit_softcapping;
let decode_arch = cfg.arch;
let decode_sliding = cfg.sliding_window;
let kv_out = SideOutputs::new();
let rope_factors = weights.take("rope_freqs.weight").ok().map(|(data, _)| data);
let inv_freq = resolve_inv_freq(cfg, rope_factors.as_deref());
let (rope_cos, rope_sin) = if self.dynamic_past {
(Vec::new(), Vec::new())
} else {
crate::rope::rope_slice(&inv_freq, self.past_seq)
};
let global_rope_row = if !self.dynamic_past {
secondary_rope_row(cfg, self.past_seq, rope_factors.as_deref())
} else {
None
};
let global_params = needs_secondary_rope_params(cfg);
let mut flow = ModelFlow::new("gemma_decode")
.with_profile(profile)
.input("input_ids", Shape::new(&[self.batch, 1], DType::F32));
if self.dynamic_past {
flow = flow
.input("rope_cos", Shape::new(&[1, half], f))
.input("rope_sin", Shape::new(&[1, half], f));
if let Some(gp) = global_params {
let half_global =
crate::rope::resolve_global_inv_freq(cfg, rope_factors.as_deref())
.map(|v| v.len())
.unwrap_or_else(|| crate::rope::default_inv_freq(gp.theta, gp.n_rot).len());
flow = flow
.input("rope_cos_global", Shape::new(&[1, half_global], f))
.input("rope_sin_global", Shape::new(&[1, half_global], f))
.raw_stage(FlowStage::Custom(rlx_flow::blocks::CustomStage::named(
"gemma.bind_global_decode_rope",
|emit, val| {
let cos = find_hir_input(emit.hir(), "rope_cos_global")?;
let sin = find_hir_input(emit.hir(), "rope_sin_global")?;
emit.set_named("global_cos", cos);
emit.set_named("global_sin", sin);
Ok(val)
},
)));
}
}
if self.use_custom_mask {
flow = flow.input("mask", Shape::new(&[self.batch, self.past_seq + 1], f));
}
for layer_idx in 0..cfg.num_hidden_layers {
let layer_kv_dim = cfg.layer_num_kv_heads(layer_idx) * cfg.layer_head_dim(layer_idx);
let shape = if self.dynamic_past {
Shape::from_dims(
&[
Dim::Static(self.batch),
Dim::Dynamic(sym::PAST_SEQ),
Dim::Static(layer_kv_dim),
],
f,
)
} else {
Shape::new(&[self.batch, self.past_seq, layer_kv_dim], f)
};
flow = flow
.input(format!("past_k_{layer_idx}"), shape.clone())
.input(format!("past_v_{layer_idx}"), shape);
}
if !self.dynamic_past {
flow = flow.raw_stage(FlowStage::DecodeRopeParams(DecodeRopeParamsStage::new(
rope_cos, rope_sin, half,
)));
if let Some(g) = &global_rope_row {
flow = flow.raw_stage(FlowStage::DecodeRopeParams(DecodeRopeParamsStage::named(
"global",
g.cos.clone(),
g.sin.clone(),
g.half_dim,
)));
}
}
flow = flow
.bind_decode_inputs(cfg.num_hidden_layers, self.use_custom_mask)
.zero_beta_named("gemma.zero_beta.hidden", h)
.token_embed()
.raw_stage(FlowStage::EmbedScale(EmbedScaleStage::new(h)))
.raw_stages(self.before_layers.iter().cloned());
let layer_fn = self.layer_fn.clone();
let use_custom_mask = self.use_custom_mask;
let num_heads = cfg.num_attention_heads;
let num_layers = cfg.active_num_layers();
let secondary_rope_active = global_rope_row.is_some();
let per_layer_decode: Vec<PerLayerAttn> = (0..num_layers)
.map(|i| PerLayerAttn {
head_dim: cfg.layer_head_dim(i),
num_kv_heads: cfg.layer_num_kv_heads(i),
n_rot: cfg.layer_n_rot(i),
rope_table: if cfg.is_full_attention_layer(i) && secondary_rope_active {
Some("global".to_string())
} else {
None
},
k_eq_v: cfg.attention_k_eq_v,
})
.collect();
let is_moe = cfg.is_moe();
let moe_num_experts = cfg.num_experts;
let moe_top_k = cfg.num_experts_used;
let moe_n_embd = cfg.hidden_size;
let moe_n_ff = cfg.expert_ffn_dim();
flow = flow.repeat_layers(num_layers, {
let sink = kv_out.clone();
let hidden_shape = hidden_shape.clone();
move |i| {
let mask = if use_custom_mask {
rlx_ir::op::MaskKind::Causal
} else {
match (decode_arch, decode_sliding) {
(GemmaArch::Gemma2, Some(w)) => rlx_flow::blocks::gemma2_layer_mask(i, w),
(GemmaArch::Gemma3 | GemmaArch::Gemma4, Some(w)) => {
rlx_flow::blocks::gemma_strided_layer_mask(
i,
w,
decode_arch.sliding_window_stride(),
)
}
_ => rlx_ir::op::MaskKind::Causal,
}
};
let pl = &per_layer_decode[i];
let kv_group_size = num_heads / pl.num_kv_heads;
let spec = GemmaDecodeLayerSpec {
style: decode_style,
num_heads,
head_dim: pl.head_dim,
num_kv_heads: pl.num_kv_heads,
kv_group_size,
n_rot: pl.n_rot,
rope_table: pl.rope_table.clone(),
k_eq_v: pl.k_eq_v,
eps,
use_custom_mask,
hidden_shape: hidden_shape.clone(),
mask,
score_scale: decode_score_scale,
attn_logit_softcap: decode_softcap,
};
if let Some(ref f) = layer_fn {
return f(GemmaLayerCtx::Decode {
index: i,
spec: spec.clone(),
kv_out: &sink,
});
}
if is_moe {
let prefix = format!("model.layers.{i}");
let moe = rlx_flow::blocks::MoeFfnStage::hf(
prefix,
moe_num_experts,
moe_top_k,
moe_n_embd,
moe_n_ff,
);
return rlx_flow::blocks::gemma_moe_decode_layer_composed(
i,
spec,
sink.inner(),
moe,
);
}
GemmaLayerCtx::Decode {
index: i,
spec,
kv_out: &sink,
}
.default_stage()
}
});
flow = flow.raw_stages(self.after_layers.iter().cloned());
if let Some(patch) = self.flow_patch {
flow = patch(flow);
}
let mut flow = flow.raw_stage(FlowStage::GemmaRmsNorm(GemmaRmsNormStage::hf_layer(
"model.norm",
eps,
)));
let lm = if cfg.tie_word_embeddings {
FlowStage::LmHead(LmHeadStage::tied(cfg.vocab_size, h))
} else {
FlowStage::LmHead(LmHeadStage::separate("lm_head.weight", cfg.vocab_size, h))
};
flow = flow.raw_stage(lm);
if let Some(cap) = cfg.final_logit_softcapping {
flow = flow.raw_stage(FlowStage::LogitSoftcap(LogitSoftcapStage::new(cap)));
}
let built = flow
.output("logits")
.build(&mut WeightLoaderSource(weights))?
.with_extra_hir_outputs(kv_out.drain());
Ok(built)
}
}
fn prefill_hidden_shape(
batch: usize,
seq: usize,
hidden: usize,
dynamic: bool,
dtype: DType,
) -> Shape {
if dynamic {
Shape::from_dims(
&[
Dim::Static(batch),
Dim::Dynamic(sym::SEQ),
Dim::Static(hidden),
],
dtype,
)
} else {
Shape::new(&[batch, seq, hidden], dtype)
}
}
fn prefill_input_shape(batch: usize, seq: usize, dynamic: bool) -> Shape {
if dynamic {
Shape::from_dims(&[Dim::Static(batch), Dim::Dynamic(sym::SEQ)], DType::F32)
} else {
Shape::new(&[batch, seq], DType::F32)
}
}
#[derive(Debug, Clone)]
struct PerLayerAttn {
head_dim: usize,
num_kv_heads: usize,
n_rot: usize,
rope_table: Option<String>,
k_eq_v: bool,
}
#[derive(Debug, Clone)]
struct GlobalRopeTables {
cos: Vec<f32>,
sin: Vec<f32>,
half_dim: usize,
}
fn secondary_rope_tables(
cfg: &GemmaConfig,
max_pos: usize,
factors: Option<&[f32]>,
) -> Option<GlobalRopeTables> {
let inv = crate::rope::resolve_global_inv_freq(cfg, factors)?;
let (cos, sin) = crate::rope::build_rope_tables(&inv, max_pos);
Some(GlobalRopeTables {
cos,
sin,
half_dim: inv.len(),
})
}
fn secondary_rope_row(
cfg: &GemmaConfig,
pos: usize,
factors: Option<&[f32]>,
) -> Option<GlobalRopeTables> {
let inv = crate::rope::resolve_global_inv_freq(cfg, factors)?;
let (cos, sin) = crate::rope::rope_slice(&inv, pos);
Some(GlobalRopeTables {
cos,
sin,
half_dim: inv.len(),
})
}
fn needs_secondary_rope_params(cfg: &GemmaConfig) -> Option<GlobalRopeParams> {
crate::rope::global_rope_params(cfg).map(|(theta, n_rot)| GlobalRopeParams { theta, n_rot })
}
#[derive(Debug, Clone, Copy)]
struct GlobalRopeParams {
theta: f64,
n_rot: usize,
}
fn find_hir_input(hir: &HirModule, name: &str) -> anyhow::Result<rlx_ir::HirNodeId> {
use rlx_ir::hir::HirOp;
for node in hir.nodes() {
if let HirOp::Input { name: n } = &node.op {
if n == name {
return Ok(node.id);
}
}
}
Err(anyhow::anyhow!("gemma decode flow missing input: {name}"))
}
impl<'a> GemmaFlow<'a> {
fn from_prefill_opts(cfg: &'a GemmaConfig, o: &GemmaPrefillOpts) -> Self {
let mut f = GemmaFlow::new(cfg).prefill().batch(o.batch).seq(o.seq);
if o.dynamic_seq {
f = f.dynamic_seq();
}
if o.prefill_hidden {
f = f.prefill_from_hidden();
}
if o.media_attn_bias {
f = f.prefill_media_attn_bias();
}
if o.with_lm_head {
f = f.lm_head();
}
if o.with_kv_outputs {
f = f.export_kv();
}
if o.last_logits_only {
f = f.last_token_logits();
}
if let Some(p) = o.profile.clone() {
f = f.profile(p);
}
f
}
fn from_decode_opts(cfg: &'a GemmaConfig, o: &GemmaDecodeOpts) -> Self {
let mut f = GemmaFlow::new(cfg)
.decode()
.batch(o.batch)
.past(o.past_seq)
.lm_head();
if o.dynamic_past {
f = f.dynamic_past();
}
if o.use_custom_mask {
f = f.custom_mask();
}
if let Some(p) = o.profile.clone() {
f = f.profile(p);
}
f
}
}
#[derive(Debug, Clone)]
pub struct GemmaPrefillOpts {
pub batch: usize,
pub seq: usize,
pub dynamic_seq: bool,
pub prefill_hidden: bool,
pub media_attn_bias: bool,
pub with_lm_head: bool,
pub with_kv_outputs: bool,
pub last_logits_only: bool,
pub profile: Option<CompileProfile>,
}
impl GemmaPrefillOpts {
pub fn static_prefill(batch: usize, seq: usize) -> Self {
Self {
batch,
seq,
dynamic_seq: false,
prefill_hidden: false,
media_attn_bias: false,
with_lm_head: false,
with_kv_outputs: false,
last_logits_only: false,
profile: None,
}
}
}
#[derive(Debug, Clone)]
pub struct GemmaDecodeOpts {
pub batch: usize,
pub past_seq: usize,
pub dynamic_past: bool,
pub use_custom_mask: bool,
pub profile: Option<CompileProfile>,
}
pub fn build_gemma_prefill_flow(
cfg: &GemmaConfig,
weights: &mut dyn WeightLoader,
opts: &GemmaPrefillOpts,
) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
build_gemma_prefill_built(cfg, weights, opts)?.into_parts()
}
pub fn build_gemma_prefill_built(
cfg: &GemmaConfig,
weights: &mut dyn WeightLoader,
opts: &GemmaPrefillOpts,
) -> Result<BuiltModel> {
GemmaFlow::from_prefill_opts(cfg, opts).build(weights)
}
pub fn build_gemma_decode_flow(
cfg: &GemmaConfig,
weights: &mut dyn WeightLoader,
opts: &GemmaDecodeOpts,
) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
build_gemma_decode_built(cfg, weights, opts)?.into_parts()
}
pub fn build_gemma_decode_graph(
cfg: &GemmaConfig,
weights: &mut dyn WeightLoader,
opts: &GemmaDecodeOpts,
) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
rlx_core::flow_util::graph_from_built(build_gemma_decode_built(cfg, weights, opts)?)
}
pub fn build_gemma_decode_built(
cfg: &GemmaConfig,
weights: &mut dyn WeightLoader,
opts: &GemmaDecodeOpts,
) -> Result<BuiltModel> {
GemmaFlow::from_decode_opts(cfg, opts).build(weights)
}
#[cfg(test)]
mod gemma4_tests {
use super::*;
use crate::config::{
GemmaArch, GemmaLayerType, GemmaRopeKind, GemmaRopeMap, GemmaRopeParameters,
};
fn gemma4_12b_like() -> GemmaConfig {
let mut cfg = GemmaConfig::tiny_test();
cfg.arch = GemmaArch::Gemma4;
cfg.hidden_size = 3840;
cfg.intermediate_size = 15_360;
cfg.num_hidden_layers = 12; cfg.num_attention_heads = 16;
cfg.num_key_value_heads = 8;
cfg.head_dim = Some(256);
cfg.global_head_dim = Some(512);
cfg.num_global_key_value_heads = Some(1);
cfg.attention_k_eq_v = true;
cfg.sliding_window = Some(1024);
cfg.final_logit_softcapping = Some(30.0);
cfg.tie_word_embeddings = true;
cfg.max_position_embeddings = 4096;
cfg.rope_theta = 10_000.0;
cfg.layer_types = (0..cfg.num_hidden_layers)
.map(|i| {
if (i + 1) % 6 == 0 {
GemmaLayerType::FullAttention
} else {
GemmaLayerType::SlidingAttention
}
})
.collect();
cfg.rope_parameters = GemmaRopeMap {
sliding_attention: Some(GemmaRopeParameters {
rope_theta: Some(10_000.0),
rope_type: Some(GemmaRopeKind::Default),
partial_rotary_factor: None,
}),
full_attention: Some(GemmaRopeParameters {
rope_theta: Some(1_000_000.0),
rope_type: Some(GemmaRopeKind::Proportional),
partial_rotary_factor: Some(0.25),
}),
};
cfg
}
#[test]
fn secondary_rope_emits_distinct_table_for_full_attention() {
let cfg = gemma4_12b_like();
let tables = secondary_rope_tables(&cfg, cfg.max_position_embeddings, None)
.expect("Gemma 4 split rope_parameters should produce a secondary table");
assert_eq!(tables.half_dim, 64);
assert_eq!(tables.cos.len(), cfg.max_position_embeddings * 64);
assert_eq!(tables.sin.len(), tables.cos.len());
assert!((tables.cos[0] - 1.0).abs() < 1e-6);
assert!(tables.sin[0].abs() < 1e-6);
let global_inv = crate::rope::default_inv_freq(1_000_000.0, 128);
let sliding_inv = crate::rope::default_inv_freq(10_000.0, 128);
assert!((global_inv[5] - sliding_inv[5]).abs() > 1e-3);
let global_cos_p1_d5 = (1.0 * global_inv[5]).cos();
let global_sample = tables.cos[64 + 5]; assert!((global_sample as f64 - global_cos_p1_d5).abs() < 1e-5);
}
#[test]
fn per_layer_kv_dims_diverge_on_full_attention() {
let cfg = gemma4_12b_like();
assert_eq!(cfg.layer_num_kv_heads(0) * cfg.layer_head_dim(0), 2048);
assert_eq!(cfg.layer_num_kv_heads(5) * cfg.layer_head_dim(5), 512);
assert_eq!(cfg.layer_num_kv_heads(11) * cfg.layer_head_dim(11), 512);
}
#[test]
fn no_secondary_table_when_params_match() {
let mut cfg = gemma4_12b_like();
cfg.rope_parameters.full_attention = cfg.rope_parameters.sliding_attention;
cfg.global_head_dim = None;
cfg.num_global_key_value_heads = None;
assert!(secondary_rope_tables(&cfg, cfg.max_position_embeddings, None).is_none());
}
}