use crate::error::{RealizarError, Result};
use crate::gguf::qwen3_moe_load::load_qwen3_moe_layer;
use crate::gguf::{MappedGGUFModel, OwnedQuantizedModel, QuantizedGenerateConfig};
pub fn run_qwen3_moe_generate(
mapped: &MappedGGUFModel,
model: &OwnedQuantizedModel,
input_tokens: &[u32],
gen_config: &QuantizedGenerateConfig,
) -> Result<Vec<u32>> {
if input_tokens.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "run_qwen3_moe_generate: prompt cannot be empty".to_string(),
});
}
let canonical_arch = crate::tensor_names::normalize_architecture(&model.config().architecture);
if canonical_arch != "qwen3_moe" {
return Err(RealizarError::InvalidShape {
reason: format!(
"run_qwen3_moe_generate: arch '{}' (canonical '{}') is not qwen3_moe — \
caller should dispatch to run_gguf_generate instead",
model.config().architecture,
canonical_arch
),
});
}
let num_experts = mapped
.model
.expert_count()
.ok_or_else(|| RealizarError::InvalidShape {
reason: format!(
"run_qwen3_moe_generate: missing '{}.expert_count' in GGUF metadata",
model.config().architecture
),
})?;
let num_experts_per_tok =
mapped
.model
.expert_used_count()
.ok_or_else(|| RealizarError::InvalidShape {
reason: format!(
"run_qwen3_moe_generate: missing '{}.expert_used_count' in GGUF metadata",
model.config().architecture
),
})?;
let moe_intermediate =
mapped
.model
.expert_feed_forward_length()
.ok_or_else(|| RealizarError::InvalidShape {
reason: format!(
"run_qwen3_moe_generate: missing '{}.expert_feed_forward_length' in GGUF metadata",
model.config().architecture
),
})?;
let data = mapped.data();
let num_layers = model.config().num_layers;
let mut moe_layers = Vec::with_capacity(num_layers);
for layer_idx in 0..num_layers {
moe_layers.push(load_qwen3_moe_layer(&mapped.model, data, layer_idx)?);
}
let mut tokens = input_tokens.to_vec();
for _step in 0..gen_config.max_tokens {
let logits = model.forward_qwen3_moe(
&tokens,
&moe_layers,
num_experts,
num_experts_per_tok,
moe_intermediate,
data,
)?;
let next_token = logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i as u32)
.ok_or_else(|| RealizarError::InvalidShape {
reason: "run_qwen3_moe_generate: empty logits vector".to_string(),
})?;
tokens.push(next_token);
if gen_config.stop_tokens.contains(&next_token) {
break;
}
}
Ok(tokens)
}