Skip to main content

rlx_gemma/
flow.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Fluent Gemma model assembly — tier-0 reference for `rlx-flow`.
17//!
18//! ```rust,ignore
19//! use rlx_models::gemma::GemmaFlow;
20//!
21//! // Prefill logits for the last token
22//! let built = GemmaFlow::for_prefill(&cfg, 1, 128)
23//!     .last_token_logits()
24//!     .profile_near(&weights_path)
25//!     .build(&mut weights)?;
26//!
27//! // Decode step with KV side outputs
28//! let built = GemmaFlow::for_decode(&cfg, 1, 256)
29//!     .custom_mask()
30//!     .profile_decode()
31//!     .build(&mut weights)?;
32//!
33//! // Override one layer while keeping the rest of the recipe
34//! let built = GemmaFlow::for_prefill(&cfg, 1, 128)
35//!     .layer(|ctx| {
36//!         if ctx.index() == 0 {
37//!             ctx.default_stage() // or FlowStage::Custom(...)
38//!         } else {
39//!             ctx.default_stage()
40//!         }
41//!     })
42//!     .build(&mut weights)?;
43//! ```
44
45use std::collections::HashMap;
46use std::fmt;
47use std::path::Path;
48use std::sync::Arc;
49
50use anyhow::Result;
51use rlx_flow::blocks::{
52    DecodeRopeParamsStage, EmbedScaleStage, GemmaDecodeLayerSpec, GemmaDecodeLayerStage,
53    GemmaLayerStyle, GemmaRmsNormStage, LmHeadStage, LogitSoftcapStage, RopeTablesStage,
54    gemma_attn_spec, gemma_prefill_layer_composed,
55};
56use rlx_flow::{BuiltModel, CompileProfile, FlowStage, ModelFlow, SideOutputs};
57use rlx_ir::dynamic::sym;
58use rlx_ir::hir::HirModule;
59use rlx_ir::shape::Dim;
60use rlx_ir::{DType, Graph, Shape};
61
62use super::config::{GemmaArch, GemmaConfig};
63use super::rope::{build_rope_tables, resolve_inv_freq};
64use rlx_core::flow_bridge::{WeightLoaderSource, load_compile_profile};
65use rlx_core::weight_loader::WeightLoader;
66
67/// Tier-1 profile file name colocated with weights.
68pub const GEMMA_PROFILE_FILE: &str = "gemma.rlx.toml";
69
70/// Resolve compile profile from `gemma.rlx.toml` in the weights directory.
71pub fn gemma_profile_near_weights(weights: &Path, decode: bool) -> CompileProfile {
72    let default = if decode {
73        CompileProfile::gemma_decode()
74    } else {
75        CompileProfile::gemma_prefill()
76    };
77    let dir = weights.parent().unwrap_or_else(|| Path::new("."));
78    load_compile_profile(&dir.join(GEMMA_PROFILE_FILE), default)
79}
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
82pub enum GemmaMode {
83    Prefill,
84    Decode,
85}
86
87/// Per-layer context for `.layer()` overrides — defaults preserve stock Gemma blocks.
88pub enum GemmaLayerCtx<'a> {
89    Prefill {
90        index: usize,
91        style: GemmaLayerStyle,
92        attn: rlx_flow::blocks::SelfAttnPrefillSpec,
93        kv_sink: &'a SideOutputs,
94        export_kv: bool,
95        head_dim: usize,
96        eps: f32,
97    },
98    Decode {
99        index: usize,
100        spec: GemmaDecodeLayerSpec,
101        kv_out: &'a SideOutputs,
102    },
103}
104
105impl GemmaLayerCtx<'_> {
106    pub fn index(&self) -> usize {
107        match self {
108            Self::Prefill { index, .. } | Self::Decode { index, .. } => *index,
109        }
110    }
111
112    pub fn default_stage(&self) -> FlowStage {
113        match self {
114            Self::Prefill {
115                index,
116                style,
117                attn,
118                kv_sink,
119                export_kv,
120                head_dim: _,
121                eps,
122            } => gemma_prefill_layer_composed(
123                *index,
124                *style,
125                attn.clone(),
126                *eps,
127                if *export_kv {
128                    Some(kv_sink.inner())
129                } else {
130                    None
131                },
132            ),
133            Self::Decode {
134                index,
135                spec,
136                kv_out,
137            } => FlowStage::Named {
138                name: format!("layer{index}"),
139                inner: Arc::new(FlowStage::GemmaDecodeLayer(GemmaDecodeLayerStage::layer(
140                    *index,
141                    spec.clone(),
142                    kv_out.inner(),
143                ))),
144            },
145        }
146    }
147}
148
149type LayerFn = Arc<dyn Fn(GemmaLayerCtx<'_>) -> FlowStage + Send + Sync>;
150type FlowPatchFn = Arc<dyn Fn(ModelFlow) -> ModelFlow + Send + Sync>;
151
152/// Fluent Gemma flow builder — reads config once, chain modifiers, then `build`.
153///
154/// ```rust,ignore
155/// use rlx_models::gemma::{GemmaConfig, GemmaFlow};
156///
157/// let built = GemmaFlow::new(&cfg)
158///     .prefill()
159///     .batch(1)
160///     .seq(128)
161///     .lm_head()
162///     .last_token_logits()
163///     .build(&mut weights)?;
164/// ```
165#[derive(Clone)]
166pub struct GemmaFlow<'a> {
167    cfg: &'a GemmaConfig,
168    mode: GemmaMode,
169    batch: usize,
170    seq: usize,
171    past_seq: usize,
172    dynamic_seq: bool,
173    dynamic_past: bool,
174    with_lm_head: bool,
175    with_kv_outputs: bool,
176    last_logits_only: bool,
177    use_custom_mask: bool,
178    profile: Option<CompileProfile>,
179    before_layers: Vec<FlowStage>,
180    after_layers: Vec<FlowStage>,
181    layer_fn: Option<LayerFn>,
182    flow_patch: Option<FlowPatchFn>,
183    /// Prefill from fused `inputs_embeds` (`prefill_hidden` input) instead of token ids.
184    prefill_hidden: bool,
185    /// Sliding layers read additive `attn_bias` for vision bidirectional spans.
186    media_attn_bias: bool,
187}
188
189impl fmt::Debug for GemmaFlow<'_> {
190    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191        f.debug_struct("GemmaFlow")
192            .field("mode", &self.mode)
193            .field("batch", &self.batch)
194            .field("seq", &self.seq)
195            .field("past_seq", &self.past_seq)
196            .field("dynamic_seq", &self.dynamic_seq)
197            .field("dynamic_past", &self.dynamic_past)
198            .field("with_lm_head", &self.with_lm_head)
199            .field("with_kv_outputs", &self.with_kv_outputs)
200            .field("last_logits_only", &self.last_logits_only)
201            .field("use_custom_mask", &self.use_custom_mask)
202            .field("profile", &self.profile)
203            .field("before_layers", &self.before_layers.len())
204            .field("after_layers", &self.after_layers.len())
205            .field("layer_fn", &self.layer_fn.is_some())
206            .field("flow_patch", &self.flow_patch.is_some())
207            .finish_non_exhaustive()
208    }
209}
210
211impl<'a> GemmaFlow<'a> {
212    pub fn new(cfg: &'a GemmaConfig) -> Self {
213        Self {
214            cfg,
215            mode: GemmaMode::Prefill,
216            batch: 1,
217            seq: 128,
218            past_seq: 0,
219            dynamic_seq: false,
220            dynamic_past: false,
221            with_lm_head: false,
222            with_kv_outputs: false,
223            last_logits_only: false,
224            use_custom_mask: false,
225            profile: None,
226            before_layers: Vec::new(),
227            after_layers: Vec::new(),
228            layer_fn: None,
229            flow_patch: None,
230            prefill_hidden: false,
231            media_attn_bias: false,
232        }
233    }
234
235    /// Skip token embedding — feed pre-scaled hidden states at `prefill_hidden`.
236    pub fn prefill_from_hidden(mut self) -> Self {
237        self.prefill_hidden = true;
238        self
239    }
240
241    /// Add `attn_bias` input and bidirectional self-attn on sliding layers.
242    pub fn prefill_media_attn_bias(mut self) -> Self {
243        self.media_attn_bias = true;
244        self
245    }
246
247    /// Prefill recipe with common batch/seq defaults.
248    pub fn for_prefill(cfg: &'a GemmaConfig, batch: usize, seq: usize) -> Self {
249        Self::new(cfg).prefill().batch(batch).seq(seq)
250    }
251
252    /// Decode recipe with common batch/past defaults (includes LM head).
253    pub fn for_decode(cfg: &'a GemmaConfig, batch: usize, past_seq: usize) -> Self {
254        Self::new(cfg)
255            .decode()
256            .batch(batch)
257            .past(past_seq)
258            .lm_head()
259    }
260
261    pub fn prefill(mut self) -> Self {
262        self.mode = GemmaMode::Prefill;
263        self
264    }
265
266    pub fn decode(mut self) -> Self {
267        self.mode = GemmaMode::Decode;
268        self
269    }
270
271    pub fn batch(mut self, batch: usize) -> Self {
272        self.batch = batch;
273        self
274    }
275
276    /// Prefill sequence length (ignored in decode mode).
277    pub fn seq(mut self, seq: usize) -> Self {
278        self.seq = seq;
279        self
280    }
281
282    /// Decode past length (ignored in prefill mode).
283    pub fn past(mut self, past_seq: usize) -> Self {
284        self.past_seq = past_seq;
285        self
286    }
287
288    /// Symbolic sequence dim (`sym::SEQ`) for dynamic prefill specialization.
289    pub fn dynamic_seq(mut self) -> Self {
290        self.dynamic_seq = true;
291        self
292    }
293
294    /// Symbolic past dim (`sym::PAST_SEQ`) for dynamic decode specialization.
295    pub fn dynamic_past(mut self) -> Self {
296        self.dynamic_past = true;
297        self
298    }
299
300    pub fn lm_head(mut self) -> Self {
301        self.with_lm_head = true;
302        self
303    }
304
305    /// Hidden states only — skip LM head (default for prefill unless `.lm_head()`).
306    pub fn hidden_only(mut self) -> Self {
307        self.with_lm_head = false;
308        self.last_logits_only = false;
309        self
310    }
311
312    pub fn last_token_logits(mut self) -> Self {
313        self.with_lm_head = true;
314        self.last_logits_only = true;
315        self
316    }
317
318    pub fn export_kv(mut self) -> Self {
319        self.with_kv_outputs = true;
320        self
321    }
322
323    pub fn custom_mask(mut self) -> Self {
324        self.use_custom_mask = true;
325        self
326    }
327
328    pub fn profile(mut self, profile: CompileProfile) -> Self {
329        self.profile = Some(profile);
330        self
331    }
332
333    /// Fusion-first prefill profile preset.
334    pub fn profile_prefill(mut self) -> Self {
335        self.profile = Some(CompileProfile::gemma_prefill());
336        self
337    }
338
339    pub fn profile_decode(mut self) -> Self {
340        self.profile = Some(CompileProfile::gemma_decode());
341        self
342    }
343
344    pub fn profile_near(mut self, weights_path: &Path) -> Self {
345        let decode = self.mode == GemmaMode::Decode;
346        self.profile = Some(gemma_profile_near_weights(weights_path, decode));
347        self
348    }
349
350    /// Insert custom stages after embedding, before the layer stack.
351    pub fn before_layers(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
352        self.before_layers.extend(stages);
353        self
354    }
355
356    /// Insert custom stages after the layer stack, before final norm / LM head.
357    pub fn after_layers(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
358        self.after_layers.extend(stages);
359        self
360    }
361
362    /// Override per-layer construction (prefill or decode depending on mode).
363    ///
364    /// Call [`GemmaLayerCtx::default_stage`] to keep stock blocks for unmodified layers.
365    pub fn layer<F>(mut self, f: F) -> Self
366    where
367        F: Fn(GemmaLayerCtx<'_>) -> FlowStage + Send + Sync + 'static,
368    {
369        self.layer_fn = Some(Arc::new(f));
370        self
371    }
372
373    /// Patch the assembled [`ModelFlow`] before build — full flexibility escape hatch.
374    pub fn patch_flow<F>(mut self, f: F) -> Self
375    where
376        F: Fn(ModelFlow) -> ModelFlow + Send + Sync + 'static,
377    {
378        self.flow_patch = Some(Arc::new(f));
379        self
380    }
381
382    pub fn build(self, weights: &mut dyn WeightLoader) -> Result<BuiltModel> {
383        match self.mode {
384            GemmaMode::Prefill => self.build_prefill(weights),
385            GemmaMode::Decode => self.build_decode(weights),
386        }
387    }
388
389    fn build_prefill(self, weights: &mut dyn WeightLoader) -> Result<BuiltModel> {
390        if self.dynamic_seq && self.batch != 1 {
391            anyhow::bail!("gemma: dynamic_seq prefill requires batch=1");
392        }
393
394        let cfg = self.cfg;
395        let profile = self.profile.unwrap_or_else(CompileProfile::gemma_prefill);
396        let f = DType::F32;
397        let h = cfg.hidden_size;
398        let eps = cfg.rms_norm_eps as f32;
399        let layer_style = cfg.layer_style();
400
401        let hidden_shape = prefill_hidden_shape(self.batch, self.seq, h, self.dynamic_seq, f);
402        let input_shape = prefill_input_shape(self.batch, self.seq, self.dynamic_seq);
403
404        let rope_factors = weights.take("rope_freqs.weight").ok().map(|(data, _)| data);
405        let inv_freq = resolve_inv_freq(cfg, rope_factors.as_deref());
406        let (cos_data, sin_data) = build_rope_tables(&inv_freq, cfg.max_position_embeddings);
407
408        // When Gemma 4 ships split rope_parameters with a distinct
409        // full-attention theta and/or partial_rotary_factor, build a
410        // second cos/sin table under the "global" slot. The per-layer
411        // closure below opts full-attention layers into it via
412        // SelfAttnPrefillSpec::with_rope_table("global").
413        let global_rope =
414            secondary_rope_tables(cfg, cfg.max_position_embeddings, rope_factors.as_deref());
415
416        let kv_sink = SideOutputs::new();
417
418        let mut flow = ModelFlow::new("gemma").with_profile(profile);
419        if self.prefill_hidden {
420            flow = flow.input("prefill_hidden", hidden_shape.clone());
421        } else {
422            flow = flow.input("input_ids", input_shape);
423        }
424
425        if self.dynamic_seq && self.with_lm_head && self.last_logits_only {
426            flow = flow.input("last_token_idx", Shape::new(&[self.batch], DType::F32));
427        }
428
429        if self.media_attn_bias {
430            let nh = cfg.num_attention_heads;
431            if self.dynamic_seq {
432                flow = flow.input(
433                    "attn_bias",
434                    Shape::from_dims(
435                        &[
436                            rlx_ir::shape::Dim::Static(self.batch),
437                            rlx_ir::shape::Dim::Static(nh),
438                            rlx_ir::shape::Dim::Dynamic(rlx_ir::sym::SEQ),
439                            rlx_ir::shape::Dim::Dynamic(rlx_ir::sym::SEQ),
440                        ],
441                        f,
442                    ),
443                );
444            } else {
445                flow = flow.input(
446                    "attn_bias",
447                    Shape::new(&[self.batch, nh, self.seq, self.seq], f),
448                );
449            }
450        }
451
452        flow = flow
453            .rope_tables(RopeTablesStage::param(
454                cfg.max_position_embeddings,
455                inv_freq.len(),
456                cos_data,
457                sin_data,
458            ))
459            .zero_beta_named("gemma.zero_beta.hidden", h);
460
461        if self.prefill_hidden {
462            flow = flow.plugin_named("gemma.prefill_hidden_bind", move |emit, _| {
463                let hidden = emit
464                    .flow_input("prefill_hidden")
465                    .map_err(|e| anyhow::anyhow!("prefill_hidden input: {e}"))?;
466                // Tied LM head still needs the embedding table in params.
467                let _ = emit.load_param("model.embed_tokens.weight", false)?;
468                Ok(Some(hidden))
469            });
470        } else {
471            flow = flow
472                .token_embed()
473                .raw_stage(FlowStage::EmbedScale(EmbedScaleStage::new(h)));
474        }
475
476        flow = flow.raw_stages(self.before_layers.iter().cloned());
477
478        if let Some(g) = &global_rope {
479            flow = flow.raw_stage(FlowStage::RopeTables(RopeTablesStage::param_named(
480                "global",
481                cfg.max_position_embeddings,
482                g.half_dim,
483                g.cos.clone(),
484                g.sin.clone(),
485            )));
486        }
487
488        let layer_fn = self.layer_fn.clone();
489        let export = self.with_kv_outputs;
490        let media_bias = self.media_attn_bias;
491        let num_heads = cfg.num_attention_heads;
492        let num_layers = cfg.active_num_layers();
493        let layer_attn: Vec<_> = (0..num_layers).map(|i| cfg.layer_attn_options(i)).collect();
494        // PLAN.md M2 — Gemma 4 MoE (`gemma4-26b-a4b`) routes the FFN
495        // through `MoeFfnStage` via the upstream
496        // `gemma_moe_prefill_layer_composed` helper. Dense Gemma
497        // (`is_moe() == false`) keeps the existing default stage.
498        let is_moe = cfg.is_moe();
499        let moe_num_experts = cfg.num_experts;
500        let moe_top_k = cfg.num_experts_used;
501        let moe_n_embd = cfg.hidden_size;
502        let moe_n_ff = cfg.expert_ffn_dim();
503        // Gemma 4 12B varies (head_dim, num_kv_heads, n_rot) across
504        // layers — sliding layers stay at the base shape, global
505        // (full-attention) layers may override. For Gemma <=3 every
506        // accessor returns the uniform value so the closure is a
507        // no-op shape-wise.
508        let per_layer: Vec<PerLayerAttn> = (0..num_layers)
509            .map(|i| PerLayerAttn {
510                head_dim: cfg.layer_head_dim(i),
511                num_kv_heads: cfg.layer_num_kv_heads(i),
512                n_rot: cfg.layer_n_rot(i),
513                rope_table: if cfg.is_full_attention_layer(i) && global_rope.is_some() {
514                    Some("global".to_string())
515                } else {
516                    None
517                },
518                k_eq_v: cfg.attention_k_eq_v,
519            })
520            .collect();
521        flow = flow.repeat_layers(num_layers, {
522            let style = layer_style;
523            let sink = kv_sink.clone();
524            move |i| {
525                let (mask, score_scale, softcap) = layer_attn[i];
526                let pl = &per_layer[i];
527                let lh = pl.head_dim;
528                let mut attn = gemma_attn_spec(
529                    i,
530                    num_heads,
531                    pl.head_dim,
532                    pl.num_kv_heads,
533                    pl.n_rot,
534                    mask,
535                    score_scale,
536                    softcap,
537                );
538                if let Some(name) = pl.rope_table.as_ref() {
539                    attn = attn.with_rope_table(name);
540                }
541                if pl.k_eq_v {
542                    attn = attn.with_k_eq_v();
543                }
544                if let Some(ref f) = layer_fn {
545                    return f(GemmaLayerCtx::Prefill {
546                        index: i,
547                        style,
548                        attn: attn.clone(),
549                        kv_sink: &sink,
550                        export_kv: export,
551                        head_dim: lh,
552                        eps,
553                    });
554                }
555                if media_bias {
556                    return crate::multimodal_flow::multimodal_layer_override(
557                        GemmaLayerCtx::Prefill {
558                            index: i,
559                            style,
560                            attn,
561                            kv_sink: &sink,
562                            export_kv: export,
563                            head_dim: lh,
564                            eps,
565                        },
566                        true,
567                    );
568                }
569                if is_moe {
570                    let prefix = format!("model.layers.{i}");
571                    let moe = rlx_flow::blocks::MoeFfnStage::hf(
572                        prefix,
573                        moe_num_experts,
574                        moe_top_k,
575                        moe_n_embd,
576                        moe_n_ff,
577                    );
578                    let kv = if export { Some(sink.inner()) } else { None };
579                    return rlx_flow::blocks::gemma_moe_prefill_layer_composed(
580                        i, style, attn, eps, kv, moe,
581                    );
582                }
583                GemmaLayerCtx::Prefill {
584                    index: i,
585                    style,
586                    attn,
587                    kv_sink: &sink,
588                    export_kv: export,
589                    head_dim: lh,
590                    eps,
591                }
592                .default_stage()
593            }
594        });
595
596        flow = flow.raw_stages(self.after_layers.iter().cloned());
597
598        if self.with_lm_head && self.last_logits_only {
599            flow = if self.dynamic_seq {
600                flow.gather_last_token_dynamic(self.batch)
601            } else {
602                flow.gather_last_token_at(self.batch, self.seq)
603            };
604        }
605
606        flow = flow.raw_stage(FlowStage::GemmaRmsNorm(GemmaRmsNormStage::hf_layer(
607            "model.norm",
608            eps,
609        )));
610
611        if let Some(patch) = self.flow_patch {
612            flow = patch(flow);
613        }
614
615        let mut built = if self.with_lm_head {
616            let lm = if cfg.tie_word_embeddings {
617                FlowStage::LmHead(LmHeadStage::tied(cfg.vocab_size, h))
618            } else {
619                FlowStage::LmHead(LmHeadStage::separate("lm_head.weight", cfg.vocab_size, h))
620            };
621            flow = flow.raw_stage(lm);
622            if let Some(cap) = cfg.final_logit_softcapping {
623                flow = flow.raw_stage(FlowStage::LogitSoftcap(LogitSoftcapStage::new(cap)));
624            }
625            flow.output("logits")
626                .build(&mut WeightLoaderSource(weights))?
627        } else {
628            flow.output("hidden")
629                .build(&mut WeightLoaderSource(weights))?
630        };
631
632        if self.with_kv_outputs {
633            built = built.with_extra_hir_outputs(kv_sink.drain());
634        }
635        Ok(built)
636    }
637
638    fn build_decode(self, weights: &mut dyn WeightLoader) -> Result<BuiltModel> {
639        let cfg = self.cfg;
640        let profile = self.profile.unwrap_or_else(CompileProfile::gemma_decode);
641        let f = DType::F32;
642        let h = cfg.hidden_size;
643        let eps = cfg.rms_norm_eps as f32;
644        let dh = cfg.head_dim();
645        let half = dh / 2;
646
647        let hidden_shape = Shape::new(&[self.batch, 1, h], f);
648
649        let decode_style = cfg.layer_style();
650        let decode_score_scale = cfg.attn_score_scale();
651        let decode_softcap = cfg.attn_logit_softcapping;
652        let decode_arch = cfg.arch;
653        let decode_sliding = cfg.sliding_window;
654
655        let kv_out = SideOutputs::new();
656
657        let rope_factors = weights.take("rope_freqs.weight").ok().map(|(data, _)| data);
658        let inv_freq = resolve_inv_freq(cfg, rope_factors.as_deref());
659        let (rope_cos, rope_sin) = if self.dynamic_past {
660            (Vec::new(), Vec::new())
661        } else {
662            crate::rope::rope_slice(&inv_freq, self.past_seq)
663        };
664
665        // Static-past mode bakes the per-step cos/sin row as a const.
666        // Dynamic-past mode promotes both default and (Gemma 4) global
667        // rope rows to graph inputs so the runner can supply them at
668        // step-time.
669        let global_rope_row = if !self.dynamic_past {
670            secondary_rope_row(cfg, self.past_seq, rope_factors.as_deref())
671        } else {
672            None
673        };
674        let global_params = needs_secondary_rope_params(cfg);
675
676        let mut flow = ModelFlow::new("gemma_decode")
677            .with_profile(profile)
678            .input("input_ids", Shape::new(&[self.batch, 1], DType::F32));
679
680        if self.dynamic_past {
681            flow = flow
682                .input("rope_cos", Shape::new(&[1, half], f))
683                .input("rope_sin", Shape::new(&[1, half], f));
684            if let Some(gp) = global_params {
685                let half_global =
686                    crate::rope::resolve_global_inv_freq(cfg, rope_factors.as_deref())
687                        .map(|v| v.len())
688                        .unwrap_or_else(|| crate::rope::default_inv_freq(gp.theta, gp.n_rot).len());
689                flow = flow
690                    .input("rope_cos_global", Shape::new(&[1, half_global], f))
691                    .input("rope_sin_global", Shape::new(&[1, half_global], f))
692                    .raw_stage(FlowStage::Custom(rlx_flow::blocks::CustomStage::named(
693                        "gemma.bind_global_decode_rope",
694                        |emit, val| {
695                            // Find the freshly declared inputs in the
696                            // HIR and publish them under the "global"
697                            // slot so per-layer dispatch resolves
698                            // state.named["global_cos"]/_sin.
699                            let cos = find_hir_input(emit.hir(), "rope_cos_global")?;
700                            let sin = find_hir_input(emit.hir(), "rope_sin_global")?;
701                            emit.set_named("global_cos", cos);
702                            emit.set_named("global_sin", sin);
703                            Ok(val)
704                        },
705                    )));
706            }
707        }
708
709        if self.use_custom_mask {
710            flow = flow.input("mask", Shape::new(&[self.batch, self.past_seq + 1], f));
711        }
712
713        // Per-layer past-K/V shapes — sliding layers ship the base
714        // num_kv_heads * head_dim, full-attention layers may ship a
715        // smaller (Gemma 4 12B: 1 * 512 = 512 instead of 8 * 256 =
716        // 2048) cache slot.
717        for layer_idx in 0..cfg.num_hidden_layers {
718            let layer_kv_dim = cfg.layer_num_kv_heads(layer_idx) * cfg.layer_head_dim(layer_idx);
719            let shape = if self.dynamic_past {
720                Shape::from_dims(
721                    &[
722                        Dim::Static(self.batch),
723                        Dim::Dynamic(sym::PAST_SEQ),
724                        Dim::Static(layer_kv_dim),
725                    ],
726                    f,
727                )
728            } else {
729                Shape::new(&[self.batch, self.past_seq, layer_kv_dim], f)
730            };
731            flow = flow
732                .input(format!("past_k_{layer_idx}"), shape.clone())
733                .input(format!("past_v_{layer_idx}"), shape);
734        }
735
736        if !self.dynamic_past {
737            flow = flow.raw_stage(FlowStage::DecodeRopeParams(DecodeRopeParamsStage::new(
738                rope_cos, rope_sin, half,
739            )));
740            if let Some(g) = &global_rope_row {
741                flow = flow.raw_stage(FlowStage::DecodeRopeParams(DecodeRopeParamsStage::named(
742                    "global",
743                    g.cos.clone(),
744                    g.sin.clone(),
745                    g.half_dim,
746                )));
747            }
748        }
749
750        flow = flow
751            .bind_decode_inputs(cfg.num_hidden_layers, self.use_custom_mask)
752            .zero_beta_named("gemma.zero_beta.hidden", h)
753            .token_embed()
754            .raw_stage(FlowStage::EmbedScale(EmbedScaleStage::new(h)))
755            .raw_stages(self.before_layers.iter().cloned());
756
757        let layer_fn = self.layer_fn.clone();
758        let use_custom_mask = self.use_custom_mask;
759        let num_heads = cfg.num_attention_heads;
760        let num_layers = cfg.active_num_layers();
761        // Per-layer (head_dim, kv_heads, n_rot) — uniform on Gemma <=3,
762        // diverges on Gemma 4 12B's full-attention layers.
763        let secondary_rope_active = global_rope_row.is_some();
764        let per_layer_decode: Vec<PerLayerAttn> = (0..num_layers)
765            .map(|i| PerLayerAttn {
766                head_dim: cfg.layer_head_dim(i),
767                num_kv_heads: cfg.layer_num_kv_heads(i),
768                n_rot: cfg.layer_n_rot(i),
769                rope_table: if cfg.is_full_attention_layer(i) && secondary_rope_active {
770                    Some("global".to_string())
771                } else {
772                    None
773                },
774                k_eq_v: cfg.attention_k_eq_v,
775            })
776            .collect();
777        // PLAN.md M2 — Gemma 4 MoE (`gemma4-26b-a4b`) decode-side dispatch.
778        let is_moe = cfg.is_moe();
779        let moe_num_experts = cfg.num_experts;
780        let moe_top_k = cfg.num_experts_used;
781        let moe_n_embd = cfg.hidden_size;
782        let moe_n_ff = cfg.expert_ffn_dim();
783        flow = flow.repeat_layers(num_layers, {
784            let sink = kv_out.clone();
785            let hidden_shape = hidden_shape.clone();
786            move |i| {
787                let mask = if use_custom_mask {
788                    rlx_ir::op::MaskKind::Causal
789                } else {
790                    match (decode_arch, decode_sliding) {
791                        (GemmaArch::Gemma2, Some(w)) => rlx_flow::blocks::gemma2_layer_mask(i, w),
792                        // PLAN.md M2 — Gemma 3 / 4 use the strided
793                        // `sliding_window_pattern` (5 sliding + 1
794                        // full for stride 6).
795                        (GemmaArch::Gemma3 | GemmaArch::Gemma4, Some(w)) => {
796                            rlx_flow::blocks::gemma_strided_layer_mask(
797                                i,
798                                w,
799                                decode_arch.sliding_window_stride(),
800                            )
801                        }
802                        _ => rlx_ir::op::MaskKind::Causal,
803                    }
804                };
805                let pl = &per_layer_decode[i];
806                let kv_group_size = num_heads / pl.num_kv_heads;
807                let spec = GemmaDecodeLayerSpec {
808                    style: decode_style,
809                    num_heads,
810                    head_dim: pl.head_dim,
811                    num_kv_heads: pl.num_kv_heads,
812                    kv_group_size,
813                    n_rot: pl.n_rot,
814                    rope_table: pl.rope_table.clone(),
815                    k_eq_v: pl.k_eq_v,
816                    eps,
817                    use_custom_mask,
818                    hidden_shape: hidden_shape.clone(),
819                    mask,
820                    score_scale: decode_score_scale,
821                    attn_logit_softcap: decode_softcap,
822                };
823                if let Some(ref f) = layer_fn {
824                    return f(GemmaLayerCtx::Decode {
825                        index: i,
826                        spec: spec.clone(),
827                        kv_out: &sink,
828                    });
829                }
830                if is_moe {
831                    let prefix = format!("model.layers.{i}");
832                    let moe = rlx_flow::blocks::MoeFfnStage::hf(
833                        prefix,
834                        moe_num_experts,
835                        moe_top_k,
836                        moe_n_embd,
837                        moe_n_ff,
838                    );
839                    return rlx_flow::blocks::gemma_moe_decode_layer_composed(
840                        i,
841                        spec,
842                        sink.inner(),
843                        moe,
844                    );
845                }
846                GemmaLayerCtx::Decode {
847                    index: i,
848                    spec,
849                    kv_out: &sink,
850                }
851                .default_stage()
852            }
853        });
854
855        flow = flow.raw_stages(self.after_layers.iter().cloned());
856
857        if let Some(patch) = self.flow_patch {
858            flow = patch(flow);
859        }
860
861        let mut flow = flow.raw_stage(FlowStage::GemmaRmsNorm(GemmaRmsNormStage::hf_layer(
862            "model.norm",
863            eps,
864        )));
865        let lm = if cfg.tie_word_embeddings {
866            FlowStage::LmHead(LmHeadStage::tied(cfg.vocab_size, h))
867        } else {
868            FlowStage::LmHead(LmHeadStage::separate("lm_head.weight", cfg.vocab_size, h))
869        };
870        flow = flow.raw_stage(lm);
871        if let Some(cap) = cfg.final_logit_softcapping {
872            flow = flow.raw_stage(FlowStage::LogitSoftcap(LogitSoftcapStage::new(cap)));
873        }
874        let built = flow
875            .output("logits")
876            .build(&mut WeightLoaderSource(weights))?
877            .with_extra_hir_outputs(kv_out.drain());
878
879        Ok(built)
880    }
881}
882
883fn prefill_hidden_shape(
884    batch: usize,
885    seq: usize,
886    hidden: usize,
887    dynamic: bool,
888    dtype: DType,
889) -> Shape {
890    if dynamic {
891        Shape::from_dims(
892            &[
893                Dim::Static(batch),
894                Dim::Dynamic(sym::SEQ),
895                Dim::Static(hidden),
896            ],
897            dtype,
898        )
899    } else {
900        Shape::new(&[batch, seq, hidden], dtype)
901    }
902}
903
904fn prefill_input_shape(batch: usize, seq: usize, dynamic: bool) -> Shape {
905    if dynamic {
906        Shape::from_dims(&[Dim::Static(batch), Dim::Dynamic(sym::SEQ)], DType::F32)
907    } else {
908        Shape::new(&[batch, seq], DType::F32)
909    }
910}
911
912/// Per-layer attention dimensions cached at flow-build time. Uniform
913/// across layers for Gemma <=3; diverges on Gemma 4 unified where
914/// full-attention layers may carry different (head_dim, kv_heads,
915/// n_rot) and a secondary RoPE table.
916#[derive(Debug, Clone)]
917struct PerLayerAttn {
918    head_dim: usize,
919    num_kv_heads: usize,
920    n_rot: usize,
921    rope_table: Option<String>,
922    k_eq_v: bool,
923}
924
925#[derive(Debug, Clone)]
926struct GlobalRopeTables {
927    cos: Vec<f32>,
928    sin: Vec<f32>,
929    half_dim: usize,
930}
931
932/// Build the Gemma 4 "global" (full-attention) RoPE table when the
933/// unified config carries a distinct rope_theta or
934/// partial_rotary_factor for full-attention layers. Returns `None`
935/// for Gemma <=3 and for Gemma 4 configs that omit the split.
936fn secondary_rope_tables(
937    cfg: &GemmaConfig,
938    max_pos: usize,
939    factors: Option<&[f32]>,
940) -> Option<GlobalRopeTables> {
941    let inv = crate::rope::resolve_global_inv_freq(cfg, factors)?;
942    let (cos, sin) = crate::rope::build_rope_tables(&inv, max_pos);
943    Some(GlobalRopeTables {
944        cos,
945        sin,
946        half_dim: inv.len(),
947    })
948}
949
950/// One-position decode row for the global RoPE.
951fn secondary_rope_row(
952    cfg: &GemmaConfig,
953    pos: usize,
954    factors: Option<&[f32]>,
955) -> Option<GlobalRopeTables> {
956    let inv = crate::rope::resolve_global_inv_freq(cfg, factors)?;
957    let (cos, sin) = crate::rope::rope_slice(&inv, pos);
958    Some(GlobalRopeTables {
959        cos,
960        sin,
961        half_dim: inv.len(),
962    })
963}
964
965fn needs_secondary_rope_params(cfg: &GemmaConfig) -> Option<GlobalRopeParams> {
966    crate::rope::global_rope_params(cfg).map(|(theta, n_rot)| GlobalRopeParams { theta, n_rot })
967}
968
969#[derive(Debug, Clone, Copy)]
970struct GlobalRopeParams {
971    theta: f64,
972    n_rot: usize,
973}
974
975fn find_hir_input(hir: &HirModule, name: &str) -> anyhow::Result<rlx_ir::HirNodeId> {
976    use rlx_ir::hir::HirOp;
977    for node in hir.nodes() {
978        if let HirOp::Input { name: n } = &node.op {
979            if n == name {
980                return Ok(node.id);
981            }
982        }
983    }
984    Err(anyhow::anyhow!("gemma decode flow missing input: {name}"))
985}
986
987// ── Legacy opt structs + thin wrappers (backward compatible) ─────────
988
989impl<'a> GemmaFlow<'a> {
990    fn from_prefill_opts(cfg: &'a GemmaConfig, o: &GemmaPrefillOpts) -> Self {
991        let mut f = GemmaFlow::new(cfg).prefill().batch(o.batch).seq(o.seq);
992        if o.dynamic_seq {
993            f = f.dynamic_seq();
994        }
995        if o.prefill_hidden {
996            f = f.prefill_from_hidden();
997        }
998        if o.media_attn_bias {
999            f = f.prefill_media_attn_bias();
1000        }
1001        if o.with_lm_head {
1002            f = f.lm_head();
1003        }
1004        if o.with_kv_outputs {
1005            f = f.export_kv();
1006        }
1007        if o.last_logits_only {
1008            f = f.last_token_logits();
1009        }
1010        if let Some(p) = o.profile.clone() {
1011            f = f.profile(p);
1012        }
1013        f
1014    }
1015
1016    fn from_decode_opts(cfg: &'a GemmaConfig, o: &GemmaDecodeOpts) -> Self {
1017        let mut f = GemmaFlow::new(cfg)
1018            .decode()
1019            .batch(o.batch)
1020            .past(o.past_seq)
1021            .lm_head();
1022        if o.dynamic_past {
1023            f = f.dynamic_past();
1024        }
1025        if o.use_custom_mask {
1026            f = f.custom_mask();
1027        }
1028        if let Some(p) = o.profile.clone() {
1029            f = f.profile(p);
1030        }
1031        f
1032    }
1033}
1034
1035/// Options for the tier-0 Gemma prefill assembly line.
1036#[derive(Debug, Clone)]
1037pub struct GemmaPrefillOpts {
1038    pub batch: usize,
1039    pub seq: usize,
1040    pub dynamic_seq: bool,
1041    pub prefill_hidden: bool,
1042    pub media_attn_bias: bool,
1043    pub with_lm_head: bool,
1044    pub with_kv_outputs: bool,
1045    pub last_logits_only: bool,
1046    pub profile: Option<CompileProfile>,
1047}
1048
1049impl GemmaPrefillOpts {
1050    pub fn static_prefill(batch: usize, seq: usize) -> Self {
1051        Self {
1052            batch,
1053            seq,
1054            dynamic_seq: false,
1055            prefill_hidden: false,
1056            media_attn_bias: false,
1057            with_lm_head: false,
1058            with_kv_outputs: false,
1059            last_logits_only: false,
1060            profile: None,
1061        }
1062    }
1063}
1064
1065/// Options for tier-0 Gemma decode (KV-cache) assembly line.
1066#[derive(Debug, Clone)]
1067pub struct GemmaDecodeOpts {
1068    pub batch: usize,
1069    pub past_seq: usize,
1070    pub dynamic_past: bool,
1071    pub use_custom_mask: bool,
1072    pub profile: Option<CompileProfile>,
1073}
1074
1075pub fn build_gemma_prefill_flow(
1076    cfg: &GemmaConfig,
1077    weights: &mut dyn WeightLoader,
1078    opts: &GemmaPrefillOpts,
1079) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
1080    build_gemma_prefill_built(cfg, weights, opts)?.into_parts()
1081}
1082
1083pub fn build_gemma_prefill_built(
1084    cfg: &GemmaConfig,
1085    weights: &mut dyn WeightLoader,
1086    opts: &GemmaPrefillOpts,
1087) -> Result<BuiltModel> {
1088    GemmaFlow::from_prefill_opts(cfg, opts).build(weights)
1089}
1090
1091pub fn build_gemma_decode_flow(
1092    cfg: &GemmaConfig,
1093    weights: &mut dyn WeightLoader,
1094    opts: &GemmaDecodeOpts,
1095) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
1096    build_gemma_decode_built(cfg, weights, opts)?.into_parts()
1097}
1098
1099pub fn build_gemma_decode_graph(
1100    cfg: &GemmaConfig,
1101    weights: &mut dyn WeightLoader,
1102    opts: &GemmaDecodeOpts,
1103) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
1104    rlx_core::flow_util::graph_from_built(build_gemma_decode_built(cfg, weights, opts)?)
1105}
1106
1107pub fn build_gemma_decode_built(
1108    cfg: &GemmaConfig,
1109    weights: &mut dyn WeightLoader,
1110    opts: &GemmaDecodeOpts,
1111) -> Result<BuiltModel> {
1112    GemmaFlow::from_decode_opts(cfg, opts).build(weights)
1113}
1114
1115#[cfg(test)]
1116mod gemma4_tests {
1117    use super::*;
1118    use crate::config::{
1119        GemmaArch, GemmaLayerType, GemmaRopeKind, GemmaRopeMap, GemmaRopeParameters,
1120    };
1121
1122    fn gemma4_12b_like() -> GemmaConfig {
1123        let mut cfg = GemmaConfig::tiny_test();
1124        cfg.arch = GemmaArch::Gemma4;
1125        cfg.hidden_size = 3840;
1126        cfg.intermediate_size = 15_360;
1127        cfg.num_hidden_layers = 12; // small for tests
1128        cfg.num_attention_heads = 16;
1129        cfg.num_key_value_heads = 8;
1130        cfg.head_dim = Some(256);
1131        cfg.global_head_dim = Some(512);
1132        cfg.num_global_key_value_heads = Some(1);
1133        cfg.attention_k_eq_v = true;
1134        cfg.sliding_window = Some(1024);
1135        cfg.final_logit_softcapping = Some(30.0);
1136        cfg.tie_word_embeddings = true;
1137        cfg.max_position_embeddings = 4096;
1138        cfg.rope_theta = 10_000.0;
1139        // Stride-6 pattern: every 6th layer (1-indexed) is full.
1140        cfg.layer_types = (0..cfg.num_hidden_layers)
1141            .map(|i| {
1142                if (i + 1) % 6 == 0 {
1143                    GemmaLayerType::FullAttention
1144                } else {
1145                    GemmaLayerType::SlidingAttention
1146                }
1147            })
1148            .collect();
1149        cfg.rope_parameters = GemmaRopeMap {
1150            sliding_attention: Some(GemmaRopeParameters {
1151                rope_theta: Some(10_000.0),
1152                rope_type: Some(GemmaRopeKind::Default),
1153                partial_rotary_factor: None,
1154            }),
1155            full_attention: Some(GemmaRopeParameters {
1156                rope_theta: Some(1_000_000.0),
1157                rope_type: Some(GemmaRopeKind::Proportional),
1158                partial_rotary_factor: Some(0.25),
1159            }),
1160        };
1161        cfg
1162    }
1163
1164    #[test]
1165    fn secondary_rope_emits_distinct_table_for_full_attention() {
1166        let cfg = gemma4_12b_like();
1167        let tables = secondary_rope_tables(&cfg, cfg.max_position_embeddings, None)
1168            .expect("Gemma 4 split rope_parameters should produce a secondary table");
1169        // n_rot = 512 * 0.25 = 128 → half = 64.
1170        assert_eq!(tables.half_dim, 64);
1171        assert_eq!(tables.cos.len(), cfg.max_position_embeddings * 64);
1172        assert_eq!(tables.sin.len(), tables.cos.len());
1173
1174        // pos=0 row is always (1, 0) regardless of theta.
1175        assert!((tables.cos[0] - 1.0).abs() < 1e-6);
1176        assert!(tables.sin[0].abs() < 1e-6);
1177        // The frequency exponent kicks in for dim>=1: at pos=1, dim=5
1178        // the two thetas should produce different cos values.
1179        let global_inv = crate::rope::default_inv_freq(1_000_000.0, 128);
1180        let sliding_inv = crate::rope::default_inv_freq(10_000.0, 128);
1181        assert!((global_inv[5] - sliding_inv[5]).abs() > 1e-3);
1182        let global_cos_p1_d5 = (1.0 * global_inv[5]).cos();
1183        let global_sample = tables.cos[64 + 5]; // pos=1, dim=5 (stride 64)
1184        assert!((global_sample as f64 - global_cos_p1_d5).abs() < 1e-5);
1185    }
1186
1187    #[test]
1188    fn per_layer_kv_dims_diverge_on_full_attention() {
1189        let cfg = gemma4_12b_like();
1190        // Sliding: 8 heads * 256 = 2048.
1191        assert_eq!(cfg.layer_num_kv_heads(0) * cfg.layer_head_dim(0), 2048);
1192        // Full: 1 head * 512 = 512.
1193        assert_eq!(cfg.layer_num_kv_heads(5) * cfg.layer_head_dim(5), 512);
1194        assert_eq!(cfg.layer_num_kv_heads(11) * cfg.layer_head_dim(11), 512);
1195    }
1196
1197    #[test]
1198    fn no_secondary_table_when_params_match() {
1199        // Gemma 3-shape (uniform rope) — secondary table should not
1200        // be emitted even if arch is Gemma4 (e.g. a tuned variant
1201        // with collapsed rope_parameters).
1202        let mut cfg = gemma4_12b_like();
1203        cfg.rope_parameters.full_attention = cfg.rope_parameters.sliding_attention;
1204        cfg.global_head_dim = None;
1205        cfg.num_global_key_value_heads = None;
1206        assert!(secondary_rope_tables(&cfg, cfg.max_position_embeddings, None).is_none());
1207    }
1208}