Skip to main content

rlx_gemma/
flow.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Fluent Gemma model assembly — tier-0 reference for `rlx-flow`.
17//!
18//! ```rust,ignore
19//! use rlx_models::gemma::GemmaFlow;
20//!
21//! // Prefill logits for the last token
22//! let built = GemmaFlow::for_prefill(&cfg, 1, 128)
23//!     .last_token_logits()
24//!     .profile_near(&weights_path)
25//!     .build(&mut weights)?;
26//!
27//! // Decode step with KV side outputs
28//! let built = GemmaFlow::for_decode(&cfg, 1, 256)
29//!     .custom_mask()
30//!     .profile_decode()
31//!     .build(&mut weights)?;
32//!
33//! // Override one layer while keeping the rest of the recipe
34//! let built = GemmaFlow::for_prefill(&cfg, 1, 128)
35//!     .layer(|ctx| {
36//!         if ctx.index() == 0 {
37//!             ctx.default_stage() // or FlowStage::Custom(...)
38//!         } else {
39//!             ctx.default_stage()
40//!         }
41//!     })
42//!     .build(&mut weights)?;
43//! ```
44
45use std::collections::HashMap;
46use std::fmt;
47use std::path::Path;
48use std::sync::Arc;
49
50use anyhow::Result;
51use rlx_flow::blocks::{
52    DecodeRopeParamsStage, EmbedScaleStage, GemmaDecodeLayerSpec, GemmaDecodeLayerStage,
53    GemmaLayerStyle, GemmaRmsNormStage, LmHeadStage, LogitSoftcapStage, RopeTablesStage,
54    gemma_attn_spec, gemma_prefill_layer_composed,
55};
56use rlx_flow::{BuiltModel, CompileProfile, FlowStage, ModelFlow, SideOutputs};
57use rlx_ir::dynamic::sym;
58use rlx_ir::hir::HirModule;
59use rlx_ir::shape::Dim;
60use rlx_ir::{DType, Graph, Shape};
61
62use super::config::{GemmaArch, GemmaConfig};
63use super::rope::{build_rope_tables, resolve_inv_freq};
64use rlx_core::flow_bridge::{WeightLoaderSource, load_compile_profile};
65use rlx_core::weight_loader::WeightLoader;
66
67/// Tier-1 profile file name colocated with weights.
68pub const GEMMA_PROFILE_FILE: &str = "gemma.rlx.toml";
69
70/// Resolve compile profile from `gemma.rlx.toml` in the weights directory.
71pub fn gemma_profile_near_weights(weights: &Path, decode: bool) -> CompileProfile {
72    let default = if decode {
73        CompileProfile::gemma_decode()
74    } else {
75        CompileProfile::gemma_prefill()
76    };
77    let dir = weights.parent().unwrap_or_else(|| Path::new("."));
78    load_compile_profile(&dir.join(GEMMA_PROFILE_FILE), default)
79}
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
82pub enum GemmaMode {
83    Prefill,
84    Decode,
85}
86
87/// Per-layer context for `.layer()` overrides — defaults preserve stock Gemma blocks.
88pub enum GemmaLayerCtx<'a> {
89    Prefill {
90        index: usize,
91        style: GemmaLayerStyle,
92        attn: rlx_flow::blocks::SelfAttnPrefillSpec,
93        kv_sink: &'a SideOutputs,
94        export_kv: bool,
95        head_dim: usize,
96        eps: f32,
97    },
98    Decode {
99        index: usize,
100        spec: GemmaDecodeLayerSpec,
101        kv_out: &'a SideOutputs,
102    },
103}
104
105impl GemmaLayerCtx<'_> {
106    pub fn index(&self) -> usize {
107        match self {
108            Self::Prefill { index, .. } | Self::Decode { index, .. } => *index,
109        }
110    }
111
112    pub fn default_stage(&self) -> FlowStage {
113        match self {
114            Self::Prefill {
115                index,
116                style,
117                attn,
118                kv_sink,
119                export_kv,
120                head_dim: _,
121                eps,
122            } => gemma_prefill_layer_composed(
123                *index,
124                *style,
125                attn.clone(),
126                *eps,
127                if *export_kv {
128                    Some(kv_sink.inner())
129                } else {
130                    None
131                },
132            ),
133            Self::Decode {
134                index,
135                spec,
136                kv_out,
137            } => FlowStage::Named {
138                name: format!("layer{index}"),
139                inner: Arc::new(FlowStage::GemmaDecodeLayer(GemmaDecodeLayerStage::layer(
140                    *index,
141                    spec.clone(),
142                    kv_out.inner(),
143                ))),
144            },
145        }
146    }
147}
148
149type LayerFn = Arc<dyn Fn(GemmaLayerCtx<'_>) -> FlowStage + Send + Sync>;
150type FlowPatchFn = Arc<dyn Fn(ModelFlow) -> ModelFlow + Send + Sync>;
151
152/// Fluent Gemma flow builder — reads config once, chain modifiers, then `build`.
153///
154/// ```rust,ignore
155/// use rlx_models::gemma::{GemmaConfig, GemmaFlow};
156///
157/// let built = GemmaFlow::new(&cfg)
158///     .prefill()
159///     .batch(1)
160///     .seq(128)
161///     .lm_head()
162///     .last_token_logits()
163///     .build(&mut weights)?;
164/// ```
165#[derive(Clone)]
166pub struct GemmaFlow<'a> {
167    cfg: &'a GemmaConfig,
168    mode: GemmaMode,
169    batch: usize,
170    seq: usize,
171    past_seq: usize,
172    dynamic_seq: bool,
173    dynamic_past: bool,
174    with_lm_head: bool,
175    with_kv_outputs: bool,
176    last_logits_only: bool,
177    use_custom_mask: bool,
178    profile: Option<CompileProfile>,
179    before_layers: Vec<FlowStage>,
180    after_layers: Vec<FlowStage>,
181    layer_fn: Option<LayerFn>,
182    flow_patch: Option<FlowPatchFn>,
183}
184
185impl fmt::Debug for GemmaFlow<'_> {
186    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
187        f.debug_struct("GemmaFlow")
188            .field("mode", &self.mode)
189            .field("batch", &self.batch)
190            .field("seq", &self.seq)
191            .field("past_seq", &self.past_seq)
192            .field("dynamic_seq", &self.dynamic_seq)
193            .field("dynamic_past", &self.dynamic_past)
194            .field("with_lm_head", &self.with_lm_head)
195            .field("with_kv_outputs", &self.with_kv_outputs)
196            .field("last_logits_only", &self.last_logits_only)
197            .field("use_custom_mask", &self.use_custom_mask)
198            .field("profile", &self.profile)
199            .field("before_layers", &self.before_layers.len())
200            .field("after_layers", &self.after_layers.len())
201            .field("layer_fn", &self.layer_fn.is_some())
202            .field("flow_patch", &self.flow_patch.is_some())
203            .finish_non_exhaustive()
204    }
205}
206
207impl<'a> GemmaFlow<'a> {
208    pub fn new(cfg: &'a GemmaConfig) -> Self {
209        Self {
210            cfg,
211            mode: GemmaMode::Prefill,
212            batch: 1,
213            seq: 128,
214            past_seq: 0,
215            dynamic_seq: false,
216            dynamic_past: false,
217            with_lm_head: false,
218            with_kv_outputs: false,
219            last_logits_only: false,
220            use_custom_mask: false,
221            profile: None,
222            before_layers: Vec::new(),
223            after_layers: Vec::new(),
224            layer_fn: None,
225            flow_patch: None,
226        }
227    }
228
229    /// Prefill recipe with common batch/seq defaults.
230    pub fn for_prefill(cfg: &'a GemmaConfig, batch: usize, seq: usize) -> Self {
231        Self::new(cfg).prefill().batch(batch).seq(seq)
232    }
233
234    /// Decode recipe with common batch/past defaults (includes LM head).
235    pub fn for_decode(cfg: &'a GemmaConfig, batch: usize, past_seq: usize) -> Self {
236        Self::new(cfg)
237            .decode()
238            .batch(batch)
239            .past(past_seq)
240            .lm_head()
241    }
242
243    pub fn prefill(mut self) -> Self {
244        self.mode = GemmaMode::Prefill;
245        self
246    }
247
248    pub fn decode(mut self) -> Self {
249        self.mode = GemmaMode::Decode;
250        self
251    }
252
253    pub fn batch(mut self, batch: usize) -> Self {
254        self.batch = batch;
255        self
256    }
257
258    /// Prefill sequence length (ignored in decode mode).
259    pub fn seq(mut self, seq: usize) -> Self {
260        self.seq = seq;
261        self
262    }
263
264    /// Decode past length (ignored in prefill mode).
265    pub fn past(mut self, past_seq: usize) -> Self {
266        self.past_seq = past_seq;
267        self
268    }
269
270    /// Symbolic sequence dim (`sym::SEQ`) for dynamic prefill specialization.
271    pub fn dynamic_seq(mut self) -> Self {
272        self.dynamic_seq = true;
273        self
274    }
275
276    /// Symbolic past dim (`sym::PAST_SEQ`) for dynamic decode specialization.
277    pub fn dynamic_past(mut self) -> Self {
278        self.dynamic_past = true;
279        self
280    }
281
282    pub fn lm_head(mut self) -> Self {
283        self.with_lm_head = true;
284        self
285    }
286
287    /// Hidden states only — skip LM head (default for prefill unless `.lm_head()`).
288    pub fn hidden_only(mut self) -> Self {
289        self.with_lm_head = false;
290        self.last_logits_only = false;
291        self
292    }
293
294    pub fn last_token_logits(mut self) -> Self {
295        self.with_lm_head = true;
296        self.last_logits_only = true;
297        self
298    }
299
300    pub fn export_kv(mut self) -> Self {
301        self.with_kv_outputs = true;
302        self
303    }
304
305    pub fn custom_mask(mut self) -> Self {
306        self.use_custom_mask = true;
307        self
308    }
309
310    pub fn profile(mut self, profile: CompileProfile) -> Self {
311        self.profile = Some(profile);
312        self
313    }
314
315    /// Fusion-first prefill profile preset.
316    pub fn profile_prefill(mut self) -> Self {
317        self.profile = Some(CompileProfile::gemma_prefill());
318        self
319    }
320
321    pub fn profile_decode(mut self) -> Self {
322        self.profile = Some(CompileProfile::gemma_decode());
323        self
324    }
325
326    pub fn profile_near(mut self, weights_path: &Path) -> Self {
327        let decode = self.mode == GemmaMode::Decode;
328        self.profile = Some(gemma_profile_near_weights(weights_path, decode));
329        self
330    }
331
332    /// Insert custom stages after embedding, before the layer stack.
333    pub fn before_layers(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
334        self.before_layers.extend(stages);
335        self
336    }
337
338    /// Insert custom stages after the layer stack, before final norm / LM head.
339    pub fn after_layers(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
340        self.after_layers.extend(stages);
341        self
342    }
343
344    /// Override per-layer construction (prefill or decode depending on mode).
345    ///
346    /// Call [`GemmaLayerCtx::default_stage`] to keep stock blocks for unmodified layers.
347    pub fn layer<F>(mut self, f: F) -> Self
348    where
349        F: Fn(GemmaLayerCtx<'_>) -> FlowStage + Send + Sync + 'static,
350    {
351        self.layer_fn = Some(Arc::new(f));
352        self
353    }
354
355    /// Patch the assembled [`ModelFlow`] before build — full flexibility escape hatch.
356    pub fn patch_flow<F>(mut self, f: F) -> Self
357    where
358        F: Fn(ModelFlow) -> ModelFlow + Send + Sync + 'static,
359    {
360        self.flow_patch = Some(Arc::new(f));
361        self
362    }
363
364    pub fn build(self, weights: &mut dyn WeightLoader) -> Result<BuiltModel> {
365        match self.mode {
366            GemmaMode::Prefill => self.build_prefill(weights),
367            GemmaMode::Decode => self.build_decode(weights),
368        }
369    }
370
371    fn build_prefill(self, weights: &mut dyn WeightLoader) -> Result<BuiltModel> {
372        if self.dynamic_seq && self.batch != 1 {
373            anyhow::bail!("gemma: dynamic_seq prefill requires batch=1");
374        }
375
376        let cfg = self.cfg;
377        let profile = self.profile.unwrap_or_else(CompileProfile::gemma_prefill);
378        let f = DType::F32;
379        let h = cfg.hidden_size;
380        let eps = cfg.rms_norm_eps as f32;
381        let dh = cfg.head_dim();
382        let layer_style = cfg.layer_style();
383
384        let _hidden_shape = prefill_hidden_shape(self.batch, self.seq, h, self.dynamic_seq, f);
385        let input_shape = prefill_input_shape(self.batch, self.seq, self.dynamic_seq);
386
387        let rope_factors = weights.take("rope_freqs.weight").ok().map(|(data, _)| data);
388        let inv_freq = resolve_inv_freq(cfg, rope_factors.as_deref());
389        let (cos_data, sin_data) = build_rope_tables(&inv_freq, cfg.max_position_embeddings);
390
391        let kv_sink = SideOutputs::new();
392
393        let mut flow = ModelFlow::new("gemma")
394            .with_profile(profile)
395            .input("input_ids", input_shape);
396
397        if self.dynamic_seq && self.with_lm_head && self.last_logits_only {
398            flow = flow.input("last_token_idx", Shape::new(&[self.batch], DType::F32));
399        }
400
401        flow = flow
402            .rope_tables(RopeTablesStage::param(
403                cfg.max_position_embeddings,
404                inv_freq.len(),
405                cos_data,
406                sin_data,
407            ))
408            .zero_beta_named("gemma.zero_beta.hidden", h)
409            .token_embed()
410            .raw_stage(FlowStage::EmbedScale(EmbedScaleStage::new(h)))
411            .raw_stages(self.before_layers.iter().cloned());
412
413        let layer_fn = self.layer_fn.clone();
414        let export = self.with_kv_outputs;
415        let num_heads = cfg.num_attention_heads;
416        let num_kv_heads = cfg.num_key_value_heads;
417        let num_layers = cfg.active_num_layers();
418        let layer_attn: Vec<_> = (0..num_layers).map(|i| cfg.layer_attn_options(i)).collect();
419        // PLAN.md M2 — Gemma 4 MoE (`gemma4-26b-a4b`) routes the FFN
420        // through `MoeFfnStage` via the upstream
421        // `gemma_moe_prefill_layer_composed` helper. Dense Gemma
422        // (`is_moe() == false`) keeps the existing default stage.
423        let is_moe = cfg.is_moe();
424        let moe_num_experts = cfg.num_experts;
425        let moe_top_k = cfg.num_experts_used;
426        let moe_n_embd = cfg.hidden_size;
427        let moe_n_ff = cfg.expert_ffn_dim();
428        flow = flow.repeat_layers(num_layers, {
429            let style = layer_style;
430            let sink = kv_sink.clone();
431            move |i| {
432                let (mask, score_scale, softcap) = layer_attn[i];
433                let attn =
434                    gemma_attn_spec(i, num_heads, dh, num_kv_heads, mask, score_scale, softcap);
435                if let Some(ref f) = layer_fn {
436                    return f(GemmaLayerCtx::Prefill {
437                        index: i,
438                        style,
439                        attn: attn.clone(),
440                        kv_sink: &sink,
441                        export_kv: export,
442                        head_dim: dh,
443                        eps,
444                    });
445                }
446                if is_moe {
447                    let prefix = format!("model.layers.{i}");
448                    let moe = rlx_flow::blocks::MoeFfnStage::hf(
449                        prefix,
450                        moe_num_experts,
451                        moe_top_k,
452                        moe_n_embd,
453                        moe_n_ff,
454                    );
455                    let kv = if export { Some(sink.inner()) } else { None };
456                    return rlx_flow::blocks::gemma_moe_prefill_layer_composed(
457                        i, style, attn, eps, kv, moe,
458                    );
459                }
460                GemmaLayerCtx::Prefill {
461                    index: i,
462                    style,
463                    attn,
464                    kv_sink: &sink,
465                    export_kv: export,
466                    head_dim: dh,
467                    eps,
468                }
469                .default_stage()
470            }
471        });
472
473        flow = flow.raw_stages(self.after_layers.iter().cloned());
474
475        if self.with_lm_head && self.last_logits_only {
476            flow = if self.dynamic_seq {
477                flow.gather_last_token_dynamic(self.batch)
478            } else {
479                flow.gather_last_token_at(self.batch, self.seq)
480            };
481        }
482
483        flow = flow.raw_stage(FlowStage::GemmaRmsNorm(GemmaRmsNormStage::hf_layer(
484            "model.norm",
485            eps,
486        )));
487
488        if let Some(patch) = self.flow_patch {
489            flow = patch(flow);
490        }
491
492        let mut built = if self.with_lm_head {
493            let lm = if cfg.tie_word_embeddings {
494                FlowStage::LmHead(LmHeadStage::tied(cfg.vocab_size, h))
495            } else {
496                FlowStage::LmHead(LmHeadStage::separate("lm_head.weight", cfg.vocab_size, h))
497            };
498            flow = flow.raw_stage(lm);
499            if let Some(cap) = cfg.final_logit_softcapping {
500                flow = flow.raw_stage(FlowStage::LogitSoftcap(LogitSoftcapStage::new(cap)));
501            }
502            flow.output("logits")
503                .build(&mut WeightLoaderSource(weights))?
504        } else {
505            flow.output("hidden")
506                .build(&mut WeightLoaderSource(weights))?
507        };
508
509        if self.with_kv_outputs {
510            built = built.with_extra_hir_outputs(kv_sink.drain());
511        }
512        Ok(built)
513    }
514
515    fn build_decode(self, weights: &mut dyn WeightLoader) -> Result<BuiltModel> {
516        let cfg = self.cfg;
517        let profile = self.profile.unwrap_or_else(CompileProfile::gemma_decode);
518        let f = DType::F32;
519        let h = cfg.hidden_size;
520        let eps = cfg.rms_norm_eps as f32;
521        let dh = cfg.head_dim();
522        let kv_dim = cfg.kv_proj_dim();
523        let half = dh / 2;
524
525        let hidden_shape = Shape::new(&[self.batch, 1, h], f);
526        let past_kv_shape = if self.dynamic_past {
527            Shape::from_dims(
528                &[
529                    Dim::Static(self.batch),
530                    Dim::Dynamic(sym::PAST_SEQ),
531                    Dim::Static(kv_dim),
532                ],
533                f,
534            )
535        } else {
536            Shape::new(&[self.batch, self.past_seq, kv_dim], f)
537        };
538
539        let decode_style = cfg.layer_style();
540        let decode_score_scale = cfg.attn_score_scale();
541        let decode_softcap = cfg.attn_logit_softcapping;
542        let decode_arch = cfg.arch;
543        let decode_sliding = cfg.sliding_window;
544
545        let kv_out = SideOutputs::new();
546
547        let rope_factors = weights.take("rope_freqs.weight").ok().map(|(data, _)| data);
548        let inv_freq = resolve_inv_freq(cfg, rope_factors.as_deref());
549        let (rope_cos, rope_sin) = if self.dynamic_past {
550            (Vec::new(), Vec::new())
551        } else {
552            crate::rope::rope_slice(&inv_freq, self.past_seq)
553        };
554
555        let mut flow = ModelFlow::new("gemma_decode")
556            .with_profile(profile)
557            .input("input_ids", Shape::new(&[self.batch, 1], DType::F32));
558
559        if self.dynamic_past {
560            flow = flow
561                .input("rope_cos", Shape::new(&[1, half], f))
562                .input("rope_sin", Shape::new(&[1, half], f));
563        }
564
565        if self.use_custom_mask {
566            flow = flow.input("mask", Shape::new(&[self.batch, self.past_seq + 1], f));
567        }
568
569        for layer_idx in 0..cfg.num_hidden_layers {
570            flow = flow
571                .input(format!("past_k_{layer_idx}"), past_kv_shape.clone())
572                .input(format!("past_v_{layer_idx}"), past_kv_shape.clone());
573        }
574
575        if !self.dynamic_past {
576            flow = flow.raw_stage(FlowStage::DecodeRopeParams(DecodeRopeParamsStage {
577                cos: rope_cos,
578                sin: rope_sin,
579                half_dim: half,
580            }));
581        }
582
583        flow = flow
584            .bind_decode_inputs(cfg.num_hidden_layers, self.use_custom_mask)
585            .zero_beta_named("gemma.zero_beta.hidden", h)
586            .token_embed()
587            .raw_stage(FlowStage::EmbedScale(EmbedScaleStage::new(h)))
588            .raw_stages(self.before_layers.iter().cloned());
589
590        let layer_fn = self.layer_fn.clone();
591        let use_custom_mask = self.use_custom_mask;
592        let num_heads = cfg.num_attention_heads;
593        let num_kv_heads = cfg.num_key_value_heads;
594        let kv_group_size = cfg.kv_group_size();
595        let num_layers = cfg.active_num_layers();
596        // PLAN.md M2 — Gemma 4 MoE (`gemma4-26b-a4b`) decode-side dispatch.
597        let is_moe = cfg.is_moe();
598        let moe_num_experts = cfg.num_experts;
599        let moe_top_k = cfg.num_experts_used;
600        let moe_n_embd = cfg.hidden_size;
601        let moe_n_ff = cfg.expert_ffn_dim();
602        flow = flow.repeat_layers(num_layers, {
603            let sink = kv_out.clone();
604            let hidden_shape = hidden_shape.clone();
605            move |i| {
606                let mask = if use_custom_mask {
607                    rlx_ir::op::MaskKind::Causal
608                } else {
609                    match (decode_arch, decode_sliding) {
610                        (GemmaArch::Gemma2, Some(w)) => rlx_flow::blocks::gemma2_layer_mask(i, w),
611                        // PLAN.md M2 — Gemma 3 / 4 use the strided
612                        // `sliding_window_pattern` (5 sliding + 1
613                        // full for stride 6).
614                        (GemmaArch::Gemma3 | GemmaArch::Gemma4, Some(w)) => {
615                            rlx_flow::blocks::gemma_strided_layer_mask(
616                                i,
617                                w,
618                                decode_arch.sliding_window_stride(),
619                            )
620                        }
621                        _ => rlx_ir::op::MaskKind::Causal,
622                    }
623                };
624                let spec = GemmaDecodeLayerSpec {
625                    style: decode_style,
626                    num_heads,
627                    head_dim: dh,
628                    num_kv_heads,
629                    kv_group_size,
630                    eps,
631                    use_custom_mask,
632                    hidden_shape: hidden_shape.clone(),
633                    mask,
634                    score_scale: decode_score_scale,
635                    attn_logit_softcap: decode_softcap,
636                };
637                if let Some(ref f) = layer_fn {
638                    return f(GemmaLayerCtx::Decode {
639                        index: i,
640                        spec: spec.clone(),
641                        kv_out: &sink,
642                    });
643                }
644                if is_moe {
645                    let prefix = format!("model.layers.{i}");
646                    let moe = rlx_flow::blocks::MoeFfnStage::hf(
647                        prefix,
648                        moe_num_experts,
649                        moe_top_k,
650                        moe_n_embd,
651                        moe_n_ff,
652                    );
653                    return rlx_flow::blocks::gemma_moe_decode_layer_composed(
654                        i,
655                        spec,
656                        sink.inner(),
657                        moe,
658                    );
659                }
660                GemmaLayerCtx::Decode {
661                    index: i,
662                    spec,
663                    kv_out: &sink,
664                }
665                .default_stage()
666            }
667        });
668
669        flow = flow.raw_stages(self.after_layers.iter().cloned());
670
671        if let Some(patch) = self.flow_patch {
672            flow = patch(flow);
673        }
674
675        let mut flow = flow.raw_stage(FlowStage::GemmaRmsNorm(GemmaRmsNormStage::hf_layer(
676            "model.norm",
677            eps,
678        )));
679        let lm = if cfg.tie_word_embeddings {
680            FlowStage::LmHead(LmHeadStage::tied(cfg.vocab_size, h))
681        } else {
682            FlowStage::LmHead(LmHeadStage::separate("lm_head.weight", cfg.vocab_size, h))
683        };
684        flow = flow.raw_stage(lm);
685        if let Some(cap) = cfg.final_logit_softcapping {
686            flow = flow.raw_stage(FlowStage::LogitSoftcap(LogitSoftcapStage::new(cap)));
687        }
688        let built = flow
689            .output("logits")
690            .build(&mut WeightLoaderSource(weights))?
691            .with_extra_hir_outputs(kv_out.drain());
692
693        Ok(built)
694    }
695}
696
697fn prefill_hidden_shape(
698    batch: usize,
699    seq: usize,
700    hidden: usize,
701    dynamic: bool,
702    dtype: DType,
703) -> Shape {
704    if dynamic {
705        Shape::from_dims(
706            &[
707                Dim::Static(batch),
708                Dim::Dynamic(sym::SEQ),
709                Dim::Static(hidden),
710            ],
711            dtype,
712        )
713    } else {
714        Shape::new(&[batch, seq, hidden], dtype)
715    }
716}
717
718fn prefill_input_shape(batch: usize, seq: usize, dynamic: bool) -> Shape {
719    if dynamic {
720        Shape::from_dims(&[Dim::Static(batch), Dim::Dynamic(sym::SEQ)], DType::F32)
721    } else {
722        Shape::new(&[batch, seq], DType::F32)
723    }
724}
725
726// ── Legacy opt structs + thin wrappers (backward compatible) ─────────
727
728impl<'a> GemmaFlow<'a> {
729    fn from_prefill_opts(cfg: &'a GemmaConfig, o: &GemmaPrefillOpts) -> Self {
730        let mut f = GemmaFlow::new(cfg).prefill().batch(o.batch).seq(o.seq);
731        if o.dynamic_seq {
732            f = f.dynamic_seq();
733        }
734        if o.with_lm_head {
735            f = f.lm_head();
736        }
737        if o.with_kv_outputs {
738            f = f.export_kv();
739        }
740        if o.last_logits_only {
741            f = f.last_token_logits();
742        }
743        if let Some(p) = o.profile.clone() {
744            f = f.profile(p);
745        }
746        f
747    }
748
749    fn from_decode_opts(cfg: &'a GemmaConfig, o: &GemmaDecodeOpts) -> Self {
750        let mut f = GemmaFlow::new(cfg)
751            .decode()
752            .batch(o.batch)
753            .past(o.past_seq)
754            .lm_head();
755        if o.dynamic_past {
756            f = f.dynamic_past();
757        }
758        if o.use_custom_mask {
759            f = f.custom_mask();
760        }
761        if let Some(p) = o.profile.clone() {
762            f = f.profile(p);
763        }
764        f
765    }
766}
767
768/// Options for the tier-0 Gemma prefill assembly line.
769#[derive(Debug, Clone)]
770pub struct GemmaPrefillOpts {
771    pub batch: usize,
772    pub seq: usize,
773    pub dynamic_seq: bool,
774    pub with_lm_head: bool,
775    pub with_kv_outputs: bool,
776    pub last_logits_only: bool,
777    pub profile: Option<CompileProfile>,
778}
779
780impl GemmaPrefillOpts {
781    pub fn static_prefill(batch: usize, seq: usize) -> Self {
782        Self {
783            batch,
784            seq,
785            dynamic_seq: false,
786            with_lm_head: false,
787            with_kv_outputs: false,
788            last_logits_only: false,
789            profile: None,
790        }
791    }
792}
793
794/// Options for tier-0 Gemma decode (KV-cache) assembly line.
795#[derive(Debug, Clone)]
796pub struct GemmaDecodeOpts {
797    pub batch: usize,
798    pub past_seq: usize,
799    pub dynamic_past: bool,
800    pub use_custom_mask: bool,
801    pub profile: Option<CompileProfile>,
802}
803
804pub fn build_gemma_prefill_flow(
805    cfg: &GemmaConfig,
806    weights: &mut dyn WeightLoader,
807    opts: &GemmaPrefillOpts,
808) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
809    build_gemma_prefill_built(cfg, weights, opts)?.into_parts()
810}
811
812pub fn build_gemma_prefill_built(
813    cfg: &GemmaConfig,
814    weights: &mut dyn WeightLoader,
815    opts: &GemmaPrefillOpts,
816) -> Result<BuiltModel> {
817    GemmaFlow::from_prefill_opts(cfg, opts).build(weights)
818}
819
820pub fn build_gemma_decode_flow(
821    cfg: &GemmaConfig,
822    weights: &mut dyn WeightLoader,
823    opts: &GemmaDecodeOpts,
824) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
825    build_gemma_decode_built(cfg, weights, opts)?.into_parts()
826}
827
828pub fn build_gemma_decode_graph(
829    cfg: &GemmaConfig,
830    weights: &mut dyn WeightLoader,
831    opts: &GemmaDecodeOpts,
832) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
833    rlx_core::flow_util::graph_from_built(build_gemma_decode_built(cfg, weights, opts)?)
834}
835
836pub fn build_gemma_decode_built(
837    cfg: &GemmaConfig,
838    weights: &mut dyn WeightLoader,
839    opts: &GemmaDecodeOpts,
840) -> Result<BuiltModel> {
841    GemmaFlow::from_decode_opts(cfg, opts).build(weights)
842}