use std::collections::HashMap;
use std::fmt;
use std::path::Path;
use std::sync::Arc;
use anyhow::Result;
use rlx_flow::blocks::{
DecodeRopeParamsStage, EmbedScaleStage, GemmaDecodeLayerSpec, GemmaDecodeLayerStage,
GemmaLayerStyle, GemmaRmsNormStage, LmHeadStage, LogitSoftcapStage, RopeTablesStage,
gemma_attn_spec, gemma_prefill_layer_composed,
};
use rlx_flow::{BuiltModel, CompileProfile, FlowStage, ModelFlow, SideOutputs};
use rlx_ir::dynamic::sym;
use rlx_ir::hir::HirModule;
use rlx_ir::shape::Dim;
use rlx_ir::{DType, Graph, Shape};
use super::config::{GemmaArch, GemmaConfig};
use super::rope::{build_rope_tables, resolve_inv_freq};
use rlx_core::flow_bridge::{WeightLoaderSource, load_compile_profile};
use rlx_core::weight_loader::WeightLoader;
pub const GEMMA_PROFILE_FILE: &str = "gemma.rlx.toml";
pub fn gemma_profile_near_weights(weights: &Path, decode: bool) -> CompileProfile {
let default = if decode {
CompileProfile::gemma_decode()
} else {
CompileProfile::gemma_prefill()
};
let dir = weights.parent().unwrap_or_else(|| Path::new("."));
load_compile_profile(&dir.join(GEMMA_PROFILE_FILE), default)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GemmaMode {
Prefill,
Decode,
}
pub enum GemmaLayerCtx<'a> {
Prefill {
index: usize,
style: GemmaLayerStyle,
attn: rlx_flow::blocks::SelfAttnPrefillSpec,
kv_sink: &'a SideOutputs,
export_kv: bool,
head_dim: usize,
eps: f32,
},
Decode {
index: usize,
spec: GemmaDecodeLayerSpec,
kv_out: &'a SideOutputs,
},
}
impl GemmaLayerCtx<'_> {
pub fn index(&self) -> usize {
match self {
Self::Prefill { index, .. } | Self::Decode { index, .. } => *index,
}
}
pub fn default_stage(&self) -> FlowStage {
match self {
Self::Prefill {
index,
style,
attn,
kv_sink,
export_kv,
head_dim: _,
eps,
} => gemma_prefill_layer_composed(
*index,
*style,
attn.clone(),
*eps,
if *export_kv {
Some(kv_sink.inner())
} else {
None
},
),
Self::Decode {
index,
spec,
kv_out,
} => FlowStage::Named {
name: format!("layer{index}"),
inner: Arc::new(FlowStage::GemmaDecodeLayer(GemmaDecodeLayerStage::layer(
*index,
spec.clone(),
kv_out.inner(),
))),
},
}
}
}
type LayerFn = Arc<dyn Fn(GemmaLayerCtx<'_>) -> FlowStage + Send + Sync>;
type FlowPatchFn = Arc<dyn Fn(ModelFlow) -> ModelFlow + Send + Sync>;
#[derive(Clone)]
pub struct GemmaFlow<'a> {
cfg: &'a GemmaConfig,
mode: GemmaMode,
batch: usize,
seq: usize,
past_seq: usize,
dynamic_seq: bool,
dynamic_past: bool,
with_lm_head: bool,
with_kv_outputs: bool,
last_logits_only: bool,
use_custom_mask: bool,
profile: Option<CompileProfile>,
before_layers: Vec<FlowStage>,
after_layers: Vec<FlowStage>,
layer_fn: Option<LayerFn>,
flow_patch: Option<FlowPatchFn>,
}
impl fmt::Debug for GemmaFlow<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("GemmaFlow")
.field("mode", &self.mode)
.field("batch", &self.batch)
.field("seq", &self.seq)
.field("past_seq", &self.past_seq)
.field("dynamic_seq", &self.dynamic_seq)
.field("dynamic_past", &self.dynamic_past)
.field("with_lm_head", &self.with_lm_head)
.field("with_kv_outputs", &self.with_kv_outputs)
.field("last_logits_only", &self.last_logits_only)
.field("use_custom_mask", &self.use_custom_mask)
.field("profile", &self.profile)
.field("before_layers", &self.before_layers.len())
.field("after_layers", &self.after_layers.len())
.field("layer_fn", &self.layer_fn.is_some())
.field("flow_patch", &self.flow_patch.is_some())
.finish_non_exhaustive()
}
}
impl<'a> GemmaFlow<'a> {
pub fn new(cfg: &'a GemmaConfig) -> Self {
Self {
cfg,
mode: GemmaMode::Prefill,
batch: 1,
seq: 128,
past_seq: 0,
dynamic_seq: false,
dynamic_past: false,
with_lm_head: false,
with_kv_outputs: false,
last_logits_only: false,
use_custom_mask: false,
profile: None,
before_layers: Vec::new(),
after_layers: Vec::new(),
layer_fn: None,
flow_patch: None,
}
}
pub fn for_prefill(cfg: &'a GemmaConfig, batch: usize, seq: usize) -> Self {
Self::new(cfg).prefill().batch(batch).seq(seq)
}
pub fn for_decode(cfg: &'a GemmaConfig, batch: usize, past_seq: usize) -> Self {
Self::new(cfg)
.decode()
.batch(batch)
.past(past_seq)
.lm_head()
}
pub fn prefill(mut self) -> Self {
self.mode = GemmaMode::Prefill;
self
}
pub fn decode(mut self) -> Self {
self.mode = GemmaMode::Decode;
self
}
pub fn batch(mut self, batch: usize) -> Self {
self.batch = batch;
self
}
pub fn seq(mut self, seq: usize) -> Self {
self.seq = seq;
self
}
pub fn past(mut self, past_seq: usize) -> Self {
self.past_seq = past_seq;
self
}
pub fn dynamic_seq(mut self) -> Self {
self.dynamic_seq = true;
self
}
pub fn dynamic_past(mut self) -> Self {
self.dynamic_past = true;
self
}
pub fn lm_head(mut self) -> Self {
self.with_lm_head = true;
self
}
pub fn hidden_only(mut self) -> Self {
self.with_lm_head = false;
self.last_logits_only = false;
self
}
pub fn last_token_logits(mut self) -> Self {
self.with_lm_head = true;
self.last_logits_only = true;
self
}
pub fn export_kv(mut self) -> Self {
self.with_kv_outputs = true;
self
}
pub fn custom_mask(mut self) -> Self {
self.use_custom_mask = true;
self
}
pub fn profile(mut self, profile: CompileProfile) -> Self {
self.profile = Some(profile);
self
}
pub fn profile_prefill(mut self) -> Self {
self.profile = Some(CompileProfile::gemma_prefill());
self
}
pub fn profile_decode(mut self) -> Self {
self.profile = Some(CompileProfile::gemma_decode());
self
}
pub fn profile_near(mut self, weights_path: &Path) -> Self {
let decode = self.mode == GemmaMode::Decode;
self.profile = Some(gemma_profile_near_weights(weights_path, decode));
self
}
pub fn before_layers(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
self.before_layers.extend(stages);
self
}
pub fn after_layers(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
self.after_layers.extend(stages);
self
}
pub fn layer<F>(mut self, f: F) -> Self
where
F: Fn(GemmaLayerCtx<'_>) -> FlowStage + Send + Sync + 'static,
{
self.layer_fn = Some(Arc::new(f));
self
}
pub fn patch_flow<F>(mut self, f: F) -> Self
where
F: Fn(ModelFlow) -> ModelFlow + Send + Sync + 'static,
{
self.flow_patch = Some(Arc::new(f));
self
}
pub fn build(self, weights: &mut dyn WeightLoader) -> Result<BuiltModel> {
match self.mode {
GemmaMode::Prefill => self.build_prefill(weights),
GemmaMode::Decode => self.build_decode(weights),
}
}
fn build_prefill(self, weights: &mut dyn WeightLoader) -> Result<BuiltModel> {
if self.dynamic_seq && self.batch != 1 {
anyhow::bail!("gemma: dynamic_seq prefill requires batch=1");
}
let cfg = self.cfg;
let profile = self.profile.unwrap_or_else(CompileProfile::gemma_prefill);
let f = DType::F32;
let h = cfg.hidden_size;
let eps = cfg.rms_norm_eps as f32;
let dh = cfg.head_dim();
let layer_style = cfg.layer_style();
let _hidden_shape = prefill_hidden_shape(self.batch, self.seq, h, self.dynamic_seq, f);
let input_shape = prefill_input_shape(self.batch, self.seq, self.dynamic_seq);
let rope_factors = weights.take("rope_freqs.weight").ok().map(|(data, _)| data);
let inv_freq = resolve_inv_freq(cfg, rope_factors.as_deref());
let (cos_data, sin_data) = build_rope_tables(&inv_freq, cfg.max_position_embeddings);
let kv_sink = SideOutputs::new();
let mut flow = ModelFlow::new("gemma")
.with_profile(profile)
.input("input_ids", input_shape);
if self.dynamic_seq && self.with_lm_head && self.last_logits_only {
flow = flow.input("last_token_idx", Shape::new(&[self.batch], DType::F32));
}
flow = flow
.rope_tables(RopeTablesStage::param(
cfg.max_position_embeddings,
inv_freq.len(),
cos_data,
sin_data,
))
.zero_beta_named("gemma.zero_beta.hidden", h)
.token_embed()
.raw_stage(FlowStage::EmbedScale(EmbedScaleStage::new(h)))
.raw_stages(self.before_layers.iter().cloned());
let layer_fn = self.layer_fn.clone();
let export = self.with_kv_outputs;
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
let num_layers = cfg.active_num_layers();
let layer_attn: Vec<_> = (0..num_layers).map(|i| cfg.layer_attn_options(i)).collect();
let is_moe = cfg.is_moe();
let moe_num_experts = cfg.num_experts;
let moe_top_k = cfg.num_experts_used;
let moe_n_embd = cfg.hidden_size;
let moe_n_ff = cfg.expert_ffn_dim();
flow = flow.repeat_layers(num_layers, {
let style = layer_style;
let sink = kv_sink.clone();
move |i| {
let (mask, score_scale, softcap) = layer_attn[i];
let attn =
gemma_attn_spec(i, num_heads, dh, num_kv_heads, mask, score_scale, softcap);
if let Some(ref f) = layer_fn {
return f(GemmaLayerCtx::Prefill {
index: i,
style,
attn: attn.clone(),
kv_sink: &sink,
export_kv: export,
head_dim: dh,
eps,
});
}
if is_moe {
let prefix = format!("model.layers.{i}");
let moe = rlx_flow::blocks::MoeFfnStage::hf(
prefix,
moe_num_experts,
moe_top_k,
moe_n_embd,
moe_n_ff,
);
let kv = if export { Some(sink.inner()) } else { None };
return rlx_flow::blocks::gemma_moe_prefill_layer_composed(
i, style, attn, eps, kv, moe,
);
}
GemmaLayerCtx::Prefill {
index: i,
style,
attn,
kv_sink: &sink,
export_kv: export,
head_dim: dh,
eps,
}
.default_stage()
}
});
flow = flow.raw_stages(self.after_layers.iter().cloned());
if self.with_lm_head && self.last_logits_only {
flow = if self.dynamic_seq {
flow.gather_last_token_dynamic(self.batch)
} else {
flow.gather_last_token_at(self.batch, self.seq)
};
}
flow = flow.raw_stage(FlowStage::GemmaRmsNorm(GemmaRmsNormStage::hf_layer(
"model.norm",
eps,
)));
if let Some(patch) = self.flow_patch {
flow = patch(flow);
}
let mut built = if self.with_lm_head {
let lm = if cfg.tie_word_embeddings {
FlowStage::LmHead(LmHeadStage::tied(cfg.vocab_size, h))
} else {
FlowStage::LmHead(LmHeadStage::separate("lm_head.weight", cfg.vocab_size, h))
};
flow = flow.raw_stage(lm);
if let Some(cap) = cfg.final_logit_softcapping {
flow = flow.raw_stage(FlowStage::LogitSoftcap(LogitSoftcapStage::new(cap)));
}
flow.output("logits")
.build(&mut WeightLoaderSource(weights))?
} else {
flow.output("hidden")
.build(&mut WeightLoaderSource(weights))?
};
if self.with_kv_outputs {
built = built.with_extra_hir_outputs(kv_sink.drain());
}
Ok(built)
}
fn build_decode(self, weights: &mut dyn WeightLoader) -> Result<BuiltModel> {
let cfg = self.cfg;
let profile = self.profile.unwrap_or_else(CompileProfile::gemma_decode);
let f = DType::F32;
let h = cfg.hidden_size;
let eps = cfg.rms_norm_eps as f32;
let dh = cfg.head_dim();
let kv_dim = cfg.kv_proj_dim();
let half = dh / 2;
let hidden_shape = Shape::new(&[self.batch, 1, h], f);
let past_kv_shape = if self.dynamic_past {
Shape::from_dims(
&[
Dim::Static(self.batch),
Dim::Dynamic(sym::PAST_SEQ),
Dim::Static(kv_dim),
],
f,
)
} else {
Shape::new(&[self.batch, self.past_seq, kv_dim], f)
};
let decode_style = cfg.layer_style();
let decode_score_scale = cfg.attn_score_scale();
let decode_softcap = cfg.attn_logit_softcapping;
let decode_arch = cfg.arch;
let decode_sliding = cfg.sliding_window;
let kv_out = SideOutputs::new();
let rope_factors = weights.take("rope_freqs.weight").ok().map(|(data, _)| data);
let inv_freq = resolve_inv_freq(cfg, rope_factors.as_deref());
let (rope_cos, rope_sin) = if self.dynamic_past {
(Vec::new(), Vec::new())
} else {
crate::rope::rope_slice(&inv_freq, self.past_seq)
};
let mut flow = ModelFlow::new("gemma_decode")
.with_profile(profile)
.input("input_ids", Shape::new(&[self.batch, 1], DType::F32));
if self.dynamic_past {
flow = flow
.input("rope_cos", Shape::new(&[1, half], f))
.input("rope_sin", Shape::new(&[1, half], f));
}
if self.use_custom_mask {
flow = flow.input("mask", Shape::new(&[self.batch, self.past_seq + 1], f));
}
for layer_idx in 0..cfg.num_hidden_layers {
flow = flow
.input(format!("past_k_{layer_idx}"), past_kv_shape.clone())
.input(format!("past_v_{layer_idx}"), past_kv_shape.clone());
}
if !self.dynamic_past {
flow = flow.raw_stage(FlowStage::DecodeRopeParams(DecodeRopeParamsStage {
cos: rope_cos,
sin: rope_sin,
half_dim: half,
}));
}
flow = flow
.bind_decode_inputs(cfg.num_hidden_layers, self.use_custom_mask)
.zero_beta_named("gemma.zero_beta.hidden", h)
.token_embed()
.raw_stage(FlowStage::EmbedScale(EmbedScaleStage::new(h)))
.raw_stages(self.before_layers.iter().cloned());
let layer_fn = self.layer_fn.clone();
let use_custom_mask = self.use_custom_mask;
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
let kv_group_size = cfg.kv_group_size();
let num_layers = cfg.active_num_layers();
let is_moe = cfg.is_moe();
let moe_num_experts = cfg.num_experts;
let moe_top_k = cfg.num_experts_used;
let moe_n_embd = cfg.hidden_size;
let moe_n_ff = cfg.expert_ffn_dim();
flow = flow.repeat_layers(num_layers, {
let sink = kv_out.clone();
let hidden_shape = hidden_shape.clone();
move |i| {
let mask = if use_custom_mask {
rlx_ir::op::MaskKind::Causal
} else {
match (decode_arch, decode_sliding) {
(GemmaArch::Gemma2, Some(w)) => rlx_flow::blocks::gemma2_layer_mask(i, w),
(GemmaArch::Gemma3 | GemmaArch::Gemma4, Some(w)) => {
rlx_flow::blocks::gemma_strided_layer_mask(
i,
w,
decode_arch.sliding_window_stride(),
)
}
_ => rlx_ir::op::MaskKind::Causal,
}
};
let spec = GemmaDecodeLayerSpec {
style: decode_style,
num_heads,
head_dim: dh,
num_kv_heads,
kv_group_size,
eps,
use_custom_mask,
hidden_shape: hidden_shape.clone(),
mask,
score_scale: decode_score_scale,
attn_logit_softcap: decode_softcap,
};
if let Some(ref f) = layer_fn {
return f(GemmaLayerCtx::Decode {
index: i,
spec: spec.clone(),
kv_out: &sink,
});
}
if is_moe {
let prefix = format!("model.layers.{i}");
let moe = rlx_flow::blocks::MoeFfnStage::hf(
prefix,
moe_num_experts,
moe_top_k,
moe_n_embd,
moe_n_ff,
);
return rlx_flow::blocks::gemma_moe_decode_layer_composed(
i,
spec,
sink.inner(),
moe,
);
}
GemmaLayerCtx::Decode {
index: i,
spec,
kv_out: &sink,
}
.default_stage()
}
});
flow = flow.raw_stages(self.after_layers.iter().cloned());
if let Some(patch) = self.flow_patch {
flow = patch(flow);
}
let mut flow = flow.raw_stage(FlowStage::GemmaRmsNorm(GemmaRmsNormStage::hf_layer(
"model.norm",
eps,
)));
let lm = if cfg.tie_word_embeddings {
FlowStage::LmHead(LmHeadStage::tied(cfg.vocab_size, h))
} else {
FlowStage::LmHead(LmHeadStage::separate("lm_head.weight", cfg.vocab_size, h))
};
flow = flow.raw_stage(lm);
if let Some(cap) = cfg.final_logit_softcapping {
flow = flow.raw_stage(FlowStage::LogitSoftcap(LogitSoftcapStage::new(cap)));
}
let built = flow
.output("logits")
.build(&mut WeightLoaderSource(weights))?
.with_extra_hir_outputs(kv_out.drain());
Ok(built)
}
}
fn prefill_hidden_shape(
batch: usize,
seq: usize,
hidden: usize,
dynamic: bool,
dtype: DType,
) -> Shape {
if dynamic {
Shape::from_dims(
&[
Dim::Static(batch),
Dim::Dynamic(sym::SEQ),
Dim::Static(hidden),
],
dtype,
)
} else {
Shape::new(&[batch, seq, hidden], dtype)
}
}
fn prefill_input_shape(batch: usize, seq: usize, dynamic: bool) -> Shape {
if dynamic {
Shape::from_dims(&[Dim::Static(batch), Dim::Dynamic(sym::SEQ)], DType::F32)
} else {
Shape::new(&[batch, seq], DType::F32)
}
}
impl<'a> GemmaFlow<'a> {
fn from_prefill_opts(cfg: &'a GemmaConfig, o: &GemmaPrefillOpts) -> Self {
let mut f = GemmaFlow::new(cfg).prefill().batch(o.batch).seq(o.seq);
if o.dynamic_seq {
f = f.dynamic_seq();
}
if o.with_lm_head {
f = f.lm_head();
}
if o.with_kv_outputs {
f = f.export_kv();
}
if o.last_logits_only {
f = f.last_token_logits();
}
if let Some(p) = o.profile.clone() {
f = f.profile(p);
}
f
}
fn from_decode_opts(cfg: &'a GemmaConfig, o: &GemmaDecodeOpts) -> Self {
let mut f = GemmaFlow::new(cfg)
.decode()
.batch(o.batch)
.past(o.past_seq)
.lm_head();
if o.dynamic_past {
f = f.dynamic_past();
}
if o.use_custom_mask {
f = f.custom_mask();
}
if let Some(p) = o.profile.clone() {
f = f.profile(p);
}
f
}
}
#[derive(Debug, Clone)]
pub struct GemmaPrefillOpts {
pub batch: usize,
pub seq: usize,
pub dynamic_seq: bool,
pub with_lm_head: bool,
pub with_kv_outputs: bool,
pub last_logits_only: bool,
pub profile: Option<CompileProfile>,
}
impl GemmaPrefillOpts {
pub fn static_prefill(batch: usize, seq: usize) -> Self {
Self {
batch,
seq,
dynamic_seq: false,
with_lm_head: false,
with_kv_outputs: false,
last_logits_only: false,
profile: None,
}
}
}
#[derive(Debug, Clone)]
pub struct GemmaDecodeOpts {
pub batch: usize,
pub past_seq: usize,
pub dynamic_past: bool,
pub use_custom_mask: bool,
pub profile: Option<CompileProfile>,
}
pub fn build_gemma_prefill_flow(
cfg: &GemmaConfig,
weights: &mut dyn WeightLoader,
opts: &GemmaPrefillOpts,
) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
build_gemma_prefill_built(cfg, weights, opts)?.into_parts()
}
pub fn build_gemma_prefill_built(
cfg: &GemmaConfig,
weights: &mut dyn WeightLoader,
opts: &GemmaPrefillOpts,
) -> Result<BuiltModel> {
GemmaFlow::from_prefill_opts(cfg, opts).build(weights)
}
pub fn build_gemma_decode_flow(
cfg: &GemmaConfig,
weights: &mut dyn WeightLoader,
opts: &GemmaDecodeOpts,
) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
build_gemma_decode_built(cfg, weights, opts)?.into_parts()
}
pub fn build_gemma_decode_graph(
cfg: &GemmaConfig,
weights: &mut dyn WeightLoader,
opts: &GemmaDecodeOpts,
) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
rlx_core::flow_util::graph_from_built(build_gemma_decode_built(cfg, weights, opts)?)
}
pub fn build_gemma_decode_built(
cfg: &GemmaConfig,
weights: &mut dyn WeightLoader,
opts: &GemmaDecodeOpts,
) -> Result<BuiltModel> {
GemmaFlow::from_decode_opts(cfg, opts).build(weights)
}