1use std::collections::HashMap;
46use std::fmt;
47use std::path::Path;
48use std::sync::Arc;
49
50use anyhow::Result;
51use rlx_flow::blocks::{
52 LlamaDecodeLayerSpec, LlamaDecoderSpec, RopeTablesStage, llama_prefill_layer_fused,
53};
54use rlx_flow::{BuiltModel, CompileProfile, FlowStage, ModelFlow, SideOutputs};
55use rlx_ir::dynamic::sym;
56use rlx_ir::hir::HirModule;
57use rlx_ir::op::MaskKind;
58use rlx_ir::shape::Dim;
59use rlx_ir::{DType, Graph, Shape};
60
61use super::config::Llama32Config;
62use super::rope::{build_rope_tables, resolve_inv_freq};
63use rlx_core::flow_bridge::{WeightLoaderSource, load_compile_profile};
64use rlx_core::weight_loader::WeightLoader;
65
66pub const LLAMA32_PROFILE_FILE: &str = "llama32.rlx.toml";
68
69pub fn llama32_profile_near_weights(weights: &Path, decode: bool) -> CompileProfile {
71 let default = if decode {
72 CompileProfile::llama32_decode()
73 } else {
74 CompileProfile::llama32_prefill()
75 };
76 let dir = weights.parent().unwrap_or_else(|| Path::new("."));
77 load_compile_profile(&dir.join(LLAMA32_PROFILE_FILE), default)
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub enum Llama32Mode {
82 Prefill,
83 Decode,
84}
85
86pub enum LlamaLayerCtx<'a> {
88 Prefill {
89 index: usize,
90 spec: &'a LlamaDecoderSpec,
91 kv_sink: &'a SideOutputs,
92 export_kv: bool,
93 head_dim: usize,
94 eps: f32,
95 },
96 Decode {
97 index: usize,
98 spec: &'a LlamaDecodeLayerSpec,
99 kv_out: &'a SideOutputs,
100 },
101}
102
103impl LlamaLayerCtx<'_> {
104 pub fn index(&self) -> usize {
105 match self {
106 Self::Prefill { index, .. } | Self::Decode { index, .. } => *index,
107 }
108 }
109
110 pub fn default_stage(&self) -> FlowStage {
112 match self {
113 Self::Prefill {
114 index,
115 spec,
116 kv_sink,
117 export_kv,
118 head_dim,
119 eps,
120 } => {
121 let mut stages = Vec::new();
122 if *export_kv {
123 stages.push(FlowStage::LlamaKvTap(
124 rlx_flow::blocks::LlamaKvTapStage::layer(
125 *index,
126 *head_dim,
127 *eps,
128 kv_sink.inner(),
129 ),
130 ));
131 }
132 stages.push(FlowStage::Named {
133 name: format!("layer{index}"),
134 inner: Arc::new(FlowStage::LlamaDecoder(
135 rlx_flow::blocks::LlamaDecoderStage::layer(*index, (*spec).clone()),
136 )),
137 });
138 FlowStage::Sequence(stages)
139 }
140 Self::Decode {
141 index,
142 spec,
143 kv_out,
144 } => FlowStage::Named {
145 name: format!("layer{index}"),
146 inner: Arc::new(FlowStage::LlamaDecodeLayer(
147 rlx_flow::blocks::LlamaDecodeLayerStage::layer(
148 *index,
149 (*spec).clone(),
150 kv_out.inner(),
151 ),
152 )),
153 },
154 }
155 }
156}
157
158type LayerFn = Arc<dyn Fn(LlamaLayerCtx<'_>) -> FlowStage + Send + Sync>;
159type FlowPatchFn = Arc<dyn Fn(ModelFlow) -> ModelFlow + Send + Sync>;
160
161#[derive(Clone)]
175pub struct Llama32Flow<'a> {
176 cfg: &'a Llama32Config,
177 mode: Llama32Mode,
178 batch: usize,
179 seq: usize,
180 past_seq: usize,
181 dynamic_seq: bool,
182 dynamic_past: bool,
183 with_lm_head: bool,
184 with_kv_outputs: bool,
185 last_logits_only: bool,
186 use_custom_mask: bool,
187 profile: Option<CompileProfile>,
188 before_layers: Vec<FlowStage>,
189 after_layers: Vec<FlowStage>,
190 layer_fn: Option<LayerFn>,
191 flow_patch: Option<FlowPatchFn>,
192}
193
194impl fmt::Debug for Llama32Flow<'_> {
195 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196 f.debug_struct("Llama32Flow")
197 .field("mode", &self.mode)
198 .field("batch", &self.batch)
199 .field("seq", &self.seq)
200 .field("past_seq", &self.past_seq)
201 .field("dynamic_seq", &self.dynamic_seq)
202 .field("dynamic_past", &self.dynamic_past)
203 .field("with_lm_head", &self.with_lm_head)
204 .field("with_kv_outputs", &self.with_kv_outputs)
205 .field("last_logits_only", &self.last_logits_only)
206 .field("use_custom_mask", &self.use_custom_mask)
207 .field("profile", &self.profile)
208 .field("before_layers", &self.before_layers.len())
209 .field("after_layers", &self.after_layers.len())
210 .field("layer_fn", &self.layer_fn.is_some())
211 .field("flow_patch", &self.flow_patch.is_some())
212 .finish_non_exhaustive()
213 }
214}
215
216impl<'a> Llama32Flow<'a> {
217 pub fn new(cfg: &'a Llama32Config) -> Self {
218 Self {
219 cfg,
220 mode: Llama32Mode::Prefill,
221 batch: 1,
222 seq: 128,
223 past_seq: 0,
224 dynamic_seq: false,
225 dynamic_past: false,
226 with_lm_head: false,
227 with_kv_outputs: false,
228 last_logits_only: false,
229 use_custom_mask: false,
230 profile: None,
231 before_layers: Vec::new(),
232 after_layers: Vec::new(),
233 layer_fn: None,
234 flow_patch: None,
235 }
236 }
237
238 pub fn for_prefill(cfg: &'a Llama32Config, batch: usize, seq: usize) -> Self {
240 Self::new(cfg).prefill().batch(batch).seq(seq)
241 }
242
243 pub fn for_decode(cfg: &'a Llama32Config, batch: usize, past_seq: usize) -> Self {
245 Self::new(cfg)
246 .decode()
247 .batch(batch)
248 .past(past_seq)
249 .lm_head()
250 }
251
252 pub fn prefill(mut self) -> Self {
253 self.mode = Llama32Mode::Prefill;
254 self
255 }
256
257 pub fn decode(mut self) -> Self {
258 self.mode = Llama32Mode::Decode;
259 self
260 }
261
262 pub fn batch(mut self, batch: usize) -> Self {
263 self.batch = batch;
264 self
265 }
266
267 pub fn seq(mut self, seq: usize) -> Self {
269 self.seq = seq;
270 self
271 }
272
273 pub fn past(mut self, past_seq: usize) -> Self {
275 self.past_seq = past_seq;
276 self
277 }
278
279 pub fn dynamic_seq(mut self) -> Self {
281 self.dynamic_seq = true;
282 self
283 }
284
285 pub fn dynamic_past(mut self) -> Self {
287 self.dynamic_past = true;
288 self
289 }
290
291 pub fn lm_head(mut self) -> Self {
292 self.with_lm_head = true;
293 self
294 }
295
296 pub fn hidden_only(mut self) -> Self {
298 self.with_lm_head = false;
299 self.last_logits_only = false;
300 self
301 }
302
303 pub fn last_token_logits(mut self) -> Self {
304 self.with_lm_head = true;
305 self.last_logits_only = true;
306 self
307 }
308
309 pub fn export_kv(mut self) -> Self {
310 self.with_kv_outputs = true;
311 self
312 }
313
314 pub fn custom_mask(mut self) -> Self {
315 self.use_custom_mask = true;
316 self
317 }
318
319 pub fn profile(mut self, profile: CompileProfile) -> Self {
320 self.profile = Some(profile);
321 self
322 }
323
324 pub fn profile_prefill(mut self) -> Self {
326 self.profile = Some(CompileProfile::llama32_prefill());
327 self
328 }
329
330 pub fn profile_decode(mut self) -> Self {
332 self.profile = Some(CompileProfile::llama32_decode());
333 self
334 }
335
336 pub fn profile_near(mut self, weights_path: &Path) -> Self {
337 let decode = self.mode == Llama32Mode::Decode;
338 self.profile = Some(llama32_profile_near_weights(weights_path, decode));
339 self
340 }
341
342 pub fn before_layers(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
344 self.before_layers.extend(stages);
345 self
346 }
347
348 pub fn after_layers(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
350 self.after_layers.extend(stages);
351 self
352 }
353
354 pub fn layer<F>(mut self, f: F) -> Self
358 where
359 F: Fn(LlamaLayerCtx<'_>) -> FlowStage + Send + Sync + 'static,
360 {
361 self.layer_fn = Some(Arc::new(f));
362 self
363 }
364
365 pub fn patch_flow<F>(mut self, f: F) -> Self
367 where
368 F: Fn(ModelFlow) -> ModelFlow + Send + Sync + 'static,
369 {
370 self.flow_patch = Some(Arc::new(f));
371 self
372 }
373
374 pub fn build(self, weights: &mut dyn WeightLoader) -> Result<BuiltModel> {
375 match self.mode {
376 Llama32Mode::Prefill => self.build_prefill(weights),
377 Llama32Mode::Decode => self.build_decode(weights),
378 }
379 }
380
381 fn build_prefill(self, weights: &mut dyn WeightLoader) -> Result<BuiltModel> {
382 if self.dynamic_seq && self.batch != 1 {
383 anyhow::bail!("llama32: dynamic_seq prefill requires batch=1");
384 }
385
386 let cfg = self.cfg;
387 let profile = self.profile.unwrap_or_else(CompileProfile::llama32_prefill);
388 let f = DType::F32;
389 let h = cfg.hidden_size;
390 let eps = cfg.rms_norm_eps as f32;
391 let dh = cfg.head_dim();
392
393 let hidden_shape = prefill_hidden_shape(self.batch, self.seq, h, self.dynamic_seq, f);
394 let input_shape = prefill_input_shape(self.batch, self.seq, self.dynamic_seq);
395
396 let rope_factors = weights.take("rope_freqs.weight").ok().map(|(data, _)| data);
397 let inv_freq = resolve_inv_freq(cfg, rope_factors.as_deref());
398 let (cos_data, sin_data) = build_rope_tables(&inv_freq, cfg.max_position_embeddings);
399
400 let decoder_spec = LlamaDecoderSpec {
401 num_heads: cfg.num_attention_heads,
402 head_dim: dh,
403 num_kv_heads: cfg.num_key_value_heads,
404 eps,
405 mask: MaskKind::Causal,
406 hidden_shape: hidden_shape.clone(),
407 };
408
409 let kv_sink = SideOutputs::new();
410
411 let mut flow = ModelFlow::new("llama32")
412 .with_profile(profile)
413 .input("input_ids", input_shape);
414
415 if self.dynamic_seq && self.with_lm_head && self.last_logits_only {
416 flow = flow.input("last_token_idx", Shape::new(&[self.batch], DType::F32));
417 }
418
419 flow = flow
420 .rope_tables(RopeTablesStage::param(
421 cfg.max_position_embeddings,
422 inv_freq.len(),
423 cos_data,
424 sin_data,
425 ))
426 .zero_beta_named("llama32.zero_beta.hidden", h)
427 .token_embed()
428 .raw_stages(self.before_layers.iter().cloned());
429
430 let layer_fn = self.layer_fn.clone();
431 let export = self.with_kv_outputs;
432 flow = flow.repeat_layers(cfg.num_hidden_layers, {
433 let spec = decoder_spec.clone();
434 let sink = kv_sink.clone();
435 move |i| {
436 if let Some(ref f) = layer_fn {
437 return f(LlamaLayerCtx::Prefill {
438 index: i,
439 spec: &spec,
440 kv_sink: &sink,
441 export_kv: export,
442 head_dim: dh,
443 eps,
444 });
445 }
446 let mut stages = Vec::new();
447 if export {
448 stages.push(FlowStage::LlamaKvTap(
449 rlx_flow::blocks::LlamaKvTapStage::layer(i, dh, eps, sink.inner()),
450 ));
451 }
452 stages.push(llama_prefill_layer_fused(i, spec.clone()));
453 if stages.len() == 1 {
454 stages.into_iter().next().unwrap()
455 } else {
456 FlowStage::Sequence(stages)
457 }
458 }
459 });
460
461 flow = flow.raw_stages(self.after_layers.iter().cloned());
462
463 if self.with_lm_head && self.last_logits_only {
464 flow = if self.dynamic_seq {
465 flow.gather_last_token_dynamic(self.batch)
466 } else {
467 flow.gather_last_token_at(self.batch, self.seq)
468 };
469 }
470
471 flow = flow.final_norm(eps);
472
473 if let Some(patch) = self.flow_patch {
474 flow = patch(flow);
475 }
476
477 let mut built = if self.with_lm_head {
478 flow.lm_head(cfg.vocab_size, h, cfg.tie_word_embeddings)
479 .build(&mut WeightLoaderSource(weights))?
480 } else {
481 flow.output("hidden")
482 .build(&mut WeightLoaderSource(weights))?
483 };
484
485 if self.with_kv_outputs {
486 built = built.with_extra_hir_outputs(kv_sink.drain());
487 }
488 Ok(built)
489 }
490
491 fn build_decode(self, weights: &mut dyn WeightLoader) -> Result<BuiltModel> {
492 let cfg = self.cfg;
493 let profile = self.profile.unwrap_or_else(CompileProfile::llama32_decode);
494 let f = DType::F32;
495 let h = cfg.hidden_size;
496 let eps = cfg.rms_norm_eps as f32;
497 let dh = cfg.head_dim();
498 let kv_dim = cfg.kv_proj_dim();
499 let half = dh / 2;
500
501 let hidden_shape = Shape::new(&[self.batch, 1, h], f);
502 let past_kv_shape = if self.dynamic_past {
503 Shape::from_dims(
504 &[
505 Dim::Static(self.batch),
506 Dim::Dynamic(sym::PAST_SEQ),
507 Dim::Static(kv_dim),
508 ],
509 f,
510 )
511 } else {
512 Shape::new(&[self.batch, self.past_seq, kv_dim], f)
513 };
514
515 let decode_spec = LlamaDecodeLayerSpec {
516 num_heads: cfg.num_attention_heads,
517 head_dim: dh,
518 num_kv_heads: cfg.num_key_value_heads,
519 kv_group_size: cfg.kv_group_size(),
520 eps,
521 use_custom_mask: self.use_custom_mask,
522 hidden_shape,
523 };
524
525 let kv_out = SideOutputs::new();
526
527 let mut flow = ModelFlow::new("llama32_decode")
528 .with_profile(profile)
529 .input("input_ids", Shape::new(&[self.batch, 1], DType::F32))
530 .input("rope_cos", Shape::new(&[1, half], f))
531 .input("rope_sin", Shape::new(&[1, half], f));
532
533 if self.use_custom_mask {
534 flow = flow.input("mask", Shape::new(&[self.batch, self.past_seq + 1], f));
535 }
536
537 for layer_idx in 0..cfg.num_hidden_layers {
538 flow = flow
539 .input(format!("past_k_{layer_idx}"), past_kv_shape.clone())
540 .input(format!("past_v_{layer_idx}"), past_kv_shape.clone());
541 }
542
543 flow = flow
544 .bind_decode_inputs(cfg.num_hidden_layers, self.use_custom_mask)
545 .zero_beta_named("llama32.zero_beta.hidden", h)
546 .token_embed()
547 .raw_stages(self.before_layers.iter().cloned());
548
549 let layer_fn = self.layer_fn.clone();
550 flow = flow.repeat_layers(cfg.num_hidden_layers, {
551 let spec = decode_spec.clone();
552 let sink = kv_out.clone();
553 move |i| {
554 if let Some(ref f) = layer_fn {
555 return f(LlamaLayerCtx::Decode {
556 index: i,
557 spec: &spec,
558 kv_out: &sink,
559 });
560 }
561 LlamaLayerCtx::Decode {
562 index: i,
563 spec: &spec,
564 kv_out: &sink,
565 }
566 .default_stage()
567 }
568 });
569
570 flow = flow.raw_stages(self.after_layers.iter().cloned());
571
572 if let Some(patch) = self.flow_patch {
573 flow = patch(flow);
574 }
575
576 let built = flow
577 .final_norm(eps)
578 .lm_head(cfg.vocab_size, h, cfg.tie_word_embeddings)
579 .build(&mut WeightLoaderSource(weights))?
580 .with_extra_hir_outputs(kv_out.drain());
581
582 Ok(built)
583 }
584}
585
586fn prefill_hidden_shape(
587 batch: usize,
588 seq: usize,
589 hidden: usize,
590 dynamic: bool,
591 dtype: DType,
592) -> Shape {
593 if dynamic {
594 Shape::from_dims(
595 &[
596 Dim::Static(batch),
597 Dim::Dynamic(sym::SEQ),
598 Dim::Static(hidden),
599 ],
600 dtype,
601 )
602 } else {
603 Shape::new(&[batch, seq, hidden], dtype)
604 }
605}
606
607fn prefill_input_shape(batch: usize, seq: usize, dynamic: bool) -> Shape {
608 if dynamic {
609 Shape::from_dims(&[Dim::Static(batch), Dim::Dynamic(sym::SEQ)], DType::F32)
610 } else {
611 Shape::new(&[batch, seq], DType::F32)
612 }
613}
614
615impl<'a> Llama32Flow<'a> {
618 fn from_prefill_opts(cfg: &'a Llama32Config, o: &Llama32PrefillOpts) -> Self {
619 let mut f = Llama32Flow::new(cfg).prefill().batch(o.batch).seq(o.seq);
620 if o.dynamic_seq {
621 f = f.dynamic_seq();
622 }
623 if o.with_lm_head {
624 f = f.lm_head();
625 }
626 if o.with_kv_outputs {
627 f = f.export_kv();
628 }
629 if o.last_logits_only {
630 f = f.last_token_logits();
631 }
632 if let Some(p) = o.profile.clone() {
633 f = f.profile(p);
634 }
635 f
636 }
637
638 fn from_decode_opts(cfg: &'a Llama32Config, o: &Llama32DecodeOpts) -> Self {
639 let mut f = Llama32Flow::new(cfg)
640 .decode()
641 .batch(o.batch)
642 .past(o.past_seq)
643 .lm_head();
644 if o.dynamic_past {
645 f = f.dynamic_past();
646 }
647 if o.use_custom_mask {
648 f = f.custom_mask();
649 }
650 if let Some(p) = o.profile.clone() {
651 f = f.profile(p);
652 }
653 f
654 }
655}
656
657#[derive(Debug, Clone)]
659pub struct Llama32PrefillOpts {
660 pub batch: usize,
661 pub seq: usize,
662 pub dynamic_seq: bool,
663 pub with_lm_head: bool,
664 pub with_kv_outputs: bool,
665 pub last_logits_only: bool,
666 pub profile: Option<CompileProfile>,
667}
668
669impl Llama32PrefillOpts {
670 pub fn static_prefill(batch: usize, seq: usize) -> Self {
671 Self {
672 batch,
673 seq,
674 dynamic_seq: false,
675 with_lm_head: false,
676 with_kv_outputs: false,
677 last_logits_only: false,
678 profile: None,
679 }
680 }
681}
682
683#[derive(Debug, Clone)]
685pub struct Llama32DecodeOpts {
686 pub batch: usize,
687 pub past_seq: usize,
688 pub dynamic_past: bool,
689 pub use_custom_mask: bool,
690 pub profile: Option<CompileProfile>,
691}
692
693pub fn build_llama32_prefill_flow(
694 cfg: &Llama32Config,
695 weights: &mut dyn WeightLoader,
696 opts: &Llama32PrefillOpts,
697) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
698 build_llama32_prefill_built(cfg, weights, opts)?.into_parts()
699}
700
701pub fn build_llama32_prefill_built(
702 cfg: &Llama32Config,
703 weights: &mut dyn WeightLoader,
704 opts: &Llama32PrefillOpts,
705) -> Result<BuiltModel> {
706 Llama32Flow::from_prefill_opts(cfg, opts).build(weights)
707}
708
709pub fn build_llama32_decode_flow(
710 cfg: &Llama32Config,
711 weights: &mut dyn WeightLoader,
712 opts: &Llama32DecodeOpts,
713) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
714 build_llama32_decode_built(cfg, weights, opts)?.into_parts()
715}
716
717pub fn build_llama32_decode_graph(
718 cfg: &Llama32Config,
719 weights: &mut dyn WeightLoader,
720 opts: &Llama32DecodeOpts,
721) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
722 rlx_core::flow_util::graph_from_built(build_llama32_decode_built(cfg, weights, opts)?)
723}
724
725pub fn build_llama32_decode_built(
726 cfg: &Llama32Config,
727 weights: &mut dyn WeightLoader,
728 opts: &Llama32DecodeOpts,
729) -> Result<BuiltModel> {
730 Llama32Flow::from_decode_opts(cfg, opts).build(weights)
731}