use crate::config::GemmaConfig;
use anyhow::{Result, anyhow};
use rlx_core::weight_loader::WeightLoader;
use rlx_ir::Graph;
use rlx_ir::hir::HirModule;
use std::collections::HashMap;
pub fn build_gemma_graph_sized(
cfg: &GemmaConfig,
weights: &mut dyn WeightLoader,
batch: usize,
seq: usize,
with_lm_head: bool,
with_kv_outputs: bool,
) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
let opts = crate::flow::GemmaPrefillOpts {
batch,
seq,
dynamic_seq: false,
with_lm_head,
with_kv_outputs,
last_logits_only: false,
profile: None,
};
rlx_core::flow_util::graph_from_built(crate::flow::build_gemma_prefill_built(
cfg, weights, &opts,
)?)
}
pub fn build_gemma_graph_sized_last_logits(
cfg: &GemmaConfig,
weights: &mut dyn WeightLoader,
batch: usize,
seq: usize,
with_kv_outputs: bool,
) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
let opts = crate::flow::GemmaPrefillOpts {
batch,
seq,
dynamic_seq: false,
with_lm_head: true,
with_kv_outputs,
last_logits_only: true,
profile: None,
};
rlx_core::flow_util::graph_from_built(crate::flow::build_gemma_prefill_built(
cfg, weights, &opts,
)?)
}
pub fn build_gemma_prefill_hir_dynamic_ext(
cfg: &GemmaConfig,
weights: &mut dyn WeightLoader,
batch: usize,
max_seq: usize,
with_kv_outputs: bool,
) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
validate_cfg(cfg)?;
if batch != 1 {
return Err(anyhow!("gemma: dynamic_seq prefill requires batch=1"));
}
let opts = crate::flow::GemmaPrefillOpts {
batch,
seq: max_seq,
dynamic_seq: true,
with_lm_head: true,
with_kv_outputs,
last_logits_only: true,
profile: None,
};
crate::flow::build_gemma_prefill_flow(cfg, weights, &opts)
}
pub fn build_gemma_decode_graph_sized(
cfg: &GemmaConfig,
weights: &mut dyn WeightLoader,
batch: usize,
past_seq: usize,
) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
build_gemma_decode_graph_sized_ext(cfg, weights, batch, past_seq, false)
}
pub fn build_gemma_decode_graph_sized_ext(
cfg: &GemmaConfig,
weights: &mut dyn WeightLoader,
batch: usize,
past_seq: usize,
use_custom_mask: bool,
) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
let opts = crate::flow::GemmaDecodeOpts {
batch,
past_seq,
dynamic_past: false,
use_custom_mask,
profile: None,
};
crate::flow::build_gemma_decode_graph(cfg, weights, &opts)
}
pub fn build_gemma_decode_hir_sized(
cfg: &GemmaConfig,
weights: &mut dyn WeightLoader,
batch: usize,
past_seq: usize,
) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
build_gemma_decode_hir_sized_ext(cfg, weights, batch, past_seq, false)
}
pub fn build_gemma_decode_hir_sized_ext(
cfg: &GemmaConfig,
weights: &mut dyn WeightLoader,
batch: usize,
past_seq: usize,
use_custom_mask: bool,
) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
validate_cfg(cfg)?;
let opts = crate::flow::GemmaDecodeOpts {
batch,
past_seq,
dynamic_past: false,
use_custom_mask,
profile: None,
};
crate::flow::build_gemma_decode_flow(cfg, weights, &opts)
}
pub fn build_gemma_decode_hir_dynamic_ext(
cfg: &GemmaConfig,
weights: &mut dyn WeightLoader,
batch: usize,
max_past_seq: usize,
) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
validate_cfg(cfg)?;
let opts = crate::flow::GemmaDecodeOpts {
batch,
past_seq: max_past_seq,
dynamic_past: true,
use_custom_mask: false,
profile: None,
};
crate::flow::build_gemma_decode_flow(cfg, weights, &opts)
}
#[allow(clippy::too_many_arguments)]
pub fn build_gemma_graph_sized_packed(
cfg: &GemmaConfig,
_weights: &mut rlx_core::weight_loader::GgufLoader,
_batch: usize,
_seq: usize,
_with_lm_head: bool,
_last_logits_only: bool,
_packed: &mut HashMap<String, (Vec<u8>, rlx_ir::quant::QuantScheme, Vec<usize>)>,
) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
validate_cfg(cfg)?;
Err(anyhow!(
"packed gemma prefill graphs are not implemented yet; use standard GGUF drain + GemmaFlow"
))
}
fn validate_cfg(cfg: &GemmaConfig) -> Result<()> {
if !cfg
.num_attention_heads
.is_multiple_of(cfg.num_key_value_heads)
{
return Err(anyhow!(
"num_attention_heads ({}) must be divisible by num_key_value_heads ({})",
cfg.num_attention_heads,
cfg.num_key_value_heads
));
}
if cfg.attention_bias {
return Err(anyhow!("attention_bias=true not yet wired for gemma"));
}
Ok(())
}