use crate::config::VoxtralConfig;
use crate::weights::LanguageModelPrefixLoader;
use anyhow::Result;
use rlx_core::flow_bridge::WeightLoaderSource;
use rlx_core::weight_map::WeightMap;
use rlx_flow::blocks::{LlamaDecoderSpec, RopeTablesStage, llama_prefill_layer_fused};
use rlx_flow::{BuiltModel, CompileProfile, FlowStage, ModelFlow, SideOutputs};
use rlx_ir::op::MaskKind;
use rlx_ir::{DType, Shape};
use rlx_llama32::flow::{Llama32DecodeOpts, build_llama32_decode_built};
use rlx_llama32::rope::{build_rope_tables, resolve_inv_freq};
pub fn build_voxtral_prefill_built(
cfg: &VoxtralConfig,
weights: &mut WeightMap,
batch: usize,
seq: usize,
with_kv_outputs: bool,
last_logits_only: bool,
) -> Result<BuiltModel> {
let llama = cfg.llama_config();
let profile = CompileProfile::llama32_prefill();
let f = DType::F32;
let h = llama.hidden_size;
let eps = llama.rms_norm_eps as f32;
let dh = llama.head_dim();
let hidden_shape = Shape::new(&[batch, seq, h], f);
let rope_factors = weights.take("rope_freqs.weight").ok().map(|(d, _)| d);
let inv_freq = resolve_inv_freq(llama, rope_factors.as_deref());
let (cos_data, sin_data) = build_rope_tables(&inv_freq, llama.max_position_embeddings);
let decoder_spec = LlamaDecoderSpec {
num_heads: llama.num_attention_heads,
head_dim: dh,
num_kv_heads: llama.num_key_value_heads,
eps,
mask: MaskKind::Causal,
hidden_shape: hidden_shape.clone(),
};
let kv_sink = SideOutputs::new();
let mut flow = ModelFlow::new("voxtral_prefill")
.with_profile(profile)
.input("inputs_embeds", hidden_shape.clone())
.rope_tables(RopeTablesStage::param(
llama.max_position_embeddings,
inv_freq.len(),
cos_data,
sin_data,
))
.zero_beta_named("voxtral.zero_beta.hidden", h);
let export = with_kv_outputs;
flow = flow.repeat_layers(llama.num_hidden_layers, {
let spec = decoder_spec.clone();
let sink = kv_sink.clone();
move |i| {
let mut stages = Vec::new();
if export {
stages.push(FlowStage::LlamaKvTap(
rlx_flow::blocks::LlamaKvTapStage::layer(i, dh, eps, sink.inner()),
));
}
stages.push(llama_prefill_layer_fused(i, spec.clone()));
if stages.len() == 1 {
stages.into_iter().next().unwrap()
} else {
FlowStage::Sequence(stages)
}
}
});
if last_logits_only {
flow = flow.gather_last_token_at(batch, seq);
}
flow = flow.final_norm(eps);
let mut prefixed = LanguageModelPrefixLoader::new(weights);
let mut built = flow
.lm_head(llama.vocab_size, h, false)
.build(&mut WeightLoaderSource(&mut prefixed))?;
if with_kv_outputs {
built = built.with_extra_hir_outputs(kv_sink.drain());
}
Ok(built)
}
pub fn build_voxtral_decode_built(
cfg: &VoxtralConfig,
weights: &mut WeightMap,
batch: usize,
past_seq: usize,
dynamic_past: bool,
) -> Result<BuiltModel> {
let opts = Llama32DecodeOpts {
batch,
past_seq,
dynamic_past,
use_custom_mask: false,
profile: None,
};
let mut prefixed = LanguageModelPrefixLoader::new(weights);
build_llama32_decode_built(cfg.llama_config(), &mut prefixed, &opts)
}