1use std::collections::HashMap;
46use std::fmt;
47use std::path::Path;
48use std::sync::Arc;
49
50use anyhow::Result;
51use rlx_flow::blocks::{
52 DecodeRopeParamsStage, EmbedScaleStage, GemmaDecodeLayerSpec, GemmaDecodeLayerStage,
53 GemmaLayerStyle, GemmaRmsNormStage, LmHeadStage, LogitSoftcapStage, RopeTablesStage,
54 gemma_attn_spec, gemma_prefill_layer_composed,
55};
56use rlx_flow::{BuiltModel, CompileProfile, FlowStage, ModelFlow, SideOutputs};
57use rlx_ir::dynamic::sym;
58use rlx_ir::hir::HirModule;
59use rlx_ir::shape::Dim;
60use rlx_ir::{DType, Graph, Shape};
61
62use super::config::{GemmaArch, GemmaConfig};
63use super::rope::{build_rope_tables, resolve_inv_freq};
64use rlx_core::flow_bridge::{WeightLoaderSource, load_compile_profile};
65use rlx_core::weight_loader::WeightLoader;
66
67pub const GEMMA_PROFILE_FILE: &str = "gemma.rlx.toml";
69
70pub fn gemma_profile_near_weights(weights: &Path, decode: bool) -> CompileProfile {
72 let default = if decode {
73 CompileProfile::gemma_decode()
74 } else {
75 CompileProfile::gemma_prefill()
76 };
77 let dir = weights.parent().unwrap_or_else(|| Path::new("."));
78 load_compile_profile(&dir.join(GEMMA_PROFILE_FILE), default)
79}
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
82pub enum GemmaMode {
83 Prefill,
84 Decode,
85}
86
87pub enum GemmaLayerCtx<'a> {
89 Prefill {
90 index: usize,
91 style: GemmaLayerStyle,
92 attn: rlx_flow::blocks::SelfAttnPrefillSpec,
93 kv_sink: &'a SideOutputs,
94 export_kv: bool,
95 head_dim: usize,
96 eps: f32,
97 },
98 Decode {
99 index: usize,
100 spec: GemmaDecodeLayerSpec,
101 kv_out: &'a SideOutputs,
102 },
103}
104
105impl GemmaLayerCtx<'_> {
106 pub fn index(&self) -> usize {
107 match self {
108 Self::Prefill { index, .. } | Self::Decode { index, .. } => *index,
109 }
110 }
111
112 pub fn default_stage(&self) -> FlowStage {
113 match self {
114 Self::Prefill {
115 index,
116 style,
117 attn,
118 kv_sink,
119 export_kv,
120 head_dim: _,
121 eps,
122 } => gemma_prefill_layer_composed(
123 *index,
124 *style,
125 attn.clone(),
126 *eps,
127 if *export_kv {
128 Some(kv_sink.inner())
129 } else {
130 None
131 },
132 ),
133 Self::Decode {
134 index,
135 spec,
136 kv_out,
137 } => FlowStage::Named {
138 name: format!("layer{index}"),
139 inner: Arc::new(FlowStage::GemmaDecodeLayer(GemmaDecodeLayerStage::layer(
140 *index,
141 spec.clone(),
142 kv_out.inner(),
143 ))),
144 },
145 }
146 }
147}
148
149type LayerFn = Arc<dyn Fn(GemmaLayerCtx<'_>) -> FlowStage + Send + Sync>;
150type FlowPatchFn = Arc<dyn Fn(ModelFlow) -> ModelFlow + Send + Sync>;
151
152#[derive(Clone)]
166pub struct GemmaFlow<'a> {
167 cfg: &'a GemmaConfig,
168 mode: GemmaMode,
169 batch: usize,
170 seq: usize,
171 past_seq: usize,
172 dynamic_seq: bool,
173 dynamic_past: bool,
174 with_lm_head: bool,
175 with_kv_outputs: bool,
176 last_logits_only: bool,
177 use_custom_mask: bool,
178 profile: Option<CompileProfile>,
179 before_layers: Vec<FlowStage>,
180 after_layers: Vec<FlowStage>,
181 layer_fn: Option<LayerFn>,
182 flow_patch: Option<FlowPatchFn>,
183}
184
185impl fmt::Debug for GemmaFlow<'_> {
186 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
187 f.debug_struct("GemmaFlow")
188 .field("mode", &self.mode)
189 .field("batch", &self.batch)
190 .field("seq", &self.seq)
191 .field("past_seq", &self.past_seq)
192 .field("dynamic_seq", &self.dynamic_seq)
193 .field("dynamic_past", &self.dynamic_past)
194 .field("with_lm_head", &self.with_lm_head)
195 .field("with_kv_outputs", &self.with_kv_outputs)
196 .field("last_logits_only", &self.last_logits_only)
197 .field("use_custom_mask", &self.use_custom_mask)
198 .field("profile", &self.profile)
199 .field("before_layers", &self.before_layers.len())
200 .field("after_layers", &self.after_layers.len())
201 .field("layer_fn", &self.layer_fn.is_some())
202 .field("flow_patch", &self.flow_patch.is_some())
203 .finish_non_exhaustive()
204 }
205}
206
207impl<'a> GemmaFlow<'a> {
208 pub fn new(cfg: &'a GemmaConfig) -> Self {
209 Self {
210 cfg,
211 mode: GemmaMode::Prefill,
212 batch: 1,
213 seq: 128,
214 past_seq: 0,
215 dynamic_seq: false,
216 dynamic_past: false,
217 with_lm_head: false,
218 with_kv_outputs: false,
219 last_logits_only: false,
220 use_custom_mask: false,
221 profile: None,
222 before_layers: Vec::new(),
223 after_layers: Vec::new(),
224 layer_fn: None,
225 flow_patch: None,
226 }
227 }
228
229 pub fn for_prefill(cfg: &'a GemmaConfig, batch: usize, seq: usize) -> Self {
231 Self::new(cfg).prefill().batch(batch).seq(seq)
232 }
233
234 pub fn for_decode(cfg: &'a GemmaConfig, batch: usize, past_seq: usize) -> Self {
236 Self::new(cfg)
237 .decode()
238 .batch(batch)
239 .past(past_seq)
240 .lm_head()
241 }
242
243 pub fn prefill(mut self) -> Self {
244 self.mode = GemmaMode::Prefill;
245 self
246 }
247
248 pub fn decode(mut self) -> Self {
249 self.mode = GemmaMode::Decode;
250 self
251 }
252
253 pub fn batch(mut self, batch: usize) -> Self {
254 self.batch = batch;
255 self
256 }
257
258 pub fn seq(mut self, seq: usize) -> Self {
260 self.seq = seq;
261 self
262 }
263
264 pub fn past(mut self, past_seq: usize) -> Self {
266 self.past_seq = past_seq;
267 self
268 }
269
270 pub fn dynamic_seq(mut self) -> Self {
272 self.dynamic_seq = true;
273 self
274 }
275
276 pub fn dynamic_past(mut self) -> Self {
278 self.dynamic_past = true;
279 self
280 }
281
282 pub fn lm_head(mut self) -> Self {
283 self.with_lm_head = true;
284 self
285 }
286
287 pub fn hidden_only(mut self) -> Self {
289 self.with_lm_head = false;
290 self.last_logits_only = false;
291 self
292 }
293
294 pub fn last_token_logits(mut self) -> Self {
295 self.with_lm_head = true;
296 self.last_logits_only = true;
297 self
298 }
299
300 pub fn export_kv(mut self) -> Self {
301 self.with_kv_outputs = true;
302 self
303 }
304
305 pub fn custom_mask(mut self) -> Self {
306 self.use_custom_mask = true;
307 self
308 }
309
310 pub fn profile(mut self, profile: CompileProfile) -> Self {
311 self.profile = Some(profile);
312 self
313 }
314
315 pub fn profile_prefill(mut self) -> Self {
317 self.profile = Some(CompileProfile::gemma_prefill());
318 self
319 }
320
321 pub fn profile_decode(mut self) -> Self {
322 self.profile = Some(CompileProfile::gemma_decode());
323 self
324 }
325
326 pub fn profile_near(mut self, weights_path: &Path) -> Self {
327 let decode = self.mode == GemmaMode::Decode;
328 self.profile = Some(gemma_profile_near_weights(weights_path, decode));
329 self
330 }
331
332 pub fn before_layers(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
334 self.before_layers.extend(stages);
335 self
336 }
337
338 pub fn after_layers(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
340 self.after_layers.extend(stages);
341 self
342 }
343
344 pub fn layer<F>(mut self, f: F) -> Self
348 where
349 F: Fn(GemmaLayerCtx<'_>) -> FlowStage + Send + Sync + 'static,
350 {
351 self.layer_fn = Some(Arc::new(f));
352 self
353 }
354
355 pub fn patch_flow<F>(mut self, f: F) -> Self
357 where
358 F: Fn(ModelFlow) -> ModelFlow + Send + Sync + 'static,
359 {
360 self.flow_patch = Some(Arc::new(f));
361 self
362 }
363
364 pub fn build(self, weights: &mut dyn WeightLoader) -> Result<BuiltModel> {
365 match self.mode {
366 GemmaMode::Prefill => self.build_prefill(weights),
367 GemmaMode::Decode => self.build_decode(weights),
368 }
369 }
370
371 fn build_prefill(self, weights: &mut dyn WeightLoader) -> Result<BuiltModel> {
372 if self.dynamic_seq && self.batch != 1 {
373 anyhow::bail!("gemma: dynamic_seq prefill requires batch=1");
374 }
375
376 let cfg = self.cfg;
377 let profile = self.profile.unwrap_or_else(CompileProfile::gemma_prefill);
378 let f = DType::F32;
379 let h = cfg.hidden_size;
380 let eps = cfg.rms_norm_eps as f32;
381 let dh = cfg.head_dim();
382 let layer_style = cfg.layer_style();
383
384 let _hidden_shape = prefill_hidden_shape(self.batch, self.seq, h, self.dynamic_seq, f);
385 let input_shape = prefill_input_shape(self.batch, self.seq, self.dynamic_seq);
386
387 let rope_factors = weights.take("rope_freqs.weight").ok().map(|(data, _)| data);
388 let inv_freq = resolve_inv_freq(cfg, rope_factors.as_deref());
389 let (cos_data, sin_data) = build_rope_tables(&inv_freq, cfg.max_position_embeddings);
390
391 let kv_sink = SideOutputs::new();
392
393 let mut flow = ModelFlow::new("gemma")
394 .with_profile(profile)
395 .input("input_ids", input_shape);
396
397 if self.dynamic_seq && self.with_lm_head && self.last_logits_only {
398 flow = flow.input("last_token_idx", Shape::new(&[self.batch], DType::F32));
399 }
400
401 flow = flow
402 .rope_tables(RopeTablesStage::param(
403 cfg.max_position_embeddings,
404 inv_freq.len(),
405 cos_data,
406 sin_data,
407 ))
408 .zero_beta_named("gemma.zero_beta.hidden", h)
409 .token_embed()
410 .raw_stage(FlowStage::EmbedScale(EmbedScaleStage::new(h)))
411 .raw_stages(self.before_layers.iter().cloned());
412
413 let layer_fn = self.layer_fn.clone();
414 let export = self.with_kv_outputs;
415 let num_heads = cfg.num_attention_heads;
416 let num_kv_heads = cfg.num_key_value_heads;
417 let num_layers = cfg.active_num_layers();
418 let layer_attn: Vec<_> = (0..num_layers).map(|i| cfg.layer_attn_options(i)).collect();
419 let is_moe = cfg.is_moe();
424 let moe_num_experts = cfg.num_experts;
425 let moe_top_k = cfg.num_experts_used;
426 let moe_n_embd = cfg.hidden_size;
427 let moe_n_ff = cfg.expert_ffn_dim();
428 flow = flow.repeat_layers(num_layers, {
429 let style = layer_style;
430 let sink = kv_sink.clone();
431 move |i| {
432 let (mask, score_scale, softcap) = layer_attn[i];
433 let attn =
434 gemma_attn_spec(i, num_heads, dh, num_kv_heads, mask, score_scale, softcap);
435 if let Some(ref f) = layer_fn {
436 return f(GemmaLayerCtx::Prefill {
437 index: i,
438 style,
439 attn: attn.clone(),
440 kv_sink: &sink,
441 export_kv: export,
442 head_dim: dh,
443 eps,
444 });
445 }
446 if is_moe {
447 let prefix = format!("model.layers.{i}");
448 let moe = rlx_flow::blocks::MoeFfnStage::hf(
449 prefix,
450 moe_num_experts,
451 moe_top_k,
452 moe_n_embd,
453 moe_n_ff,
454 );
455 let kv = if export { Some(sink.inner()) } else { None };
456 return rlx_flow::blocks::gemma_moe_prefill_layer_composed(
457 i, style, attn, eps, kv, moe,
458 );
459 }
460 GemmaLayerCtx::Prefill {
461 index: i,
462 style,
463 attn,
464 kv_sink: &sink,
465 export_kv: export,
466 head_dim: dh,
467 eps,
468 }
469 .default_stage()
470 }
471 });
472
473 flow = flow.raw_stages(self.after_layers.iter().cloned());
474
475 if self.with_lm_head && self.last_logits_only {
476 flow = if self.dynamic_seq {
477 flow.gather_last_token_dynamic(self.batch)
478 } else {
479 flow.gather_last_token_at(self.batch, self.seq)
480 };
481 }
482
483 flow = flow.raw_stage(FlowStage::GemmaRmsNorm(GemmaRmsNormStage::hf_layer(
484 "model.norm",
485 eps,
486 )));
487
488 if let Some(patch) = self.flow_patch {
489 flow = patch(flow);
490 }
491
492 let mut built = if self.with_lm_head {
493 let lm = if cfg.tie_word_embeddings {
494 FlowStage::LmHead(LmHeadStage::tied(cfg.vocab_size, h))
495 } else {
496 FlowStage::LmHead(LmHeadStage::separate("lm_head.weight", cfg.vocab_size, h))
497 };
498 flow = flow.raw_stage(lm);
499 if let Some(cap) = cfg.final_logit_softcapping {
500 flow = flow.raw_stage(FlowStage::LogitSoftcap(LogitSoftcapStage::new(cap)));
501 }
502 flow.output("logits")
503 .build(&mut WeightLoaderSource(weights))?
504 } else {
505 flow.output("hidden")
506 .build(&mut WeightLoaderSource(weights))?
507 };
508
509 if self.with_kv_outputs {
510 built = built.with_extra_hir_outputs(kv_sink.drain());
511 }
512 Ok(built)
513 }
514
515 fn build_decode(self, weights: &mut dyn WeightLoader) -> Result<BuiltModel> {
516 let cfg = self.cfg;
517 let profile = self.profile.unwrap_or_else(CompileProfile::gemma_decode);
518 let f = DType::F32;
519 let h = cfg.hidden_size;
520 let eps = cfg.rms_norm_eps as f32;
521 let dh = cfg.head_dim();
522 let kv_dim = cfg.kv_proj_dim();
523 let half = dh / 2;
524
525 let hidden_shape = Shape::new(&[self.batch, 1, h], f);
526 let past_kv_shape = if self.dynamic_past {
527 Shape::from_dims(
528 &[
529 Dim::Static(self.batch),
530 Dim::Dynamic(sym::PAST_SEQ),
531 Dim::Static(kv_dim),
532 ],
533 f,
534 )
535 } else {
536 Shape::new(&[self.batch, self.past_seq, kv_dim], f)
537 };
538
539 let decode_style = cfg.layer_style();
540 let decode_score_scale = cfg.attn_score_scale();
541 let decode_softcap = cfg.attn_logit_softcapping;
542 let decode_arch = cfg.arch;
543 let decode_sliding = cfg.sliding_window;
544
545 let kv_out = SideOutputs::new();
546
547 let rope_factors = weights.take("rope_freqs.weight").ok().map(|(data, _)| data);
548 let inv_freq = resolve_inv_freq(cfg, rope_factors.as_deref());
549 let (rope_cos, rope_sin) = if self.dynamic_past {
550 (Vec::new(), Vec::new())
551 } else {
552 crate::rope::rope_slice(&inv_freq, self.past_seq)
553 };
554
555 let mut flow = ModelFlow::new("gemma_decode")
556 .with_profile(profile)
557 .input("input_ids", Shape::new(&[self.batch, 1], DType::F32));
558
559 if self.dynamic_past {
560 flow = flow
561 .input("rope_cos", Shape::new(&[1, half], f))
562 .input("rope_sin", Shape::new(&[1, half], f));
563 }
564
565 if self.use_custom_mask {
566 flow = flow.input("mask", Shape::new(&[self.batch, self.past_seq + 1], f));
567 }
568
569 for layer_idx in 0..cfg.num_hidden_layers {
570 flow = flow
571 .input(format!("past_k_{layer_idx}"), past_kv_shape.clone())
572 .input(format!("past_v_{layer_idx}"), past_kv_shape.clone());
573 }
574
575 if !self.dynamic_past {
576 flow = flow.raw_stage(FlowStage::DecodeRopeParams(DecodeRopeParamsStage {
577 cos: rope_cos,
578 sin: rope_sin,
579 half_dim: half,
580 }));
581 }
582
583 flow = flow
584 .bind_decode_inputs(cfg.num_hidden_layers, self.use_custom_mask)
585 .zero_beta_named("gemma.zero_beta.hidden", h)
586 .token_embed()
587 .raw_stage(FlowStage::EmbedScale(EmbedScaleStage::new(h)))
588 .raw_stages(self.before_layers.iter().cloned());
589
590 let layer_fn = self.layer_fn.clone();
591 let use_custom_mask = self.use_custom_mask;
592 let num_heads = cfg.num_attention_heads;
593 let num_kv_heads = cfg.num_key_value_heads;
594 let kv_group_size = cfg.kv_group_size();
595 let num_layers = cfg.active_num_layers();
596 let is_moe = cfg.is_moe();
598 let moe_num_experts = cfg.num_experts;
599 let moe_top_k = cfg.num_experts_used;
600 let moe_n_embd = cfg.hidden_size;
601 let moe_n_ff = cfg.expert_ffn_dim();
602 flow = flow.repeat_layers(num_layers, {
603 let sink = kv_out.clone();
604 let hidden_shape = hidden_shape.clone();
605 move |i| {
606 let mask = if use_custom_mask {
607 rlx_ir::op::MaskKind::Causal
608 } else {
609 match (decode_arch, decode_sliding) {
610 (GemmaArch::Gemma2, Some(w)) => rlx_flow::blocks::gemma2_layer_mask(i, w),
611 (GemmaArch::Gemma3 | GemmaArch::Gemma4, Some(w)) => {
615 rlx_flow::blocks::gemma_strided_layer_mask(
616 i,
617 w,
618 decode_arch.sliding_window_stride(),
619 )
620 }
621 _ => rlx_ir::op::MaskKind::Causal,
622 }
623 };
624 let spec = GemmaDecodeLayerSpec {
625 style: decode_style,
626 num_heads,
627 head_dim: dh,
628 num_kv_heads,
629 kv_group_size,
630 eps,
631 use_custom_mask,
632 hidden_shape: hidden_shape.clone(),
633 mask,
634 score_scale: decode_score_scale,
635 attn_logit_softcap: decode_softcap,
636 };
637 if let Some(ref f) = layer_fn {
638 return f(GemmaLayerCtx::Decode {
639 index: i,
640 spec: spec.clone(),
641 kv_out: &sink,
642 });
643 }
644 if is_moe {
645 let prefix = format!("model.layers.{i}");
646 let moe = rlx_flow::blocks::MoeFfnStage::hf(
647 prefix,
648 moe_num_experts,
649 moe_top_k,
650 moe_n_embd,
651 moe_n_ff,
652 );
653 return rlx_flow::blocks::gemma_moe_decode_layer_composed(
654 i,
655 spec,
656 sink.inner(),
657 moe,
658 );
659 }
660 GemmaLayerCtx::Decode {
661 index: i,
662 spec,
663 kv_out: &sink,
664 }
665 .default_stage()
666 }
667 });
668
669 flow = flow.raw_stages(self.after_layers.iter().cloned());
670
671 if let Some(patch) = self.flow_patch {
672 flow = patch(flow);
673 }
674
675 let mut flow = flow.raw_stage(FlowStage::GemmaRmsNorm(GemmaRmsNormStage::hf_layer(
676 "model.norm",
677 eps,
678 )));
679 let lm = if cfg.tie_word_embeddings {
680 FlowStage::LmHead(LmHeadStage::tied(cfg.vocab_size, h))
681 } else {
682 FlowStage::LmHead(LmHeadStage::separate("lm_head.weight", cfg.vocab_size, h))
683 };
684 flow = flow.raw_stage(lm);
685 if let Some(cap) = cfg.final_logit_softcapping {
686 flow = flow.raw_stage(FlowStage::LogitSoftcap(LogitSoftcapStage::new(cap)));
687 }
688 let built = flow
689 .output("logits")
690 .build(&mut WeightLoaderSource(weights))?
691 .with_extra_hir_outputs(kv_out.drain());
692
693 Ok(built)
694 }
695}
696
697fn prefill_hidden_shape(
698 batch: usize,
699 seq: usize,
700 hidden: usize,
701 dynamic: bool,
702 dtype: DType,
703) -> Shape {
704 if dynamic {
705 Shape::from_dims(
706 &[
707 Dim::Static(batch),
708 Dim::Dynamic(sym::SEQ),
709 Dim::Static(hidden),
710 ],
711 dtype,
712 )
713 } else {
714 Shape::new(&[batch, seq, hidden], dtype)
715 }
716}
717
718fn prefill_input_shape(batch: usize, seq: usize, dynamic: bool) -> Shape {
719 if dynamic {
720 Shape::from_dims(&[Dim::Static(batch), Dim::Dynamic(sym::SEQ)], DType::F32)
721 } else {
722 Shape::new(&[batch, seq], DType::F32)
723 }
724}
725
726impl<'a> GemmaFlow<'a> {
729 fn from_prefill_opts(cfg: &'a GemmaConfig, o: &GemmaPrefillOpts) -> Self {
730 let mut f = GemmaFlow::new(cfg).prefill().batch(o.batch).seq(o.seq);
731 if o.dynamic_seq {
732 f = f.dynamic_seq();
733 }
734 if o.with_lm_head {
735 f = f.lm_head();
736 }
737 if o.with_kv_outputs {
738 f = f.export_kv();
739 }
740 if o.last_logits_only {
741 f = f.last_token_logits();
742 }
743 if let Some(p) = o.profile.clone() {
744 f = f.profile(p);
745 }
746 f
747 }
748
749 fn from_decode_opts(cfg: &'a GemmaConfig, o: &GemmaDecodeOpts) -> Self {
750 let mut f = GemmaFlow::new(cfg)
751 .decode()
752 .batch(o.batch)
753 .past(o.past_seq)
754 .lm_head();
755 if o.dynamic_past {
756 f = f.dynamic_past();
757 }
758 if o.use_custom_mask {
759 f = f.custom_mask();
760 }
761 if let Some(p) = o.profile.clone() {
762 f = f.profile(p);
763 }
764 f
765 }
766}
767
768#[derive(Debug, Clone)]
770pub struct GemmaPrefillOpts {
771 pub batch: usize,
772 pub seq: usize,
773 pub dynamic_seq: bool,
774 pub with_lm_head: bool,
775 pub with_kv_outputs: bool,
776 pub last_logits_only: bool,
777 pub profile: Option<CompileProfile>,
778}
779
780impl GemmaPrefillOpts {
781 pub fn static_prefill(batch: usize, seq: usize) -> Self {
782 Self {
783 batch,
784 seq,
785 dynamic_seq: false,
786 with_lm_head: false,
787 with_kv_outputs: false,
788 last_logits_only: false,
789 profile: None,
790 }
791 }
792}
793
794#[derive(Debug, Clone)]
796pub struct GemmaDecodeOpts {
797 pub batch: usize,
798 pub past_seq: usize,
799 pub dynamic_past: bool,
800 pub use_custom_mask: bool,
801 pub profile: Option<CompileProfile>,
802}
803
804pub fn build_gemma_prefill_flow(
805 cfg: &GemmaConfig,
806 weights: &mut dyn WeightLoader,
807 opts: &GemmaPrefillOpts,
808) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
809 build_gemma_prefill_built(cfg, weights, opts)?.into_parts()
810}
811
812pub fn build_gemma_prefill_built(
813 cfg: &GemmaConfig,
814 weights: &mut dyn WeightLoader,
815 opts: &GemmaPrefillOpts,
816) -> Result<BuiltModel> {
817 GemmaFlow::from_prefill_opts(cfg, opts).build(weights)
818}
819
820pub fn build_gemma_decode_flow(
821 cfg: &GemmaConfig,
822 weights: &mut dyn WeightLoader,
823 opts: &GemmaDecodeOpts,
824) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
825 build_gemma_decode_built(cfg, weights, opts)?.into_parts()
826}
827
828pub fn build_gemma_decode_graph(
829 cfg: &GemmaConfig,
830 weights: &mut dyn WeightLoader,
831 opts: &GemmaDecodeOpts,
832) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
833 rlx_core::flow_util::graph_from_built(build_gemma_decode_built(cfg, weights, opts)?)
834}
835
836pub fn build_gemma_decode_built(
837 cfg: &GemmaConfig,
838 weights: &mut dyn WeightLoader,
839 opts: &GemmaDecodeOpts,
840) -> Result<BuiltModel> {
841 GemmaFlow::from_decode_opts(cfg, opts).build(weights)
842}