use ferrum_kernels::backend::cpu::CpuBackend;
use ferrum_kernels::backend::Backend;
use ferrum_kernels::Linear;
use ferrum_quantization::gguf::{GgufFile, GgufLinear};
use ferrum_types::{FerrumError, Result};
use crate::moe::dispatch::{moe_forward_cpu, ExpertStack};
use crate::moe::router::route;
use crate::moe_config::Qwen3MoeConfig;
pub struct Qwen3MoeLayer<B: Backend> {
pub router: Box<dyn Linear<B>>,
pub experts: ExpertStack<B>,
pub top_k: usize,
pub norm_topk_prob: bool,
pub hidden_size: usize,
pub expert_intermediate: usize,
pub num_experts: usize,
}
impl<B: Backend> Qwen3MoeLayer<B> {
pub fn load_from_gguf(gguf: &GgufFile, layer_idx: usize, cfg: &Qwen3MoeConfig) -> Result<Self> {
let router_name = format!("blk.{layer_idx}.ffn_gate_inp.weight");
if !gguf.has_tensor(&router_name) {
return Err(FerrumError::model(format!(
"Qwen3MoeLayer: router tensor '{router_name}' not in GGUF"
)));
}
let router_qt = gguf
.read_tensor(&router_name, &candle_core::Device::Cpu)
.map_err(|e| FerrumError::model(format!("read router: {e}")))?;
let router = GgufLinear::<B>::from_qtensor(&router_qt)
.map_err(|e| FerrumError::model(format!("router from_qtensor: {e}")))?;
let router: Box<dyn Linear<B>> = Box::new(router);
let experts = ExpertStack::<B>::load_from_gguf(
gguf,
layer_idx,
cfg.num_experts,
cfg.base.hidden_size,
cfg.expert_intermediate_size,
)?;
if router.in_features() != cfg.base.hidden_size {
return Err(FerrumError::model(format!(
"router in_features {} != hidden_size {}",
router.in_features(),
cfg.base.hidden_size
)));
}
if router.out_features() != cfg.num_experts {
return Err(FerrumError::model(format!(
"router out_features {} != num_experts {}",
router.out_features(),
cfg.num_experts
)));
}
Ok(Self {
router,
experts,
top_k: cfg.num_experts_per_tok,
norm_topk_prob: cfg.norm_topk_prob,
hidden_size: cfg.base.hidden_size,
expert_intermediate: cfg.expert_intermediate_size,
num_experts: cfg.num_experts,
})
}
}
impl Qwen3MoeLayer<CpuBackend> {
pub fn forward_cpu(&self, x: &[f32], batch: usize, out: &mut Vec<f32>) -> Result<()> {
if x.len() != batch * self.hidden_size {
return Err(FerrumError::model(format!(
"Qwen3MoeLayer::forward_cpu: x len {} != batch*hidden = {}*{} = {}",
x.len(),
batch,
self.hidden_size,
batch * self.hidden_size
)));
}
let mut router_logits: Vec<f32> = vec![0.0; batch * self.num_experts];
let mut ctx = <CpuBackend as Backend>::new_context();
let x_buf: Vec<f32> = x.to_vec();
self.router
.forward(&mut ctx, &x_buf, &mut router_logits, batch);
let router_out = route(
&router_logits,
batch,
self.num_experts,
self.top_k,
self.norm_topk_prob,
);
moe_forward_cpu(
x,
batch,
self.hidden_size,
self.expert_intermediate,
self.top_k,
&router_out,
&self.experts,
out,
)
}
}