use std::collections::HashMap;
use std::fmt;
use std::path::Path;
use std::sync::Arc;
use anyhow::Result;
use rlx_flow::blocks::{
LlamaDecodeLayerSpec, LlamaDecoderSpec, RopeTablesStage, llama_prefill_layer_fused,
};
use rlx_flow::{BuiltModel, CompileProfile, FlowStage, ModelFlow, SideOutputs};
use rlx_ir::dynamic::sym;
use rlx_ir::hir::HirModule;
use rlx_ir::op::MaskKind;
use rlx_ir::shape::Dim;
use rlx_ir::{DType, Graph, Shape};
use super::config::Llama32Config;
use super::rope::{build_rope_tables, resolve_inv_freq};
use rlx_core::flow_bridge::{WeightLoaderSource, load_compile_profile};
use rlx_core::weight_loader::WeightLoader;
pub const LLAMA32_PROFILE_FILE: &str = "llama32.rlx.toml";
pub fn llama32_profile_near_weights(weights: &Path, decode: bool) -> CompileProfile {
let default = if decode {
CompileProfile::llama32_decode()
} else {
CompileProfile::llama32_prefill()
};
let dir = weights.parent().unwrap_or_else(|| Path::new("."));
load_compile_profile(&dir.join(LLAMA32_PROFILE_FILE), default)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Llama32Mode {
Prefill,
Decode,
}
pub enum LlamaLayerCtx<'a> {
Prefill {
index: usize,
spec: &'a LlamaDecoderSpec,
kv_sink: &'a SideOutputs,
export_kv: bool,
head_dim: usize,
eps: f32,
},
Decode {
index: usize,
spec: &'a LlamaDecodeLayerSpec,
kv_out: &'a SideOutputs,
},
}
impl LlamaLayerCtx<'_> {
pub fn index(&self) -> usize {
match self {
Self::Prefill { index, .. } | Self::Decode { index, .. } => *index,
}
}
pub fn default_stage(&self) -> FlowStage {
match self {
Self::Prefill {
index,
spec,
kv_sink,
export_kv,
head_dim,
eps,
} => {
let mut stages = Vec::new();
if *export_kv {
stages.push(FlowStage::LlamaKvTap(
rlx_flow::blocks::LlamaKvTapStage::layer(
*index,
*head_dim,
*eps,
kv_sink.inner(),
),
));
}
stages.push(FlowStage::Named {
name: format!("layer{index}"),
inner: Arc::new(FlowStage::LlamaDecoder(
rlx_flow::blocks::LlamaDecoderStage::layer(*index, (*spec).clone()),
)),
});
FlowStage::Sequence(stages)
}
Self::Decode {
index,
spec,
kv_out,
} => FlowStage::Named {
name: format!("layer{index}"),
inner: Arc::new(FlowStage::LlamaDecodeLayer(
rlx_flow::blocks::LlamaDecodeLayerStage::layer(
*index,
(*spec).clone(),
kv_out.inner(),
),
)),
},
}
}
}
type LayerFn = Arc<dyn Fn(LlamaLayerCtx<'_>) -> FlowStage + Send + Sync>;
type FlowPatchFn = Arc<dyn Fn(ModelFlow) -> ModelFlow + Send + Sync>;
#[derive(Clone)]
pub struct Llama32Flow<'a> {
cfg: &'a Llama32Config,
mode: Llama32Mode,
batch: usize,
seq: usize,
past_seq: usize,
dynamic_seq: bool,
dynamic_past: bool,
with_lm_head: bool,
with_kv_outputs: bool,
last_logits_only: bool,
use_custom_mask: bool,
profile: Option<CompileProfile>,
before_layers: Vec<FlowStage>,
after_layers: Vec<FlowStage>,
layer_fn: Option<LayerFn>,
flow_patch: Option<FlowPatchFn>,
}
impl fmt::Debug for Llama32Flow<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Llama32Flow")
.field("mode", &self.mode)
.field("batch", &self.batch)
.field("seq", &self.seq)
.field("past_seq", &self.past_seq)
.field("dynamic_seq", &self.dynamic_seq)
.field("dynamic_past", &self.dynamic_past)
.field("with_lm_head", &self.with_lm_head)
.field("with_kv_outputs", &self.with_kv_outputs)
.field("last_logits_only", &self.last_logits_only)
.field("use_custom_mask", &self.use_custom_mask)
.field("profile", &self.profile)
.field("before_layers", &self.before_layers.len())
.field("after_layers", &self.after_layers.len())
.field("layer_fn", &self.layer_fn.is_some())
.field("flow_patch", &self.flow_patch.is_some())
.finish_non_exhaustive()
}
}
impl<'a> Llama32Flow<'a> {
pub fn new(cfg: &'a Llama32Config) -> Self {
Self {
cfg,
mode: Llama32Mode::Prefill,
batch: 1,
seq: 128,
past_seq: 0,
dynamic_seq: false,
dynamic_past: false,
with_lm_head: false,
with_kv_outputs: false,
last_logits_only: false,
use_custom_mask: false,
profile: None,
before_layers: Vec::new(),
after_layers: Vec::new(),
layer_fn: None,
flow_patch: None,
}
}
pub fn for_prefill(cfg: &'a Llama32Config, batch: usize, seq: usize) -> Self {
Self::new(cfg).prefill().batch(batch).seq(seq)
}
pub fn for_decode(cfg: &'a Llama32Config, batch: usize, past_seq: usize) -> Self {
Self::new(cfg)
.decode()
.batch(batch)
.past(past_seq)
.lm_head()
}
pub fn prefill(mut self) -> Self {
self.mode = Llama32Mode::Prefill;
self
}
pub fn decode(mut self) -> Self {
self.mode = Llama32Mode::Decode;
self
}
pub fn batch(mut self, batch: usize) -> Self {
self.batch = batch;
self
}
pub fn seq(mut self, seq: usize) -> Self {
self.seq = seq;
self
}
pub fn past(mut self, past_seq: usize) -> Self {
self.past_seq = past_seq;
self
}
pub fn dynamic_seq(mut self) -> Self {
self.dynamic_seq = true;
self
}
pub fn dynamic_past(mut self) -> Self {
self.dynamic_past = true;
self
}
pub fn lm_head(mut self) -> Self {
self.with_lm_head = true;
self
}
pub fn hidden_only(mut self) -> Self {
self.with_lm_head = false;
self.last_logits_only = false;
self
}
pub fn last_token_logits(mut self) -> Self {
self.with_lm_head = true;
self.last_logits_only = true;
self
}
pub fn export_kv(mut self) -> Self {
self.with_kv_outputs = true;
self
}
pub fn custom_mask(mut self) -> Self {
self.use_custom_mask = true;
self
}
pub fn profile(mut self, profile: CompileProfile) -> Self {
self.profile = Some(profile);
self
}
pub fn profile_prefill(mut self) -> Self {
self.profile = Some(CompileProfile::llama32_prefill());
self
}
pub fn profile_decode(mut self) -> Self {
self.profile = Some(CompileProfile::llama32_decode());
self
}
pub fn profile_near(mut self, weights_path: &Path) -> Self {
let decode = self.mode == Llama32Mode::Decode;
self.profile = Some(llama32_profile_near_weights(weights_path, decode));
self
}
pub fn before_layers(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
self.before_layers.extend(stages);
self
}
pub fn after_layers(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
self.after_layers.extend(stages);
self
}
pub fn layer<F>(mut self, f: F) -> Self
where
F: Fn(LlamaLayerCtx<'_>) -> FlowStage + Send + Sync + 'static,
{
self.layer_fn = Some(Arc::new(f));
self
}
pub fn patch_flow<F>(mut self, f: F) -> Self
where
F: Fn(ModelFlow) -> ModelFlow + Send + Sync + 'static,
{
self.flow_patch = Some(Arc::new(f));
self
}
pub fn build(self, weights: &mut dyn WeightLoader) -> Result<BuiltModel> {
match self.mode {
Llama32Mode::Prefill => self.build_prefill(weights),
Llama32Mode::Decode => self.build_decode(weights),
}
}
fn build_prefill(self, weights: &mut dyn WeightLoader) -> Result<BuiltModel> {
if self.dynamic_seq && self.batch != 1 {
anyhow::bail!("llama32: dynamic_seq prefill requires batch=1");
}
let cfg = self.cfg;
let profile = self.profile.unwrap_or_else(CompileProfile::llama32_prefill);
let f = DType::F32;
let h = cfg.hidden_size;
let eps = cfg.rms_norm_eps as f32;
let dh = cfg.head_dim();
let hidden_shape = prefill_hidden_shape(self.batch, self.seq, h, self.dynamic_seq, f);
let input_shape = prefill_input_shape(self.batch, self.seq, self.dynamic_seq);
let rope_factors = weights.take("rope_freqs.weight").ok().map(|(data, _)| data);
let inv_freq = resolve_inv_freq(cfg, rope_factors.as_deref());
let (cos_data, sin_data) = build_rope_tables(&inv_freq, cfg.max_position_embeddings);
let decoder_spec = LlamaDecoderSpec {
num_heads: cfg.num_attention_heads,
head_dim: dh,
num_kv_heads: cfg.num_key_value_heads,
eps,
mask: MaskKind::Causal,
hidden_shape: hidden_shape.clone(),
};
let kv_sink = SideOutputs::new();
let mut flow = ModelFlow::new("llama32")
.with_profile(profile)
.input("input_ids", input_shape);
if self.dynamic_seq && self.with_lm_head && self.last_logits_only {
flow = flow.input("last_token_idx", Shape::new(&[self.batch], DType::F32));
}
flow = flow
.rope_tables(RopeTablesStage::param(
cfg.max_position_embeddings,
inv_freq.len(),
cos_data,
sin_data,
))
.zero_beta_named("llama32.zero_beta.hidden", h)
.token_embed()
.raw_stages(self.before_layers.iter().cloned());
let layer_fn = self.layer_fn.clone();
let export = self.with_kv_outputs;
flow = flow.repeat_layers(cfg.num_hidden_layers, {
let spec = decoder_spec.clone();
let sink = kv_sink.clone();
move |i| {
if let Some(ref f) = layer_fn {
return f(LlamaLayerCtx::Prefill {
index: i,
spec: &spec,
kv_sink: &sink,
export_kv: export,
head_dim: dh,
eps,
});
}
let mut stages = Vec::new();
if export {
stages.push(FlowStage::LlamaKvTap(
rlx_flow::blocks::LlamaKvTapStage::layer(i, dh, eps, sink.inner()),
));
}
stages.push(llama_prefill_layer_fused(i, spec.clone()));
if stages.len() == 1 {
stages.into_iter().next().unwrap()
} else {
FlowStage::Sequence(stages)
}
}
});
flow = flow.raw_stages(self.after_layers.iter().cloned());
if self.with_lm_head && self.last_logits_only {
flow = if self.dynamic_seq {
flow.gather_last_token_dynamic(self.batch)
} else {
flow.gather_last_token_at(self.batch, self.seq)
};
}
flow = flow.final_norm(eps);
if let Some(patch) = self.flow_patch {
flow = patch(flow);
}
let mut built = if self.with_lm_head {
flow.lm_head(cfg.vocab_size, h, cfg.tie_word_embeddings)
.build(&mut WeightLoaderSource(weights))?
} else {
flow.output("hidden")
.build(&mut WeightLoaderSource(weights))?
};
if self.with_kv_outputs {
built = built.with_extra_hir_outputs(kv_sink.drain());
}
Ok(built)
}
fn build_decode(self, weights: &mut dyn WeightLoader) -> Result<BuiltModel> {
let cfg = self.cfg;
let profile = self.profile.unwrap_or_else(CompileProfile::llama32_decode);
let f = DType::F32;
let h = cfg.hidden_size;
let eps = cfg.rms_norm_eps as f32;
let dh = cfg.head_dim();
let kv_dim = cfg.kv_proj_dim();
let half = dh / 2;
let hidden_shape = Shape::new(&[self.batch, 1, h], f);
let past_kv_shape = if self.dynamic_past {
Shape::from_dims(
&[
Dim::Static(self.batch),
Dim::Dynamic(sym::PAST_SEQ),
Dim::Static(kv_dim),
],
f,
)
} else {
Shape::new(&[self.batch, self.past_seq, kv_dim], f)
};
let decode_spec = LlamaDecodeLayerSpec {
num_heads: cfg.num_attention_heads,
head_dim: dh,
num_kv_heads: cfg.num_key_value_heads,
kv_group_size: cfg.kv_group_size(),
eps,
use_custom_mask: self.use_custom_mask,
hidden_shape,
};
let kv_out = SideOutputs::new();
let mut flow = ModelFlow::new("llama32_decode")
.with_profile(profile)
.input("input_ids", Shape::new(&[self.batch, 1], DType::F32))
.input("rope_cos", Shape::new(&[1, half], f))
.input("rope_sin", Shape::new(&[1, half], f));
if self.use_custom_mask {
flow = flow.input("mask", Shape::new(&[self.batch, self.past_seq + 1], f));
}
for layer_idx in 0..cfg.num_hidden_layers {
flow = flow
.input(format!("past_k_{layer_idx}"), past_kv_shape.clone())
.input(format!("past_v_{layer_idx}"), past_kv_shape.clone());
}
flow = flow
.bind_decode_inputs(cfg.num_hidden_layers, self.use_custom_mask)
.zero_beta_named("llama32.zero_beta.hidden", h)
.token_embed()
.raw_stages(self.before_layers.iter().cloned());
let layer_fn = self.layer_fn.clone();
flow = flow.repeat_layers(cfg.num_hidden_layers, {
let spec = decode_spec.clone();
let sink = kv_out.clone();
move |i| {
if let Some(ref f) = layer_fn {
return f(LlamaLayerCtx::Decode {
index: i,
spec: &spec,
kv_out: &sink,
});
}
LlamaLayerCtx::Decode {
index: i,
spec: &spec,
kv_out: &sink,
}
.default_stage()
}
});
flow = flow.raw_stages(self.after_layers.iter().cloned());
if let Some(patch) = self.flow_patch {
flow = patch(flow);
}
let built = flow
.final_norm(eps)
.lm_head(cfg.vocab_size, h, cfg.tie_word_embeddings)
.build(&mut WeightLoaderSource(weights))?
.with_extra_hir_outputs(kv_out.drain());
Ok(built)
}
}
fn prefill_hidden_shape(
batch: usize,
seq: usize,
hidden: usize,
dynamic: bool,
dtype: DType,
) -> Shape {
if dynamic {
Shape::from_dims(
&[
Dim::Static(batch),
Dim::Dynamic(sym::SEQ),
Dim::Static(hidden),
],
dtype,
)
} else {
Shape::new(&[batch, seq, hidden], dtype)
}
}
fn prefill_input_shape(batch: usize, seq: usize, dynamic: bool) -> Shape {
if dynamic {
Shape::from_dims(&[Dim::Static(batch), Dim::Dynamic(sym::SEQ)], DType::F32)
} else {
Shape::new(&[batch, seq], DType::F32)
}
}
impl<'a> Llama32Flow<'a> {
fn from_prefill_opts(cfg: &'a Llama32Config, o: &Llama32PrefillOpts) -> Self {
let mut f = Llama32Flow::new(cfg).prefill().batch(o.batch).seq(o.seq);
if o.dynamic_seq {
f = f.dynamic_seq();
}
if o.with_lm_head {
f = f.lm_head();
}
if o.with_kv_outputs {
f = f.export_kv();
}
if o.last_logits_only {
f = f.last_token_logits();
}
if let Some(p) = o.profile.clone() {
f = f.profile(p);
}
f
}
fn from_decode_opts(cfg: &'a Llama32Config, o: &Llama32DecodeOpts) -> Self {
let mut f = Llama32Flow::new(cfg)
.decode()
.batch(o.batch)
.past(o.past_seq)
.lm_head();
if o.dynamic_past {
f = f.dynamic_past();
}
if o.use_custom_mask {
f = f.custom_mask();
}
if let Some(p) = o.profile.clone() {
f = f.profile(p);
}
f
}
}
#[derive(Debug, Clone)]
pub struct Llama32PrefillOpts {
pub batch: usize,
pub seq: usize,
pub dynamic_seq: bool,
pub with_lm_head: bool,
pub with_kv_outputs: bool,
pub last_logits_only: bool,
pub profile: Option<CompileProfile>,
}
impl Llama32PrefillOpts {
pub fn static_prefill(batch: usize, seq: usize) -> Self {
Self {
batch,
seq,
dynamic_seq: false,
with_lm_head: false,
with_kv_outputs: false,
last_logits_only: false,
profile: None,
}
}
}
#[derive(Debug, Clone)]
pub struct Llama32DecodeOpts {
pub batch: usize,
pub past_seq: usize,
pub dynamic_past: bool,
pub use_custom_mask: bool,
pub profile: Option<CompileProfile>,
}
pub fn build_llama32_prefill_flow(
cfg: &Llama32Config,
weights: &mut dyn WeightLoader,
opts: &Llama32PrefillOpts,
) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
build_llama32_prefill_built(cfg, weights, opts)?.into_parts()
}
pub fn build_llama32_prefill_built(
cfg: &Llama32Config,
weights: &mut dyn WeightLoader,
opts: &Llama32PrefillOpts,
) -> Result<BuiltModel> {
Llama32Flow::from_prefill_opts(cfg, opts).build(weights)
}
pub fn build_llama32_decode_flow(
cfg: &Llama32Config,
weights: &mut dyn WeightLoader,
opts: &Llama32DecodeOpts,
) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
build_llama32_decode_built(cfg, weights, opts)?.into_parts()
}
pub fn build_llama32_decode_graph(
cfg: &Llama32Config,
weights: &mut dyn WeightLoader,
opts: &Llama32DecodeOpts,
) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
rlx_core::flow_util::graph_from_built(build_llama32_decode_built(cfg, weights, opts)?)
}
pub fn build_llama32_decode_built(
cfg: &Llama32Config,
weights: &mut dyn WeightLoader,
opts: &Llama32DecodeOpts,
) -> Result<BuiltModel> {
Llama32Flow::from_decode_opts(cfg, opts).build(weights)
}