use crate::gguf::qwen3_moe_load::Qwen3MoeQuantizedLayer;
impl OwnedQuantizedModelCuda {
#[allow(clippy::too_many_arguments)]
pub fn forward_qwen3_moe_cuda(
&mut self,
token_ids: &[u32],
moe_layers: &[Qwen3MoeQuantizedLayer],
num_experts: usize,
num_experts_per_tok: usize,
moe_intermediate: usize,
data: &[u8],
) -> Result<Vec<f32>> {
if token_ids.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "forward_qwen3_moe_cuda: token_ids must not be empty".to_string(),
});
}
if moe_layers.len() != self.model.layers.len() {
return Err(RealizarError::InvalidShape {
reason: format!(
"forward_qwen3_moe_cuda: moe_layers.len() = {} but model has {} decoder layers",
moe_layers.len(),
self.model.layers.len()
),
});
}
if num_experts == 0 || num_experts_per_tok == 0 || moe_intermediate == 0 {
return Err(RealizarError::InvalidShape {
reason: format!(
"forward_qwen3_moe_cuda: incomplete MoE config — num_experts={num_experts}, \
num_experts_per_tok={num_experts_per_tok}, moe_intermediate={moe_intermediate}. \
Caller must supply all three from GGUF metadata."
),
});
}
if num_experts_per_tok > num_experts {
return Err(RealizarError::InvalidShape {
reason: format!(
"forward_qwen3_moe_cuda: num_experts_per_tok ({num_experts_per_tok}) \
exceeds num_experts ({num_experts})"
),
});
}
let hidden_dim = self.model.config.hidden_dim;
let intermediate = moe_intermediate;
let use_rmsnorm = self.model.config.constraints.uses_rmsnorm();
let mut hidden = self.model.embed(token_ids);
if self.model.config.constraints.uses_absolute_positions() {
if let Some(ref pos_emb) = self.model.position_embedding {
for s in 0..token_ids.len() {
let pos_start = s * hidden_dim;
let pos_end = pos_start + hidden_dim;
if pos_end <= pos_emb.len() {
let h_start = s * hidden_dim;
for i in 0..hidden_dim {
hidden[h_start + i] += pos_emb[pos_start + i];
}
}
}
}
}
for (layer_idx, layer) in self.model.layers.iter().enumerate() {
let normed = if use_rmsnorm {
crate::gguf::ops::rms_norm(&hidden, &layer.attn_norm_weight, self.model.config.eps)
} else {
crate::gguf::ops::layer_norm(
&hidden,
&layer.attn_norm_weight,
layer.attn_norm_bias.as_deref(),
self.model.config.eps,
)
};
let qkv_dim = layer.qkv_weight.out_dim();
let q_dim = layer.qkv_weight.q_dim_for_config(
self.model.config.num_heads,
self.model.config.num_kv_heads,
self.model.config.hidden_dim,
self.model.config.head_dim(),
);
let k_dim = layer.qkv_weight.k_dim_for_config(
self.model.config.num_heads,
self.model.config.num_kv_heads,
self.model.config.hidden_dim,
self.model.config.head_dim(),
);
let v_dim = layer.qkv_weight.v_dim_for_config(
self.model.config.num_heads,
self.model.config.num_kv_heads,
self.model.config.hidden_dim,
self.model.config.head_dim(),
);
let mut qkv = self.model.qkv_matmul(&normed, &layer.qkv_weight)?;
if let Some(ref bias) = layer.qkv_bias {
crate::gguf::ops::add_bias(&mut qkv, bias);
}
let seq_len = token_ids.len();
let mut q_all = Vec::with_capacity(seq_len * q_dim);
let mut k_all = Vec::with_capacity(seq_len * k_dim);
let mut v_all = Vec::with_capacity(seq_len * v_dim);
for s in 0..seq_len {
let qkv_start = s * qkv_dim;
let mut q = qkv[qkv_start..qkv_start + q_dim].to_vec();
let mut k = qkv[qkv_start + q_dim..qkv_start + q_dim + k_dim].to_vec();
let v = &qkv[qkv_start + q_dim + k_dim..qkv_start + q_dim + k_dim + v_dim];
if let Some(ref q_norm) = layer.attn_q_norm_weight {
crate::gguf::ops::apply_per_head_rms_norm(
&mut q,
q_norm,
self.model.config.num_heads,
self.model.config.eps,
);
}
if let Some(ref k_norm) = layer.attn_k_norm_weight {
crate::gguf::ops::apply_per_head_rms_norm(
&mut k,
k_norm,
self.model.config.num_kv_heads,
self.model.config.eps,
);
}
if self.model.config.constraints.uses_rope() {
self.model.apply_rope(&mut q, s, self.model.config.num_heads);
self.model.apply_rope(&mut k, s, self.model.config.num_kv_heads);
}
q_all.extend_from_slice(&q);
k_all.extend_from_slice(&k);
v_all.extend_from_slice(v);
}
let attn_out = self.model.causal_attention(&q_all, &k_all, &v_all, seq_len);
let mut attn_output = self.model.fused_matmul(&attn_out, &layer.attn_output_weight)?;
if let Some(ref bias) = layer.attn_output_bias {
crate::gguf::ops::add_bias(&mut attn_output, bias);
}
for i in 0..hidden.len() {
hidden[i] += attn_output[i];
}
let ffn_input = if let Some(ref ffn_norm) = layer.ffn_norm_weight {
if use_rmsnorm {
crate::gguf::ops::rms_norm(&hidden, ffn_norm, self.model.config.eps)
} else {
crate::gguf::ops::layer_norm(
&hidden,
ffn_norm,
layer.ffn_norm_bias.as_deref(),
self.model.config.eps,
)
}
} else {
hidden.clone()
};
let mut ffn_output = vec![0.0f32; seq_len * hidden_dim];
for s in 0..seq_len {
let pos_in = &ffn_input[s * hidden_dim..(s + 1) * hidden_dim];
let pos_out = moe_ffn_forward_layer_cuda(
&mut self.executor,
pos_in,
&moe_layers[layer_idx],
num_experts,
num_experts_per_tok,
intermediate,
hidden_dim,
data,
)?;
ffn_output[s * hidden_dim..(s + 1) * hidden_dim].copy_from_slice(&pos_out);
}
for i in 0..hidden.len() {
hidden[i] += ffn_output[i];
}
}
let normed = if use_rmsnorm {
crate::gguf::ops::rms_norm(&hidden, &self.model.output_norm_weight, self.model.config.eps)
} else {
crate::gguf::ops::layer_norm(
&hidden,
&self.model.output_norm_weight,
self.model.output_norm_bias.as_deref(),
self.model.config.eps,
)
};
let seq_len = token_ids.len();
let last_start = (seq_len - 1) * hidden_dim;
let last_hidden = &normed[last_start..last_start + hidden_dim];
let mut logits = self.model.fused_matmul(last_hidden, &self.model.lm_head_weight)?;
if let Some(ref bias) = self.model.lm_head_bias {
crate::gguf::ops::add_bias(&mut logits, bias);
}
Ok(logits)
}
}
#[cfg(test)]
mod tests {
#[test]
fn forward_qwen3_moe_cuda_stub_compiles_with_correct_signature() {
}
}