Skip to main content

rlx_llama32/
flow.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Fluent LLaMA-3.2 model assembly — tier-0 reference for `rlx-flow`.
17//!
18//! ```rust,ignore
19//! use rlx_models::llama32::Llama32Flow;
20//!
21//! // Prefill logits for the last token
22//! let built = Llama32Flow::for_prefill(&cfg, 1, 128)
23//!     .last_token_logits()
24//!     .profile_near(&weights_path)
25//!     .build(&mut weights)?;
26//!
27//! // Decode step with KV side outputs
28//! let built = Llama32Flow::for_decode(&cfg, 1, 256)
29//!     .custom_mask()
30//!     .profile_decode()
31//!     .build(&mut weights)?;
32//!
33//! // Override one layer while keeping the rest of the recipe
34//! let built = Llama32Flow::for_prefill(&cfg, 1, 128)
35//!     .layer(|ctx| {
36//!         if ctx.index() == 0 {
37//!             ctx.default_stage() // or FlowStage::Custom(...)
38//!         } else {
39//!             ctx.default_stage()
40//!         }
41//!     })
42//!     .build(&mut weights)?;
43//! ```
44
45use std::collections::HashMap;
46use std::fmt;
47use std::path::Path;
48use std::sync::Arc;
49
50use anyhow::Result;
51use rlx_flow::blocks::{
52    LlamaDecodeLayerSpec, LlamaDecoderSpec, RopeTablesStage, llama_prefill_layer_fused,
53};
54use rlx_flow::{BuiltModel, CompileProfile, FlowStage, ModelFlow, SideOutputs};
55use rlx_ir::dynamic::sym;
56use rlx_ir::hir::HirModule;
57use rlx_ir::op::MaskKind;
58use rlx_ir::shape::Dim;
59use rlx_ir::{DType, Graph, Shape};
60
61use super::config::Llama32Config;
62use super::rope::{build_rope_tables, resolve_inv_freq};
63use rlx_core::flow_bridge::{WeightLoaderSource, load_compile_profile};
64use rlx_core::weight_loader::WeightLoader;
65
66/// Tier-1 profile file name colocated with weights.
67pub const LLAMA32_PROFILE_FILE: &str = "llama32.rlx.toml";
68
69/// Resolve compile profile from `llama32.rlx.toml` in the weights directory.
70pub fn llama32_profile_near_weights(weights: &Path, decode: bool) -> CompileProfile {
71    let default = if decode {
72        CompileProfile::llama32_decode()
73    } else {
74        CompileProfile::llama32_prefill()
75    };
76    let dir = weights.parent().unwrap_or_else(|| Path::new("."));
77    load_compile_profile(&dir.join(LLAMA32_PROFILE_FILE), default)
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub enum Llama32Mode {
82    Prefill,
83    Decode,
84}
85
86/// Per-layer context for `.layer()` overrides — defaults preserve stock LLaMA blocks.
87pub enum LlamaLayerCtx<'a> {
88    Prefill {
89        index: usize,
90        spec: &'a LlamaDecoderSpec,
91        kv_sink: &'a SideOutputs,
92        export_kv: bool,
93        head_dim: usize,
94        eps: f32,
95    },
96    Decode {
97        index: usize,
98        spec: &'a LlamaDecodeLayerSpec,
99        kv_out: &'a SideOutputs,
100    },
101}
102
103impl LlamaLayerCtx<'_> {
104    pub fn index(&self) -> usize {
105        match self {
106            Self::Prefill { index, .. } | Self::Decode { index, .. } => *index,
107        }
108    }
109
110    /// Stock fused LLaMA layer for this mode (what `.layer()` falls back to).
111    pub fn default_stage(&self) -> FlowStage {
112        match self {
113            Self::Prefill {
114                index,
115                spec,
116                kv_sink,
117                export_kv,
118                head_dim,
119                eps,
120            } => {
121                let mut stages = Vec::new();
122                if *export_kv {
123                    stages.push(FlowStage::LlamaKvTap(
124                        rlx_flow::blocks::LlamaKvTapStage::layer(
125                            *index,
126                            *head_dim,
127                            *eps,
128                            kv_sink.inner(),
129                        ),
130                    ));
131                }
132                stages.push(FlowStage::Named {
133                    name: format!("layer{index}"),
134                    inner: Arc::new(FlowStage::LlamaDecoder(
135                        rlx_flow::blocks::LlamaDecoderStage::layer(*index, (*spec).clone()),
136                    )),
137                });
138                FlowStage::Sequence(stages)
139            }
140            Self::Decode {
141                index,
142                spec,
143                kv_out,
144            } => FlowStage::Named {
145                name: format!("layer{index}"),
146                inner: Arc::new(FlowStage::LlamaDecodeLayer(
147                    rlx_flow::blocks::LlamaDecodeLayerStage::layer(
148                        *index,
149                        (*spec).clone(),
150                        kv_out.inner(),
151                    ),
152                )),
153            },
154        }
155    }
156}
157
158type LayerFn = Arc<dyn Fn(LlamaLayerCtx<'_>) -> FlowStage + Send + Sync>;
159type FlowPatchFn = Arc<dyn Fn(ModelFlow) -> ModelFlow + Send + Sync>;
160
161/// Fluent LLaMA-3.2 flow builder — reads config once, chain modifiers, then `build`.
162///
163/// ```rust,ignore
164/// use rlx_models::llama32::{Llama32Config, Llama32Flow};
165///
166/// let built = Llama32Flow::new(&cfg)
167///     .prefill()
168///     .batch(1)
169///     .seq(128)
170///     .lm_head()
171///     .last_token_logits()
172///     .build(&mut weights)?;
173/// ```
174#[derive(Clone)]
175pub struct Llama32Flow<'a> {
176    cfg: &'a Llama32Config,
177    mode: Llama32Mode,
178    batch: usize,
179    seq: usize,
180    past_seq: usize,
181    dynamic_seq: bool,
182    dynamic_past: bool,
183    with_lm_head: bool,
184    with_kv_outputs: bool,
185    last_logits_only: bool,
186    use_custom_mask: bool,
187    profile: Option<CompileProfile>,
188    before_layers: Vec<FlowStage>,
189    after_layers: Vec<FlowStage>,
190    layer_fn: Option<LayerFn>,
191    flow_patch: Option<FlowPatchFn>,
192}
193
194impl fmt::Debug for Llama32Flow<'_> {
195    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196        f.debug_struct("Llama32Flow")
197            .field("mode", &self.mode)
198            .field("batch", &self.batch)
199            .field("seq", &self.seq)
200            .field("past_seq", &self.past_seq)
201            .field("dynamic_seq", &self.dynamic_seq)
202            .field("dynamic_past", &self.dynamic_past)
203            .field("with_lm_head", &self.with_lm_head)
204            .field("with_kv_outputs", &self.with_kv_outputs)
205            .field("last_logits_only", &self.last_logits_only)
206            .field("use_custom_mask", &self.use_custom_mask)
207            .field("profile", &self.profile)
208            .field("before_layers", &self.before_layers.len())
209            .field("after_layers", &self.after_layers.len())
210            .field("layer_fn", &self.layer_fn.is_some())
211            .field("flow_patch", &self.flow_patch.is_some())
212            .finish_non_exhaustive()
213    }
214}
215
216impl<'a> Llama32Flow<'a> {
217    pub fn new(cfg: &'a Llama32Config) -> Self {
218        Self {
219            cfg,
220            mode: Llama32Mode::Prefill,
221            batch: 1,
222            seq: 128,
223            past_seq: 0,
224            dynamic_seq: false,
225            dynamic_past: false,
226            with_lm_head: false,
227            with_kv_outputs: false,
228            last_logits_only: false,
229            use_custom_mask: false,
230            profile: None,
231            before_layers: Vec::new(),
232            after_layers: Vec::new(),
233            layer_fn: None,
234            flow_patch: None,
235        }
236    }
237
238    /// Prefill recipe with common batch/seq defaults.
239    pub fn for_prefill(cfg: &'a Llama32Config, batch: usize, seq: usize) -> Self {
240        Self::new(cfg).prefill().batch(batch).seq(seq)
241    }
242
243    /// Decode recipe with common batch/past defaults (includes LM head).
244    pub fn for_decode(cfg: &'a Llama32Config, batch: usize, past_seq: usize) -> Self {
245        Self::new(cfg)
246            .decode()
247            .batch(batch)
248            .past(past_seq)
249            .lm_head()
250    }
251
252    pub fn prefill(mut self) -> Self {
253        self.mode = Llama32Mode::Prefill;
254        self
255    }
256
257    pub fn decode(mut self) -> Self {
258        self.mode = Llama32Mode::Decode;
259        self
260    }
261
262    pub fn batch(mut self, batch: usize) -> Self {
263        self.batch = batch;
264        self
265    }
266
267    /// Prefill sequence length (ignored in decode mode).
268    pub fn seq(mut self, seq: usize) -> Self {
269        self.seq = seq;
270        self
271    }
272
273    /// Decode past length (ignored in prefill mode).
274    pub fn past(mut self, past_seq: usize) -> Self {
275        self.past_seq = past_seq;
276        self
277    }
278
279    /// Symbolic sequence dim (`sym::SEQ`) for dynamic prefill specialization.
280    pub fn dynamic_seq(mut self) -> Self {
281        self.dynamic_seq = true;
282        self
283    }
284
285    /// Symbolic past dim (`sym::PAST_SEQ`) for dynamic decode specialization.
286    pub fn dynamic_past(mut self) -> Self {
287        self.dynamic_past = true;
288        self
289    }
290
291    pub fn lm_head(mut self) -> Self {
292        self.with_lm_head = true;
293        self
294    }
295
296    /// Hidden states only — skip LM head (default for prefill unless `.lm_head()`).
297    pub fn hidden_only(mut self) -> Self {
298        self.with_lm_head = false;
299        self.last_logits_only = false;
300        self
301    }
302
303    pub fn last_token_logits(mut self) -> Self {
304        self.with_lm_head = true;
305        self.last_logits_only = true;
306        self
307    }
308
309    pub fn export_kv(mut self) -> Self {
310        self.with_kv_outputs = true;
311        self
312    }
313
314    pub fn custom_mask(mut self) -> Self {
315        self.use_custom_mask = true;
316        self
317    }
318
319    pub fn profile(mut self, profile: CompileProfile) -> Self {
320        self.profile = Some(profile);
321        self
322    }
323
324    /// Fusion-first prefill profile preset.
325    pub fn profile_prefill(mut self) -> Self {
326        self.profile = Some(CompileProfile::llama32_prefill());
327        self
328    }
329
330    /// Decode / KV-cache profile preset (`Fusable` lowering).
331    pub fn profile_decode(mut self) -> Self {
332        self.profile = Some(CompileProfile::llama32_decode());
333        self
334    }
335
336    pub fn profile_near(mut self, weights_path: &Path) -> Self {
337        let decode = self.mode == Llama32Mode::Decode;
338        self.profile = Some(llama32_profile_near_weights(weights_path, decode));
339        self
340    }
341
342    /// Insert custom stages after embedding, before the layer stack.
343    pub fn before_layers(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
344        self.before_layers.extend(stages);
345        self
346    }
347
348    /// Insert custom stages after the layer stack, before final norm / LM head.
349    pub fn after_layers(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
350        self.after_layers.extend(stages);
351        self
352    }
353
354    /// Override per-layer construction (prefill or decode depending on mode).
355    ///
356    /// Call [`LlamaLayerCtx::default_stage`] to keep stock blocks for unmodified layers.
357    pub fn layer<F>(mut self, f: F) -> Self
358    where
359        F: Fn(LlamaLayerCtx<'_>) -> FlowStage + Send + Sync + 'static,
360    {
361        self.layer_fn = Some(Arc::new(f));
362        self
363    }
364
365    /// Patch the assembled [`ModelFlow`] before build — full flexibility escape hatch.
366    pub fn patch_flow<F>(mut self, f: F) -> Self
367    where
368        F: Fn(ModelFlow) -> ModelFlow + Send + Sync + 'static,
369    {
370        self.flow_patch = Some(Arc::new(f));
371        self
372    }
373
374    pub fn build(self, weights: &mut dyn WeightLoader) -> Result<BuiltModel> {
375        match self.mode {
376            Llama32Mode::Prefill => self.build_prefill(weights),
377            Llama32Mode::Decode => self.build_decode(weights),
378        }
379    }
380
381    fn build_prefill(self, weights: &mut dyn WeightLoader) -> Result<BuiltModel> {
382        if self.dynamic_seq && self.batch != 1 {
383            anyhow::bail!("llama32: dynamic_seq prefill requires batch=1");
384        }
385
386        let cfg = self.cfg;
387        let profile = self.profile.unwrap_or_else(CompileProfile::llama32_prefill);
388        let f = DType::F32;
389        let h = cfg.hidden_size;
390        let eps = cfg.rms_norm_eps as f32;
391        let dh = cfg.head_dim();
392
393        let hidden_shape = prefill_hidden_shape(self.batch, self.seq, h, self.dynamic_seq, f);
394        let input_shape = prefill_input_shape(self.batch, self.seq, self.dynamic_seq);
395
396        let rope_factors = weights.take("rope_freqs.weight").ok().map(|(data, _)| data);
397        let inv_freq = resolve_inv_freq(cfg, rope_factors.as_deref());
398        let (cos_data, sin_data) = build_rope_tables(&inv_freq, cfg.max_position_embeddings);
399
400        let decoder_spec = LlamaDecoderSpec {
401            num_heads: cfg.num_attention_heads,
402            head_dim: dh,
403            num_kv_heads: cfg.num_key_value_heads,
404            eps,
405            mask: MaskKind::Causal,
406            hidden_shape: hidden_shape.clone(),
407        };
408
409        let kv_sink = SideOutputs::new();
410
411        let mut flow = ModelFlow::new("llama32")
412            .with_profile(profile)
413            .input("input_ids", input_shape);
414
415        if self.dynamic_seq && self.with_lm_head && self.last_logits_only {
416            flow = flow.input("last_token_idx", Shape::new(&[self.batch], DType::F32));
417        }
418
419        flow = flow
420            .rope_tables(RopeTablesStage::param(
421                cfg.max_position_embeddings,
422                inv_freq.len(),
423                cos_data,
424                sin_data,
425            ))
426            .zero_beta_named("llama32.zero_beta.hidden", h)
427            .token_embed()
428            .raw_stages(self.before_layers.iter().cloned());
429
430        let layer_fn = self.layer_fn.clone();
431        let export = self.with_kv_outputs;
432        flow = flow.repeat_layers(cfg.num_hidden_layers, {
433            let spec = decoder_spec.clone();
434            let sink = kv_sink.clone();
435            move |i| {
436                if let Some(ref f) = layer_fn {
437                    return f(LlamaLayerCtx::Prefill {
438                        index: i,
439                        spec: &spec,
440                        kv_sink: &sink,
441                        export_kv: export,
442                        head_dim: dh,
443                        eps,
444                    });
445                }
446                let mut stages = Vec::new();
447                if export {
448                    stages.push(FlowStage::LlamaKvTap(
449                        rlx_flow::blocks::LlamaKvTapStage::layer(i, dh, eps, sink.inner()),
450                    ));
451                }
452                stages.push(llama_prefill_layer_fused(i, spec.clone()));
453                if stages.len() == 1 {
454                    stages.into_iter().next().unwrap()
455                } else {
456                    FlowStage::Sequence(stages)
457                }
458            }
459        });
460
461        flow = flow.raw_stages(self.after_layers.iter().cloned());
462
463        if self.with_lm_head && self.last_logits_only {
464            flow = if self.dynamic_seq {
465                flow.gather_last_token_dynamic(self.batch)
466            } else {
467                flow.gather_last_token_at(self.batch, self.seq)
468            };
469        }
470
471        flow = flow.final_norm(eps);
472
473        if let Some(patch) = self.flow_patch {
474            flow = patch(flow);
475        }
476
477        let mut built = if self.with_lm_head {
478            flow.lm_head(cfg.vocab_size, h, cfg.tie_word_embeddings)
479                .build(&mut WeightLoaderSource(weights))?
480        } else {
481            flow.output("hidden")
482                .build(&mut WeightLoaderSource(weights))?
483        };
484
485        if self.with_kv_outputs {
486            built = built.with_extra_hir_outputs(kv_sink.drain());
487        }
488        Ok(built)
489    }
490
491    fn build_decode(self, weights: &mut dyn WeightLoader) -> Result<BuiltModel> {
492        let cfg = self.cfg;
493        let profile = self.profile.unwrap_or_else(CompileProfile::llama32_decode);
494        let f = DType::F32;
495        let h = cfg.hidden_size;
496        let eps = cfg.rms_norm_eps as f32;
497        let dh = cfg.head_dim();
498        let kv_dim = cfg.kv_proj_dim();
499        let half = dh / 2;
500
501        let hidden_shape = Shape::new(&[self.batch, 1, h], f);
502        let past_kv_shape = if self.dynamic_past {
503            Shape::from_dims(
504                &[
505                    Dim::Static(self.batch),
506                    Dim::Dynamic(sym::PAST_SEQ),
507                    Dim::Static(kv_dim),
508                ],
509                f,
510            )
511        } else {
512            Shape::new(&[self.batch, self.past_seq, kv_dim], f)
513        };
514
515        let decode_spec = LlamaDecodeLayerSpec {
516            num_heads: cfg.num_attention_heads,
517            head_dim: dh,
518            num_kv_heads: cfg.num_key_value_heads,
519            kv_group_size: cfg.kv_group_size(),
520            eps,
521            use_custom_mask: self.use_custom_mask,
522            hidden_shape,
523        };
524
525        let kv_out = SideOutputs::new();
526
527        let mut flow = ModelFlow::new("llama32_decode")
528            .with_profile(profile)
529            .input("input_ids", Shape::new(&[self.batch, 1], DType::F32))
530            .input("rope_cos", Shape::new(&[1, half], f))
531            .input("rope_sin", Shape::new(&[1, half], f));
532
533        if self.use_custom_mask {
534            flow = flow.input("mask", Shape::new(&[self.batch, self.past_seq + 1], f));
535        }
536
537        for layer_idx in 0..cfg.num_hidden_layers {
538            flow = flow
539                .input(format!("past_k_{layer_idx}"), past_kv_shape.clone())
540                .input(format!("past_v_{layer_idx}"), past_kv_shape.clone());
541        }
542
543        flow = flow
544            .bind_decode_inputs(cfg.num_hidden_layers, self.use_custom_mask)
545            .zero_beta_named("llama32.zero_beta.hidden", h)
546            .token_embed()
547            .raw_stages(self.before_layers.iter().cloned());
548
549        let layer_fn = self.layer_fn.clone();
550        flow = flow.repeat_layers(cfg.num_hidden_layers, {
551            let spec = decode_spec.clone();
552            let sink = kv_out.clone();
553            move |i| {
554                if let Some(ref f) = layer_fn {
555                    return f(LlamaLayerCtx::Decode {
556                        index: i,
557                        spec: &spec,
558                        kv_out: &sink,
559                    });
560                }
561                LlamaLayerCtx::Decode {
562                    index: i,
563                    spec: &spec,
564                    kv_out: &sink,
565                }
566                .default_stage()
567            }
568        });
569
570        flow = flow.raw_stages(self.after_layers.iter().cloned());
571
572        if let Some(patch) = self.flow_patch {
573            flow = patch(flow);
574        }
575
576        let built = flow
577            .final_norm(eps)
578            .lm_head(cfg.vocab_size, h, cfg.tie_word_embeddings)
579            .build(&mut WeightLoaderSource(weights))?
580            .with_extra_hir_outputs(kv_out.drain());
581
582        Ok(built)
583    }
584}
585
586fn prefill_hidden_shape(
587    batch: usize,
588    seq: usize,
589    hidden: usize,
590    dynamic: bool,
591    dtype: DType,
592) -> Shape {
593    if dynamic {
594        Shape::from_dims(
595            &[
596                Dim::Static(batch),
597                Dim::Dynamic(sym::SEQ),
598                Dim::Static(hidden),
599            ],
600            dtype,
601        )
602    } else {
603        Shape::new(&[batch, seq, hidden], dtype)
604    }
605}
606
607fn prefill_input_shape(batch: usize, seq: usize, dynamic: bool) -> Shape {
608    if dynamic {
609        Shape::from_dims(&[Dim::Static(batch), Dim::Dynamic(sym::SEQ)], DType::F32)
610    } else {
611        Shape::new(&[batch, seq], DType::F32)
612    }
613}
614
615// ── Legacy opt structs + thin wrappers (backward compatible) ─────────
616
617impl<'a> Llama32Flow<'a> {
618    fn from_prefill_opts(cfg: &'a Llama32Config, o: &Llama32PrefillOpts) -> Self {
619        let mut f = Llama32Flow::new(cfg).prefill().batch(o.batch).seq(o.seq);
620        if o.dynamic_seq {
621            f = f.dynamic_seq();
622        }
623        if o.with_lm_head {
624            f = f.lm_head();
625        }
626        if o.with_kv_outputs {
627            f = f.export_kv();
628        }
629        if o.last_logits_only {
630            f = f.last_token_logits();
631        }
632        if let Some(p) = o.profile.clone() {
633            f = f.profile(p);
634        }
635        f
636    }
637
638    fn from_decode_opts(cfg: &'a Llama32Config, o: &Llama32DecodeOpts) -> Self {
639        let mut f = Llama32Flow::new(cfg)
640            .decode()
641            .batch(o.batch)
642            .past(o.past_seq)
643            .lm_head();
644        if o.dynamic_past {
645            f = f.dynamic_past();
646        }
647        if o.use_custom_mask {
648            f = f.custom_mask();
649        }
650        if let Some(p) = o.profile.clone() {
651            f = f.profile(p);
652        }
653        f
654    }
655}
656
657/// Options for the tier-0 LLaMA-3.2 prefill assembly line.
658#[derive(Debug, Clone)]
659pub struct Llama32PrefillOpts {
660    pub batch: usize,
661    pub seq: usize,
662    pub dynamic_seq: bool,
663    pub with_lm_head: bool,
664    pub with_kv_outputs: bool,
665    pub last_logits_only: bool,
666    pub profile: Option<CompileProfile>,
667}
668
669impl Llama32PrefillOpts {
670    pub fn static_prefill(batch: usize, seq: usize) -> Self {
671        Self {
672            batch,
673            seq,
674            dynamic_seq: false,
675            with_lm_head: false,
676            with_kv_outputs: false,
677            last_logits_only: false,
678            profile: None,
679        }
680    }
681}
682
683/// Options for tier-0 LLaMA-3.2 decode (KV-cache) assembly line.
684#[derive(Debug, Clone)]
685pub struct Llama32DecodeOpts {
686    pub batch: usize,
687    pub past_seq: usize,
688    pub dynamic_past: bool,
689    pub use_custom_mask: bool,
690    pub profile: Option<CompileProfile>,
691}
692
693pub fn build_llama32_prefill_flow(
694    cfg: &Llama32Config,
695    weights: &mut dyn WeightLoader,
696    opts: &Llama32PrefillOpts,
697) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
698    build_llama32_prefill_built(cfg, weights, opts)?.into_parts()
699}
700
701pub fn build_llama32_prefill_built(
702    cfg: &Llama32Config,
703    weights: &mut dyn WeightLoader,
704    opts: &Llama32PrefillOpts,
705) -> Result<BuiltModel> {
706    Llama32Flow::from_prefill_opts(cfg, opts).build(weights)
707}
708
709pub fn build_llama32_decode_flow(
710    cfg: &Llama32Config,
711    weights: &mut dyn WeightLoader,
712    opts: &Llama32DecodeOpts,
713) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
714    build_llama32_decode_built(cfg, weights, opts)?.into_parts()
715}
716
717pub fn build_llama32_decode_graph(
718    cfg: &Llama32Config,
719    weights: &mut dyn WeightLoader,
720    opts: &Llama32DecodeOpts,
721) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
722    rlx_core::flow_util::graph_from_built(build_llama32_decode_built(cfg, weights, opts)?)
723}
724
725pub fn build_llama32_decode_built(
726    cfg: &Llama32Config,
727    weights: &mut dyn WeightLoader,
728    opts: &Llama32DecodeOpts,
729) -> Result<BuiltModel> {
730    Llama32Flow::from_decode_opts(cfg, opts).build(weights)
731}