use crate::error::Result;
use crate::gguf::ops;
use crate::gguf::qwen3_moe_load::{moe_ffn_forward_layer, Qwen3MoeQuantizedLayer};
use crate::gguf::OwnedQuantizedModel;
impl OwnedQuantizedModel {
#[allow(clippy::too_many_arguments)]
pub fn forward_qwen3_moe(
&self,
token_ids: &[u32],
moe_layers: &[Qwen3MoeQuantizedLayer],
num_experts: usize,
num_experts_per_tok: usize,
moe_intermediate: usize,
data: &[u8],
) -> Result<Vec<f32>> {
let hidden_dim = self.config.hidden_dim;
if moe_layers.len() != self.layers.len() {
return Err(crate::error::RealizarError::InvalidShape {
reason: format!(
"forward_qwen3_moe: moe_layers.len() = {} but model has {} decoder layers",
moe_layers.len(),
self.layers.len()
),
});
}
if num_experts == 0 || num_experts_per_tok == 0 || moe_intermediate == 0 {
return Err(crate::error::RealizarError::InvalidShape {
reason: format!(
"forward_qwen3_moe: incomplete MoE config — num_experts={num_experts}, \
num_experts_per_tok={num_experts_per_tok}, moe_intermediate={moe_intermediate}. \
Caller must supply all three from GGUF metadata."
),
});
}
let mut hidden = self.embed(token_ids);
if self.config.constraints.uses_absolute_positions() {
if let Some(ref pos_emb) = self.position_embedding {
for (s, _) in token_ids.iter().enumerate() {
let pos_start = s * hidden_dim;
let pos_end = pos_start + hidden_dim;
if pos_end <= pos_emb.len() {
let h_start = s * hidden_dim;
for i in 0..hidden_dim {
hidden[h_start + i] += pos_emb[pos_start + i];
}
}
}
}
}
let use_rmsnorm = self.config.constraints.uses_rmsnorm();
let intermediate = moe_intermediate;
for (layer_idx, layer) in self.layers.iter().enumerate() {
let normed = if use_rmsnorm {
ops::rms_norm(&hidden, &layer.attn_norm_weight, self.config.eps)
} else {
ops::layer_norm(
&hidden,
&layer.attn_norm_weight,
layer.attn_norm_bias.as_deref(),
self.config.eps,
)
};
let qkv_dim = layer.qkv_weight.out_dim();
let q_dim = layer.qkv_weight.q_dim_for_config(
self.config.num_heads,
self.config.num_kv_heads,
self.config.hidden_dim,
self.config.head_dim(),
);
let k_dim = layer.qkv_weight.k_dim_for_config(
self.config.num_heads,
self.config.num_kv_heads,
self.config.hidden_dim,
self.config.head_dim(),
);
let v_dim = layer.qkv_weight.v_dim_for_config(
self.config.num_heads,
self.config.num_kv_heads,
self.config.hidden_dim,
self.config.head_dim(),
);
let mut qkv = self.qkv_matmul(&normed, &layer.qkv_weight)?;
if let Some(ref bias) = layer.qkv_bias {
ops::add_bias(&mut qkv, bias);
}
let seq_len = token_ids.len();
let mut q_all = Vec::with_capacity(seq_len * q_dim);
let mut k_all = Vec::with_capacity(seq_len * k_dim);
let mut v_all = Vec::with_capacity(seq_len * v_dim);
for s in 0..seq_len {
let qkv_start = s * qkv_dim;
let mut q = qkv[qkv_start..qkv_start + q_dim].to_vec();
let mut k = qkv[qkv_start + q_dim..qkv_start + q_dim + k_dim].to_vec();
let v = &qkv[qkv_start + q_dim + k_dim..qkv_start + q_dim + k_dim + v_dim];
if let Some(ref q_norm) = layer.attn_q_norm_weight {
ops::apply_per_head_rms_norm(
&mut q,
q_norm,
self.config.num_heads,
self.config.eps,
);
}
if let Some(ref k_norm) = layer.attn_k_norm_weight {
ops::apply_per_head_rms_norm(
&mut k,
k_norm,
self.config.num_kv_heads,
self.config.eps,
);
}
if self.config.constraints.uses_rope() {
self.apply_rope(&mut q, s, self.config.num_heads);
self.apply_rope(&mut k, s, self.config.num_kv_heads);
}
q_all.extend_from_slice(&q);
k_all.extend_from_slice(&k);
v_all.extend_from_slice(v);
}
let attn_out = self.causal_attention(&q_all, &k_all, &v_all, seq_len);
let mut attn_output = self.fused_matmul(&attn_out, &layer.attn_output_weight)?;
if let Some(ref bias) = layer.attn_output_bias {
ops::add_bias(&mut attn_output, bias);
}
for i in 0..hidden.len() {
hidden[i] += attn_output[i];
}
let ffn_input = if let Some(ref ffn_norm) = layer.ffn_norm_weight {
if use_rmsnorm {
ops::rms_norm(&hidden, ffn_norm, self.config.eps)
} else {
ops::layer_norm(
&hidden,
ffn_norm,
layer.ffn_norm_bias.as_deref(),
self.config.eps,
)
}
} else {
hidden.clone()
};
let mut ffn_output = vec![0.0f32; seq_len * hidden_dim];
for s in 0..seq_len {
let pos_in = &ffn_input[s * hidden_dim..(s + 1) * hidden_dim];
let pos_out = moe_ffn_forward_layer(
pos_in,
&moe_layers[layer_idx],
num_experts,
num_experts_per_tok,
intermediate,
hidden_dim,
data,
)?;
ffn_output[s * hidden_dim..(s + 1) * hidden_dim].copy_from_slice(&pos_out);
}
for i in 0..hidden.len() {
hidden[i] += ffn_output[i];
}
}
let normed = if use_rmsnorm {
ops::rms_norm(&hidden, &self.output_norm_weight, self.config.eps)
} else {
ops::layer_norm(
&hidden,
&self.output_norm_weight,
self.output_norm_bias.as_deref(),
self.config.eps,
)
};
let seq_len = token_ids.len();
let last_start = (seq_len - 1) * hidden_dim;
let last_hidden = &normed[last_start..last_start + hidden_dim];
let mut logits = self.fused_matmul(last_hidden, &self.lm_head_weight)?;
if let Some(ref bias) = self.lm_head_bias {
ops::add_bias(&mut logits, bias);
}
Ok(logits)
}
}