1use std::collections::HashMap;
46use std::fmt;
47use std::path::Path;
48use std::sync::Arc;
49
50use anyhow::Result;
51use rlx_flow::blocks::{
52 DecodeRopeParamsStage, EmbedScaleStage, GemmaDecodeLayerSpec, GemmaDecodeLayerStage,
53 GemmaLayerStyle, GemmaRmsNormStage, LmHeadStage, LogitSoftcapStage, RopeTablesStage,
54 gemma_attn_spec, gemma_prefill_layer_composed,
55};
56use rlx_flow::{BuiltModel, CompileProfile, FlowStage, ModelFlow, SideOutputs};
57use rlx_ir::dynamic::sym;
58use rlx_ir::hir::HirModule;
59use rlx_ir::shape::Dim;
60use rlx_ir::{DType, Graph, Shape};
61
62use super::config::{GemmaArch, GemmaConfig};
63use super::rope::{build_rope_tables, resolve_inv_freq};
64use rlx_core::flow_bridge::{WeightLoaderSource, load_compile_profile};
65use rlx_core::weight_loader::WeightLoader;
66
67pub const GEMMA_PROFILE_FILE: &str = "gemma.rlx.toml";
69
70pub fn gemma_profile_near_weights(weights: &Path, decode: bool) -> CompileProfile {
72 let default = if decode {
73 CompileProfile::gemma_decode()
74 } else {
75 CompileProfile::gemma_prefill()
76 };
77 let dir = weights.parent().unwrap_or_else(|| Path::new("."));
78 load_compile_profile(&dir.join(GEMMA_PROFILE_FILE), default)
79}
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
82pub enum GemmaMode {
83 Prefill,
84 Decode,
85}
86
87pub enum GemmaLayerCtx<'a> {
89 Prefill {
90 index: usize,
91 style: GemmaLayerStyle,
92 attn: rlx_flow::blocks::SelfAttnPrefillSpec,
93 kv_sink: &'a SideOutputs,
94 export_kv: bool,
95 head_dim: usize,
96 eps: f32,
97 },
98 Decode {
99 index: usize,
100 spec: GemmaDecodeLayerSpec,
101 kv_out: &'a SideOutputs,
102 },
103}
104
105impl GemmaLayerCtx<'_> {
106 pub fn index(&self) -> usize {
107 match self {
108 Self::Prefill { index, .. } | Self::Decode { index, .. } => *index,
109 }
110 }
111
112 pub fn default_stage(&self) -> FlowStage {
113 match self {
114 Self::Prefill {
115 index,
116 style,
117 attn,
118 kv_sink,
119 export_kv,
120 head_dim: _,
121 eps,
122 } => gemma_prefill_layer_composed(
123 *index,
124 *style,
125 attn.clone(),
126 *eps,
127 if *export_kv {
128 Some(kv_sink.inner())
129 } else {
130 None
131 },
132 ),
133 Self::Decode {
134 index,
135 spec,
136 kv_out,
137 } => FlowStage::Named {
138 name: format!("layer{index}"),
139 inner: Arc::new(FlowStage::GemmaDecodeLayer(GemmaDecodeLayerStage::layer(
140 *index,
141 spec.clone(),
142 kv_out.inner(),
143 ))),
144 },
145 }
146 }
147}
148
149type LayerFn = Arc<dyn Fn(GemmaLayerCtx<'_>) -> FlowStage + Send + Sync>;
150type FlowPatchFn = Arc<dyn Fn(ModelFlow) -> ModelFlow + Send + Sync>;
151
152#[derive(Clone)]
166pub struct GemmaFlow<'a> {
167 cfg: &'a GemmaConfig,
168 mode: GemmaMode,
169 batch: usize,
170 seq: usize,
171 past_seq: usize,
172 dynamic_seq: bool,
173 dynamic_past: bool,
174 with_lm_head: bool,
175 with_kv_outputs: bool,
176 last_logits_only: bool,
177 use_custom_mask: bool,
178 profile: Option<CompileProfile>,
179 before_layers: Vec<FlowStage>,
180 after_layers: Vec<FlowStage>,
181 layer_fn: Option<LayerFn>,
182 flow_patch: Option<FlowPatchFn>,
183 prefill_hidden: bool,
185 media_attn_bias: bool,
187}
188
189impl fmt::Debug for GemmaFlow<'_> {
190 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191 f.debug_struct("GemmaFlow")
192 .field("mode", &self.mode)
193 .field("batch", &self.batch)
194 .field("seq", &self.seq)
195 .field("past_seq", &self.past_seq)
196 .field("dynamic_seq", &self.dynamic_seq)
197 .field("dynamic_past", &self.dynamic_past)
198 .field("with_lm_head", &self.with_lm_head)
199 .field("with_kv_outputs", &self.with_kv_outputs)
200 .field("last_logits_only", &self.last_logits_only)
201 .field("use_custom_mask", &self.use_custom_mask)
202 .field("profile", &self.profile)
203 .field("before_layers", &self.before_layers.len())
204 .field("after_layers", &self.after_layers.len())
205 .field("layer_fn", &self.layer_fn.is_some())
206 .field("flow_patch", &self.flow_patch.is_some())
207 .finish_non_exhaustive()
208 }
209}
210
211impl<'a> GemmaFlow<'a> {
212 pub fn new(cfg: &'a GemmaConfig) -> Self {
213 Self {
214 cfg,
215 mode: GemmaMode::Prefill,
216 batch: 1,
217 seq: 128,
218 past_seq: 0,
219 dynamic_seq: false,
220 dynamic_past: false,
221 with_lm_head: false,
222 with_kv_outputs: false,
223 last_logits_only: false,
224 use_custom_mask: false,
225 profile: None,
226 before_layers: Vec::new(),
227 after_layers: Vec::new(),
228 layer_fn: None,
229 flow_patch: None,
230 prefill_hidden: false,
231 media_attn_bias: false,
232 }
233 }
234
235 pub fn prefill_from_hidden(mut self) -> Self {
237 self.prefill_hidden = true;
238 self
239 }
240
241 pub fn prefill_media_attn_bias(mut self) -> Self {
243 self.media_attn_bias = true;
244 self
245 }
246
247 pub fn for_prefill(cfg: &'a GemmaConfig, batch: usize, seq: usize) -> Self {
249 Self::new(cfg).prefill().batch(batch).seq(seq)
250 }
251
252 pub fn for_decode(cfg: &'a GemmaConfig, batch: usize, past_seq: usize) -> Self {
254 Self::new(cfg)
255 .decode()
256 .batch(batch)
257 .past(past_seq)
258 .lm_head()
259 }
260
261 pub fn prefill(mut self) -> Self {
262 self.mode = GemmaMode::Prefill;
263 self
264 }
265
266 pub fn decode(mut self) -> Self {
267 self.mode = GemmaMode::Decode;
268 self
269 }
270
271 pub fn batch(mut self, batch: usize) -> Self {
272 self.batch = batch;
273 self
274 }
275
276 pub fn seq(mut self, seq: usize) -> Self {
278 self.seq = seq;
279 self
280 }
281
282 pub fn past(mut self, past_seq: usize) -> Self {
284 self.past_seq = past_seq;
285 self
286 }
287
288 pub fn dynamic_seq(mut self) -> Self {
290 self.dynamic_seq = true;
291 self
292 }
293
294 pub fn dynamic_past(mut self) -> Self {
296 self.dynamic_past = true;
297 self
298 }
299
300 pub fn lm_head(mut self) -> Self {
301 self.with_lm_head = true;
302 self
303 }
304
305 pub fn hidden_only(mut self) -> Self {
307 self.with_lm_head = false;
308 self.last_logits_only = false;
309 self
310 }
311
312 pub fn last_token_logits(mut self) -> Self {
313 self.with_lm_head = true;
314 self.last_logits_only = true;
315 self
316 }
317
318 pub fn export_kv(mut self) -> Self {
319 self.with_kv_outputs = true;
320 self
321 }
322
323 pub fn custom_mask(mut self) -> Self {
324 self.use_custom_mask = true;
325 self
326 }
327
328 pub fn profile(mut self, profile: CompileProfile) -> Self {
329 self.profile = Some(profile);
330 self
331 }
332
333 pub fn profile_prefill(mut self) -> Self {
335 self.profile = Some(CompileProfile::gemma_prefill());
336 self
337 }
338
339 pub fn profile_decode(mut self) -> Self {
340 self.profile = Some(CompileProfile::gemma_decode());
341 self
342 }
343
344 pub fn profile_near(mut self, weights_path: &Path) -> Self {
345 let decode = self.mode == GemmaMode::Decode;
346 self.profile = Some(gemma_profile_near_weights(weights_path, decode));
347 self
348 }
349
350 pub fn before_layers(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
352 self.before_layers.extend(stages);
353 self
354 }
355
356 pub fn after_layers(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
358 self.after_layers.extend(stages);
359 self
360 }
361
362 pub fn layer<F>(mut self, f: F) -> Self
366 where
367 F: Fn(GemmaLayerCtx<'_>) -> FlowStage + Send + Sync + 'static,
368 {
369 self.layer_fn = Some(Arc::new(f));
370 self
371 }
372
373 pub fn patch_flow<F>(mut self, f: F) -> Self
375 where
376 F: Fn(ModelFlow) -> ModelFlow + Send + Sync + 'static,
377 {
378 self.flow_patch = Some(Arc::new(f));
379 self
380 }
381
382 pub fn build(self, weights: &mut dyn WeightLoader) -> Result<BuiltModel> {
383 match self.mode {
384 GemmaMode::Prefill => self.build_prefill(weights),
385 GemmaMode::Decode => self.build_decode(weights),
386 }
387 }
388
389 fn build_prefill(self, weights: &mut dyn WeightLoader) -> Result<BuiltModel> {
390 if self.dynamic_seq && self.batch != 1 {
391 anyhow::bail!("gemma: dynamic_seq prefill requires batch=1");
392 }
393
394 let cfg = self.cfg;
395 let profile = self.profile.unwrap_or_else(CompileProfile::gemma_prefill);
396 let f = DType::F32;
397 let h = cfg.hidden_size;
398 let eps = cfg.rms_norm_eps as f32;
399 let layer_style = cfg.layer_style();
400
401 let hidden_shape = prefill_hidden_shape(self.batch, self.seq, h, self.dynamic_seq, f);
402 let input_shape = prefill_input_shape(self.batch, self.seq, self.dynamic_seq);
403
404 let rope_factors = weights.take("rope_freqs.weight").ok().map(|(data, _)| data);
405 let inv_freq = resolve_inv_freq(cfg, rope_factors.as_deref());
406 let (cos_data, sin_data) = build_rope_tables(&inv_freq, cfg.max_position_embeddings);
407
408 let global_rope =
414 secondary_rope_tables(cfg, cfg.max_position_embeddings, rope_factors.as_deref());
415
416 let kv_sink = SideOutputs::new();
417
418 let mut flow = ModelFlow::new("gemma").with_profile(profile);
419 if self.prefill_hidden {
420 flow = flow.input("prefill_hidden", hidden_shape.clone());
421 } else {
422 flow = flow.input("input_ids", input_shape);
423 }
424
425 if self.dynamic_seq && self.with_lm_head && self.last_logits_only {
426 flow = flow.input("last_token_idx", Shape::new(&[self.batch], DType::F32));
427 }
428
429 if self.media_attn_bias {
430 let nh = cfg.num_attention_heads;
431 if self.dynamic_seq {
432 flow = flow.input(
433 "attn_bias",
434 Shape::from_dims(
435 &[
436 rlx_ir::shape::Dim::Static(self.batch),
437 rlx_ir::shape::Dim::Static(nh),
438 rlx_ir::shape::Dim::Dynamic(rlx_ir::sym::SEQ),
439 rlx_ir::shape::Dim::Dynamic(rlx_ir::sym::SEQ),
440 ],
441 f,
442 ),
443 );
444 } else {
445 flow = flow.input(
446 "attn_bias",
447 Shape::new(&[self.batch, nh, self.seq, self.seq], f),
448 );
449 }
450 }
451
452 flow = flow
453 .rope_tables(RopeTablesStage::param(
454 cfg.max_position_embeddings,
455 inv_freq.len(),
456 cos_data,
457 sin_data,
458 ))
459 .zero_beta_named("gemma.zero_beta.hidden", h);
460
461 if self.prefill_hidden {
462 flow = flow.plugin_named("gemma.prefill_hidden_bind", move |emit, _| {
463 let hidden = emit
464 .flow_input("prefill_hidden")
465 .map_err(|e| anyhow::anyhow!("prefill_hidden input: {e}"))?;
466 let _ = emit.load_param("model.embed_tokens.weight", false)?;
468 Ok(Some(hidden))
469 });
470 } else {
471 flow = flow
472 .token_embed()
473 .raw_stage(FlowStage::EmbedScale(EmbedScaleStage::new(h)));
474 }
475
476 flow = flow.raw_stages(self.before_layers.iter().cloned());
477
478 if let Some(g) = &global_rope {
479 flow = flow.raw_stage(FlowStage::RopeTables(RopeTablesStage::param_named(
480 "global",
481 cfg.max_position_embeddings,
482 g.half_dim,
483 g.cos.clone(),
484 g.sin.clone(),
485 )));
486 }
487
488 let layer_fn = self.layer_fn.clone();
489 let export = self.with_kv_outputs;
490 let media_bias = self.media_attn_bias;
491 let num_heads = cfg.num_attention_heads;
492 let num_layers = cfg.active_num_layers();
493 let layer_attn: Vec<_> = (0..num_layers).map(|i| cfg.layer_attn_options(i)).collect();
494 let is_moe = cfg.is_moe();
499 let moe_num_experts = cfg.num_experts;
500 let moe_top_k = cfg.num_experts_used;
501 let moe_n_embd = cfg.hidden_size;
502 let moe_n_ff = cfg.expert_ffn_dim();
503 let per_layer: Vec<PerLayerAttn> = (0..num_layers)
509 .map(|i| PerLayerAttn {
510 head_dim: cfg.layer_head_dim(i),
511 num_kv_heads: cfg.layer_num_kv_heads(i),
512 n_rot: cfg.layer_n_rot(i),
513 rope_table: if cfg.is_full_attention_layer(i) && global_rope.is_some() {
514 Some("global".to_string())
515 } else {
516 None
517 },
518 k_eq_v: cfg.attention_k_eq_v,
519 })
520 .collect();
521 flow = flow.repeat_layers(num_layers, {
522 let style = layer_style;
523 let sink = kv_sink.clone();
524 move |i| {
525 let (mask, score_scale, softcap) = layer_attn[i];
526 let pl = &per_layer[i];
527 let lh = pl.head_dim;
528 let mut attn = gemma_attn_spec(
529 i,
530 num_heads,
531 pl.head_dim,
532 pl.num_kv_heads,
533 pl.n_rot,
534 mask,
535 score_scale,
536 softcap,
537 );
538 if let Some(name) = pl.rope_table.as_ref() {
539 attn = attn.with_rope_table(name);
540 }
541 if pl.k_eq_v {
542 attn = attn.with_k_eq_v();
543 }
544 if let Some(ref f) = layer_fn {
545 return f(GemmaLayerCtx::Prefill {
546 index: i,
547 style,
548 attn: attn.clone(),
549 kv_sink: &sink,
550 export_kv: export,
551 head_dim: lh,
552 eps,
553 });
554 }
555 if media_bias {
556 return crate::multimodal_flow::multimodal_layer_override(
557 GemmaLayerCtx::Prefill {
558 index: i,
559 style,
560 attn,
561 kv_sink: &sink,
562 export_kv: export,
563 head_dim: lh,
564 eps,
565 },
566 true,
567 );
568 }
569 if is_moe {
570 let prefix = format!("model.layers.{i}");
571 let moe = rlx_flow::blocks::MoeFfnStage::hf(
572 prefix,
573 moe_num_experts,
574 moe_top_k,
575 moe_n_embd,
576 moe_n_ff,
577 );
578 let kv = if export { Some(sink.inner()) } else { None };
579 return rlx_flow::blocks::gemma_moe_prefill_layer_composed(
580 i, style, attn, eps, kv, moe,
581 );
582 }
583 GemmaLayerCtx::Prefill {
584 index: i,
585 style,
586 attn,
587 kv_sink: &sink,
588 export_kv: export,
589 head_dim: lh,
590 eps,
591 }
592 .default_stage()
593 }
594 });
595
596 flow = flow.raw_stages(self.after_layers.iter().cloned());
597
598 if self.with_lm_head && self.last_logits_only {
599 flow = if self.dynamic_seq {
600 flow.gather_last_token_dynamic(self.batch)
601 } else {
602 flow.gather_last_token_at(self.batch, self.seq)
603 };
604 }
605
606 flow = flow.raw_stage(FlowStage::GemmaRmsNorm(GemmaRmsNormStage::hf_layer(
607 "model.norm",
608 eps,
609 )));
610
611 if let Some(patch) = self.flow_patch {
612 flow = patch(flow);
613 }
614
615 let mut built = if self.with_lm_head {
616 let lm = if cfg.tie_word_embeddings {
617 FlowStage::LmHead(LmHeadStage::tied(cfg.vocab_size, h))
618 } else {
619 FlowStage::LmHead(LmHeadStage::separate("lm_head.weight", cfg.vocab_size, h))
620 };
621 flow = flow.raw_stage(lm);
622 if let Some(cap) = cfg.final_logit_softcapping {
623 flow = flow.raw_stage(FlowStage::LogitSoftcap(LogitSoftcapStage::new(cap)));
624 }
625 flow.output("logits")
626 .build(&mut WeightLoaderSource(weights))?
627 } else {
628 flow.output("hidden")
629 .build(&mut WeightLoaderSource(weights))?
630 };
631
632 if self.with_kv_outputs {
633 built = built.with_extra_hir_outputs(kv_sink.drain());
634 }
635 Ok(built)
636 }
637
638 fn build_decode(self, weights: &mut dyn WeightLoader) -> Result<BuiltModel> {
639 let cfg = self.cfg;
640 let profile = self.profile.unwrap_or_else(CompileProfile::gemma_decode);
641 let f = DType::F32;
642 let h = cfg.hidden_size;
643 let eps = cfg.rms_norm_eps as f32;
644 let dh = cfg.head_dim();
645 let half = dh / 2;
646
647 let hidden_shape = Shape::new(&[self.batch, 1, h], f);
648
649 let decode_style = cfg.layer_style();
650 let decode_score_scale = cfg.attn_score_scale();
651 let decode_softcap = cfg.attn_logit_softcapping;
652 let decode_arch = cfg.arch;
653 let decode_sliding = cfg.sliding_window;
654
655 let kv_out = SideOutputs::new();
656
657 let rope_factors = weights.take("rope_freqs.weight").ok().map(|(data, _)| data);
658 let inv_freq = resolve_inv_freq(cfg, rope_factors.as_deref());
659 let (rope_cos, rope_sin) = if self.dynamic_past {
660 (Vec::new(), Vec::new())
661 } else {
662 crate::rope::rope_slice(&inv_freq, self.past_seq)
663 };
664
665 let global_rope_row = if !self.dynamic_past {
670 secondary_rope_row(cfg, self.past_seq, rope_factors.as_deref())
671 } else {
672 None
673 };
674 let global_params = needs_secondary_rope_params(cfg);
675
676 let mut flow = ModelFlow::new("gemma_decode")
677 .with_profile(profile)
678 .input("input_ids", Shape::new(&[self.batch, 1], DType::F32));
679
680 if self.dynamic_past {
681 flow = flow
682 .input("rope_cos", Shape::new(&[1, half], f))
683 .input("rope_sin", Shape::new(&[1, half], f));
684 if let Some(gp) = global_params {
685 let half_global =
686 crate::rope::resolve_global_inv_freq(cfg, rope_factors.as_deref())
687 .map(|v| v.len())
688 .unwrap_or_else(|| crate::rope::default_inv_freq(gp.theta, gp.n_rot).len());
689 flow = flow
690 .input("rope_cos_global", Shape::new(&[1, half_global], f))
691 .input("rope_sin_global", Shape::new(&[1, half_global], f))
692 .raw_stage(FlowStage::Custom(rlx_flow::blocks::CustomStage::named(
693 "gemma.bind_global_decode_rope",
694 |emit, val| {
695 let cos = find_hir_input(emit.hir(), "rope_cos_global")?;
700 let sin = find_hir_input(emit.hir(), "rope_sin_global")?;
701 emit.set_named("global_cos", cos);
702 emit.set_named("global_sin", sin);
703 Ok(val)
704 },
705 )));
706 }
707 }
708
709 if self.use_custom_mask {
710 flow = flow.input("mask", Shape::new(&[self.batch, self.past_seq + 1], f));
711 }
712
713 for layer_idx in 0..cfg.num_hidden_layers {
718 let layer_kv_dim = cfg.layer_num_kv_heads(layer_idx) * cfg.layer_head_dim(layer_idx);
719 let shape = if self.dynamic_past {
720 Shape::from_dims(
721 &[
722 Dim::Static(self.batch),
723 Dim::Dynamic(sym::PAST_SEQ),
724 Dim::Static(layer_kv_dim),
725 ],
726 f,
727 )
728 } else {
729 Shape::new(&[self.batch, self.past_seq, layer_kv_dim], f)
730 };
731 flow = flow
732 .input(format!("past_k_{layer_idx}"), shape.clone())
733 .input(format!("past_v_{layer_idx}"), shape);
734 }
735
736 if !self.dynamic_past {
737 flow = flow.raw_stage(FlowStage::DecodeRopeParams(DecodeRopeParamsStage::new(
738 rope_cos, rope_sin, half,
739 )));
740 if let Some(g) = &global_rope_row {
741 flow = flow.raw_stage(FlowStage::DecodeRopeParams(DecodeRopeParamsStage::named(
742 "global",
743 g.cos.clone(),
744 g.sin.clone(),
745 g.half_dim,
746 )));
747 }
748 }
749
750 flow = flow
751 .bind_decode_inputs(cfg.num_hidden_layers, self.use_custom_mask)
752 .zero_beta_named("gemma.zero_beta.hidden", h)
753 .token_embed()
754 .raw_stage(FlowStage::EmbedScale(EmbedScaleStage::new(h)))
755 .raw_stages(self.before_layers.iter().cloned());
756
757 let layer_fn = self.layer_fn.clone();
758 let use_custom_mask = self.use_custom_mask;
759 let num_heads = cfg.num_attention_heads;
760 let num_layers = cfg.active_num_layers();
761 let secondary_rope_active = global_rope_row.is_some();
764 let per_layer_decode: Vec<PerLayerAttn> = (0..num_layers)
765 .map(|i| PerLayerAttn {
766 head_dim: cfg.layer_head_dim(i),
767 num_kv_heads: cfg.layer_num_kv_heads(i),
768 n_rot: cfg.layer_n_rot(i),
769 rope_table: if cfg.is_full_attention_layer(i) && secondary_rope_active {
770 Some("global".to_string())
771 } else {
772 None
773 },
774 k_eq_v: cfg.attention_k_eq_v,
775 })
776 .collect();
777 let is_moe = cfg.is_moe();
779 let moe_num_experts = cfg.num_experts;
780 let moe_top_k = cfg.num_experts_used;
781 let moe_n_embd = cfg.hidden_size;
782 let moe_n_ff = cfg.expert_ffn_dim();
783 flow = flow.repeat_layers(num_layers, {
784 let sink = kv_out.clone();
785 let hidden_shape = hidden_shape.clone();
786 move |i| {
787 let mask = if use_custom_mask {
788 rlx_ir::op::MaskKind::Causal
789 } else {
790 match (decode_arch, decode_sliding) {
791 (GemmaArch::Gemma2, Some(w)) => rlx_flow::blocks::gemma2_layer_mask(i, w),
792 (GemmaArch::Gemma3 | GemmaArch::Gemma4, Some(w)) => {
796 rlx_flow::blocks::gemma_strided_layer_mask(
797 i,
798 w,
799 decode_arch.sliding_window_stride(),
800 )
801 }
802 _ => rlx_ir::op::MaskKind::Causal,
803 }
804 };
805 let pl = &per_layer_decode[i];
806 let kv_group_size = num_heads / pl.num_kv_heads;
807 let spec = GemmaDecodeLayerSpec {
808 style: decode_style,
809 num_heads,
810 head_dim: pl.head_dim,
811 num_kv_heads: pl.num_kv_heads,
812 kv_group_size,
813 n_rot: pl.n_rot,
814 rope_table: pl.rope_table.clone(),
815 k_eq_v: pl.k_eq_v,
816 eps,
817 use_custom_mask,
818 hidden_shape: hidden_shape.clone(),
819 mask,
820 score_scale: decode_score_scale,
821 attn_logit_softcap: decode_softcap,
822 };
823 if let Some(ref f) = layer_fn {
824 return f(GemmaLayerCtx::Decode {
825 index: i,
826 spec: spec.clone(),
827 kv_out: &sink,
828 });
829 }
830 if is_moe {
831 let prefix = format!("model.layers.{i}");
832 let moe = rlx_flow::blocks::MoeFfnStage::hf(
833 prefix,
834 moe_num_experts,
835 moe_top_k,
836 moe_n_embd,
837 moe_n_ff,
838 );
839 return rlx_flow::blocks::gemma_moe_decode_layer_composed(
840 i,
841 spec,
842 sink.inner(),
843 moe,
844 );
845 }
846 GemmaLayerCtx::Decode {
847 index: i,
848 spec,
849 kv_out: &sink,
850 }
851 .default_stage()
852 }
853 });
854
855 flow = flow.raw_stages(self.after_layers.iter().cloned());
856
857 if let Some(patch) = self.flow_patch {
858 flow = patch(flow);
859 }
860
861 let mut flow = flow.raw_stage(FlowStage::GemmaRmsNorm(GemmaRmsNormStage::hf_layer(
862 "model.norm",
863 eps,
864 )));
865 let lm = if cfg.tie_word_embeddings {
866 FlowStage::LmHead(LmHeadStage::tied(cfg.vocab_size, h))
867 } else {
868 FlowStage::LmHead(LmHeadStage::separate("lm_head.weight", cfg.vocab_size, h))
869 };
870 flow = flow.raw_stage(lm);
871 if let Some(cap) = cfg.final_logit_softcapping {
872 flow = flow.raw_stage(FlowStage::LogitSoftcap(LogitSoftcapStage::new(cap)));
873 }
874 let built = flow
875 .output("logits")
876 .build(&mut WeightLoaderSource(weights))?
877 .with_extra_hir_outputs(kv_out.drain());
878
879 Ok(built)
880 }
881}
882
883fn prefill_hidden_shape(
884 batch: usize,
885 seq: usize,
886 hidden: usize,
887 dynamic: bool,
888 dtype: DType,
889) -> Shape {
890 if dynamic {
891 Shape::from_dims(
892 &[
893 Dim::Static(batch),
894 Dim::Dynamic(sym::SEQ),
895 Dim::Static(hidden),
896 ],
897 dtype,
898 )
899 } else {
900 Shape::new(&[batch, seq, hidden], dtype)
901 }
902}
903
904fn prefill_input_shape(batch: usize, seq: usize, dynamic: bool) -> Shape {
905 if dynamic {
906 Shape::from_dims(&[Dim::Static(batch), Dim::Dynamic(sym::SEQ)], DType::F32)
907 } else {
908 Shape::new(&[batch, seq], DType::F32)
909 }
910}
911
912#[derive(Debug, Clone)]
917struct PerLayerAttn {
918 head_dim: usize,
919 num_kv_heads: usize,
920 n_rot: usize,
921 rope_table: Option<String>,
922 k_eq_v: bool,
923}
924
925#[derive(Debug, Clone)]
926struct GlobalRopeTables {
927 cos: Vec<f32>,
928 sin: Vec<f32>,
929 half_dim: usize,
930}
931
932fn secondary_rope_tables(
937 cfg: &GemmaConfig,
938 max_pos: usize,
939 factors: Option<&[f32]>,
940) -> Option<GlobalRopeTables> {
941 let inv = crate::rope::resolve_global_inv_freq(cfg, factors)?;
942 let (cos, sin) = crate::rope::build_rope_tables(&inv, max_pos);
943 Some(GlobalRopeTables {
944 cos,
945 sin,
946 half_dim: inv.len(),
947 })
948}
949
950fn secondary_rope_row(
952 cfg: &GemmaConfig,
953 pos: usize,
954 factors: Option<&[f32]>,
955) -> Option<GlobalRopeTables> {
956 let inv = crate::rope::resolve_global_inv_freq(cfg, factors)?;
957 let (cos, sin) = crate::rope::rope_slice(&inv, pos);
958 Some(GlobalRopeTables {
959 cos,
960 sin,
961 half_dim: inv.len(),
962 })
963}
964
965fn needs_secondary_rope_params(cfg: &GemmaConfig) -> Option<GlobalRopeParams> {
966 crate::rope::global_rope_params(cfg).map(|(theta, n_rot)| GlobalRopeParams { theta, n_rot })
967}
968
969#[derive(Debug, Clone, Copy)]
970struct GlobalRopeParams {
971 theta: f64,
972 n_rot: usize,
973}
974
975fn find_hir_input(hir: &HirModule, name: &str) -> anyhow::Result<rlx_ir::HirNodeId> {
976 use rlx_ir::hir::HirOp;
977 for node in hir.nodes() {
978 if let HirOp::Input { name: n } = &node.op {
979 if n == name {
980 return Ok(node.id);
981 }
982 }
983 }
984 Err(anyhow::anyhow!("gemma decode flow missing input: {name}"))
985}
986
987impl<'a> GemmaFlow<'a> {
990 fn from_prefill_opts(cfg: &'a GemmaConfig, o: &GemmaPrefillOpts) -> Self {
991 let mut f = GemmaFlow::new(cfg).prefill().batch(o.batch).seq(o.seq);
992 if o.dynamic_seq {
993 f = f.dynamic_seq();
994 }
995 if o.prefill_hidden {
996 f = f.prefill_from_hidden();
997 }
998 if o.media_attn_bias {
999 f = f.prefill_media_attn_bias();
1000 }
1001 if o.with_lm_head {
1002 f = f.lm_head();
1003 }
1004 if o.with_kv_outputs {
1005 f = f.export_kv();
1006 }
1007 if o.last_logits_only {
1008 f = f.last_token_logits();
1009 }
1010 if let Some(p) = o.profile.clone() {
1011 f = f.profile(p);
1012 }
1013 f
1014 }
1015
1016 fn from_decode_opts(cfg: &'a GemmaConfig, o: &GemmaDecodeOpts) -> Self {
1017 let mut f = GemmaFlow::new(cfg)
1018 .decode()
1019 .batch(o.batch)
1020 .past(o.past_seq)
1021 .lm_head();
1022 if o.dynamic_past {
1023 f = f.dynamic_past();
1024 }
1025 if o.use_custom_mask {
1026 f = f.custom_mask();
1027 }
1028 if let Some(p) = o.profile.clone() {
1029 f = f.profile(p);
1030 }
1031 f
1032 }
1033}
1034
1035#[derive(Debug, Clone)]
1037pub struct GemmaPrefillOpts {
1038 pub batch: usize,
1039 pub seq: usize,
1040 pub dynamic_seq: bool,
1041 pub prefill_hidden: bool,
1042 pub media_attn_bias: bool,
1043 pub with_lm_head: bool,
1044 pub with_kv_outputs: bool,
1045 pub last_logits_only: bool,
1046 pub profile: Option<CompileProfile>,
1047}
1048
1049impl GemmaPrefillOpts {
1050 pub fn static_prefill(batch: usize, seq: usize) -> Self {
1051 Self {
1052 batch,
1053 seq,
1054 dynamic_seq: false,
1055 prefill_hidden: false,
1056 media_attn_bias: false,
1057 with_lm_head: false,
1058 with_kv_outputs: false,
1059 last_logits_only: false,
1060 profile: None,
1061 }
1062 }
1063}
1064
1065#[derive(Debug, Clone)]
1067pub struct GemmaDecodeOpts {
1068 pub batch: usize,
1069 pub past_seq: usize,
1070 pub dynamic_past: bool,
1071 pub use_custom_mask: bool,
1072 pub profile: Option<CompileProfile>,
1073}
1074
1075pub fn build_gemma_prefill_flow(
1076 cfg: &GemmaConfig,
1077 weights: &mut dyn WeightLoader,
1078 opts: &GemmaPrefillOpts,
1079) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
1080 build_gemma_prefill_built(cfg, weights, opts)?.into_parts()
1081}
1082
1083pub fn build_gemma_prefill_built(
1084 cfg: &GemmaConfig,
1085 weights: &mut dyn WeightLoader,
1086 opts: &GemmaPrefillOpts,
1087) -> Result<BuiltModel> {
1088 GemmaFlow::from_prefill_opts(cfg, opts).build(weights)
1089}
1090
1091pub fn build_gemma_decode_flow(
1092 cfg: &GemmaConfig,
1093 weights: &mut dyn WeightLoader,
1094 opts: &GemmaDecodeOpts,
1095) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
1096 build_gemma_decode_built(cfg, weights, opts)?.into_parts()
1097}
1098
1099pub fn build_gemma_decode_graph(
1100 cfg: &GemmaConfig,
1101 weights: &mut dyn WeightLoader,
1102 opts: &GemmaDecodeOpts,
1103) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
1104 rlx_core::flow_util::graph_from_built(build_gemma_decode_built(cfg, weights, opts)?)
1105}
1106
1107pub fn build_gemma_decode_built(
1108 cfg: &GemmaConfig,
1109 weights: &mut dyn WeightLoader,
1110 opts: &GemmaDecodeOpts,
1111) -> Result<BuiltModel> {
1112 GemmaFlow::from_decode_opts(cfg, opts).build(weights)
1113}
1114
1115#[cfg(test)]
1116mod gemma4_tests {
1117 use super::*;
1118 use crate::config::{
1119 GemmaArch, GemmaLayerType, GemmaRopeKind, GemmaRopeMap, GemmaRopeParameters,
1120 };
1121
1122 fn gemma4_12b_like() -> GemmaConfig {
1123 let mut cfg = GemmaConfig::tiny_test();
1124 cfg.arch = GemmaArch::Gemma4;
1125 cfg.hidden_size = 3840;
1126 cfg.intermediate_size = 15_360;
1127 cfg.num_hidden_layers = 12; cfg.num_attention_heads = 16;
1129 cfg.num_key_value_heads = 8;
1130 cfg.head_dim = Some(256);
1131 cfg.global_head_dim = Some(512);
1132 cfg.num_global_key_value_heads = Some(1);
1133 cfg.attention_k_eq_v = true;
1134 cfg.sliding_window = Some(1024);
1135 cfg.final_logit_softcapping = Some(30.0);
1136 cfg.tie_word_embeddings = true;
1137 cfg.max_position_embeddings = 4096;
1138 cfg.rope_theta = 10_000.0;
1139 cfg.layer_types = (0..cfg.num_hidden_layers)
1141 .map(|i| {
1142 if (i + 1) % 6 == 0 {
1143 GemmaLayerType::FullAttention
1144 } else {
1145 GemmaLayerType::SlidingAttention
1146 }
1147 })
1148 .collect();
1149 cfg.rope_parameters = GemmaRopeMap {
1150 sliding_attention: Some(GemmaRopeParameters {
1151 rope_theta: Some(10_000.0),
1152 rope_type: Some(GemmaRopeKind::Default),
1153 partial_rotary_factor: None,
1154 }),
1155 full_attention: Some(GemmaRopeParameters {
1156 rope_theta: Some(1_000_000.0),
1157 rope_type: Some(GemmaRopeKind::Proportional),
1158 partial_rotary_factor: Some(0.25),
1159 }),
1160 };
1161 cfg
1162 }
1163
1164 #[test]
1165 fn secondary_rope_emits_distinct_table_for_full_attention() {
1166 let cfg = gemma4_12b_like();
1167 let tables = secondary_rope_tables(&cfg, cfg.max_position_embeddings, None)
1168 .expect("Gemma 4 split rope_parameters should produce a secondary table");
1169 assert_eq!(tables.half_dim, 64);
1171 assert_eq!(tables.cos.len(), cfg.max_position_embeddings * 64);
1172 assert_eq!(tables.sin.len(), tables.cos.len());
1173
1174 assert!((tables.cos[0] - 1.0).abs() < 1e-6);
1176 assert!(tables.sin[0].abs() < 1e-6);
1177 let global_inv = crate::rope::default_inv_freq(1_000_000.0, 128);
1180 let sliding_inv = crate::rope::default_inv_freq(10_000.0, 128);
1181 assert!((global_inv[5] - sliding_inv[5]).abs() > 1e-3);
1182 let global_cos_p1_d5 = (1.0 * global_inv[5]).cos();
1183 let global_sample = tables.cos[64 + 5]; assert!((global_sample as f64 - global_cos_p1_d5).abs() < 1e-5);
1185 }
1186
1187 #[test]
1188 fn per_layer_kv_dims_diverge_on_full_attention() {
1189 let cfg = gemma4_12b_like();
1190 assert_eq!(cfg.layer_num_kv_heads(0) * cfg.layer_head_dim(0), 2048);
1192 assert_eq!(cfg.layer_num_kv_heads(5) * cfg.layer_head_dim(5), 512);
1194 assert_eq!(cfg.layer_num_kv_heads(11) * cfg.layer_head_dim(11), 512);
1195 }
1196
1197 #[test]
1198 fn no_secondary_table_when_params_match() {
1199 let mut cfg = gemma4_12b_like();
1203 cfg.rope_parameters.full_attention = cfg.rope_parameters.sliding_attention;
1204 cfg.global_head_dim = None;
1205 cfg.num_global_key_value_heads = None;
1206 assert!(secondary_rope_tables(&cfg, cfg.max_position_embeddings, None).is_none());
1207 }
1208}