use candle_core::quantized::gguf_file;
use candle_core::{DType, Device as CandleDevice, Result, Tensor, D};
use candle_nn::kv_cache::ConcatKvCache;
use candle_nn::{Embedding, Linear, Module};
use candle_transformers::models::quantized_qwen3::{Gguf, RotaryEmbedding};
use candle_transformers::models::with_tracing::QMatMul;
use candle_transformers::quantized_nn::RmsNorm;
use candle_transformers::utils::repeat_kv;
use std::sync::Arc;
#[allow(unused_imports)]
use candle_core::quantized::QTensor;
struct Mlp {
gate: QMatMul,
up: QMatMul,
down: QMatMul,
}
impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let gate = self.gate.forward(xs)?;
let up = self.up.forward(xs)?;
self.down.forward(&(candle_nn::ops::silu(&gate)? * up)?)
}
}
struct NaiveMoe {
gate: Linear,
gate_experts: Arc<QTensor>,
up_experts: Arc<QTensor>,
down_experts: Arc<QTensor>,
cache: Option<(Tensor, Tensor, Tensor)>,
num_experts: usize,
num_experts_per_tok: usize,
norm_topk_prob: bool,
}
impl NaiveMoe {
fn ensure_cache(&mut self, device: &CandleDevice) -> Result<()> {
if self.cache.is_some() {
return Ok(());
}
let gate_w = self.gate_experts.dequantize(device)?.to_dtype(DType::F16)?;
let up_w = self.up_experts.dequantize(device)?.to_dtype(DType::F16)?;
let down_w = self.down_experts.dequantize(device)?.to_dtype(DType::F16)?;
self.cache = Some((gate_w, up_w, down_w));
Ok(())
}
fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
let (num_tokens, hidden_dim) = xs.dims2()?;
let device = xs.device();
let router_logits = self.gate.forward(xs)?;
let routing_weights =
candle_nn::ops::softmax_last_dim(&router_logits.to_dtype(DType::F32)?)?;
let topk_ids = routing_weights
.arg_sort_last_dim(false)?
.narrow(D::Minus1, 0, self.num_experts_per_tok)?
.contiguous()?;
let mut topk_weights = routing_weights.gather(&topk_ids, D::Minus1)?;
if self.norm_topk_prob {
topk_weights = topk_weights.broadcast_div(&topk_weights.sum_keepdim(D::Minus1)?)?;
}
let topk_ids_vec: Vec<u32> = topk_ids.flatten_all()?.to_vec1()?;
let topk_weights_vec: Vec<f32> = topk_weights.flatten_all()?.to_vec1()?;
self.ensure_cache(device)?;
let (gate_w, up_w, down_w) = self.cache.as_ref().unwrap();
let mut outputs = Vec::with_capacity(num_tokens);
for t in 0..num_tokens {
let token = xs.narrow(0, t, 1)?.contiguous()?.to_dtype(DType::F16)?;
let mut expert_ids = Vec::with_capacity(self.num_experts_per_tok);
let mut weights = Vec::with_capacity(self.num_experts_per_tok);
for k in 0..self.num_experts_per_tok {
let idx = t * self.num_experts_per_tok + k;
let eid = topk_ids_vec[idx] as usize;
if eid < self.num_experts {
expert_ids.push(eid);
weights.push(topk_weights_vec[idx]);
}
}
if expert_ids.is_empty() {
outputs.push(Tensor::zeros((1, hidden_dim), DType::F32, device)?);
continue;
}
let gate_selected: Vec<Tensor> = expert_ids
.iter()
.map(|&eid| gate_w.narrow(0, eid, 1))
.collect::<Result<Vec<_>>>()?;
let gate_batch = Tensor::cat(&gate_selected, 0)?;
let up_selected: Vec<Tensor> = expert_ids
.iter()
.map(|&eid| up_w.narrow(0, eid, 1))
.collect::<Result<Vec<_>>>()?;
let up_batch = Tensor::cat(&up_selected, 0)?;
let k = expert_ids.len();
let token_k = token.unsqueeze(0)?.expand((k, 1, hidden_dim))?;
let gate_t = gate_batch.transpose(1, 2)?; let gate_out = token_k.matmul(&gate_t)?;
let up_t = up_batch.transpose(1, 2)?;
let up_out = token_k.matmul(&up_t)?;
let activated = candle_nn::ops::silu(&gate_out)?.mul(&up_out)?;
let down_selected: Vec<Tensor> = expert_ids
.iter()
.map(|&eid| down_w.narrow(0, eid, 1))
.collect::<Result<Vec<_>>>()?;
let down_batch = Tensor::cat(&down_selected, 0)?; let down_t = down_batch.transpose(1, 2)?; let expert_outs = activated.matmul(&down_t)?;
let expert_outs_f32 = expert_outs.to_dtype(DType::F32)?;
let weights_t = Tensor::from_slice(&weights, (k, 1, 1), device)?;
let weighted = expert_outs_f32.broadcast_mul(&weights_t)?; let combined = weighted.sum(0)?;
outputs.push(combined);
}
Tensor::cat(&outputs, 0)
}
}
enum MoeOrMlp {
Moe(NaiveMoe),
Mlp(Mlp),
}
impl MoeOrMlp {
fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::Mlp(m) => m.forward(xs),
Self::Moe(m) => m.forward(xs),
}
}
}
struct Attention {
q_proj: QMatMul,
k_proj: QMatMul,
v_proj: QMatMul,
o_proj: QMatMul,
q_norm: RmsNorm,
k_norm: RmsNorm,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
rotary: Arc<RotaryEmbedding>,
kv_cache: ConcatKvCache,
}
impl Attention {
fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result<Tensor> {
let (b, l, _) = x.dims3()?;
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
let v = self.v_proj.forward(x)?;
let q = q
.reshape((b, l, self.num_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let k = k
.reshape((b, l, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let v = v
.reshape((b, l, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let q = self.q_norm.forward(&q)?;
let k = self.k_norm.forward(&k)?;
let (q, k) = self.rotary.apply(&q, &k, offset)?;
let (k, v) = self.kv_cache.append(&k, &v)?;
let k = repeat_kv(k, self.num_heads / self.num_kv_heads)?;
let v = repeat_kv(v, self.num_heads / self.num_kv_heads)?;
let scale = 1.0 / (self.head_dim as f64).sqrt();
let attn = (q.matmul(&k.t()?)? * scale)?;
let attn = match mask {
Some(m) => attn.broadcast_add(m)?,
None => attn,
};
let attn = candle_nn::ops::softmax_last_dim(&attn)?;
let out = attn.matmul(&v)?;
let out = out.transpose(1, 2)?.reshape((b, l, ()))?;
self.o_proj.forward(&out)
}
}
struct Layer {
attn: Attention,
attn_norm: RmsNorm,
mlp: MoeOrMlp,
ffn_norm: RmsNorm,
}
pub struct Qwen3MoeModel {
embeddings: Embedding,
layers: Vec<Layer>,
norm: RmsNorm,
output: QMatMul,
dtype: DType,
device: CandleDevice,
}
impl Qwen3MoeModel {
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
ct: gguf_file::Content,
reader: &mut R,
device: &CandleDevice,
) -> Result<Self> {
let dtype = DType::F32;
let mut gg = Gguf::new(ct, reader, device.clone());
let arch = "qwen3moe";
let metadata = gg.metadata().clone();
let md_get = |key: &str| -> Result<gguf_file::Value> {
metadata
.get(key)
.cloned()
.ok_or_else(|| candle_core::Error::Msg(format!("cannot find {key} in metadata")))
};
let head_count = md_get(&format!("{arch}.attention.head_count"))?.to_u32()? as usize;
let head_count_kv = md_get(&format!("{arch}.attention.head_count_kv"))?.to_u32()? as usize;
let head_dim = md_get(&format!("{arch}.attention.key_length"))?.to_u32()? as usize;
let block_count = md_get(&format!("{arch}.block_count"))?.to_u32()? as usize;
let embedding_length = md_get(&format!("{arch}.embedding_length"))?.to_u32()? as usize;
let rms_norm_eps =
md_get(&format!("{arch}.attention.layer_norm_rms_epsilon"))?.to_f32()? as f64;
let context_length = md_get(&format!("{arch}.context_length"))?.to_u32()? as usize;
let rope_freq_base = md_get(&format!("{arch}.rope.freq_base"))?.to_f32()? as f64;
let num_experts = md_get(&format!("{arch}.expert_count"))?.to_u32()? as usize;
let num_experts_per_tok = md_get(&format!("{arch}.expert_used_count"))?.to_u32()? as usize;
let tok_embd = gg.tensor("token_embd.weight")?.dequantize(device)?;
let embeddings = Embedding::new(tok_embd, embedding_length);
let output = match gg.qmatmul("output.weight") {
Ok(v) => v,
_ => gg.qmatmul("token_embd.weight")?,
};
let norm = gg.rms_norm("output_norm.weight", rms_norm_eps)?;
let rotary = Arc::new(RotaryEmbedding::new(
dtype,
head_dim,
context_length,
rope_freq_base,
device,
)?);
let mut layers = Vec::with_capacity(block_count);
for i in 0..block_count {
let pfx = format!("blk.{i}");
let q_proj = gg.qmatmul(&format!("{pfx}.attn_q.weight"))?;
let k_proj = gg.qmatmul(&format!("{pfx}.attn_k.weight"))?;
let v_proj = gg.qmatmul(&format!("{pfx}.attn_v.weight"))?;
let o_proj = gg.qmatmul(&format!("{pfx}.attn_output.weight"))?;
let q_norm = gg.rms_norm(&format!("{pfx}.attn_q_norm.weight"), rms_norm_eps)?;
let k_norm = gg.rms_norm(&format!("{pfx}.attn_k_norm.weight"), rms_norm_eps)?;
let attn = Attention {
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
num_heads: head_count,
num_kv_heads: head_count_kv,
head_dim,
rotary: rotary.clone(),
kv_cache: ConcatKvCache::new(2),
};
let attn_norm = gg.rms_norm(&format!("{pfx}.attn_norm.weight"), rms_norm_eps)?;
let ffn_norm = gg.rms_norm(&format!("{pfx}.ffn_norm.weight"), rms_norm_eps)?;
let mlp = if num_experts > 0 {
let gate_ws = gg
.tensor(&format!("{pfx}.ffn_gate_inp.weight"))?
.dequantize(device)?
.to_dtype(DType::F32)?;
let gate = Linear::new(gate_ws, None);
let gate_experts = Arc::new(gg.tensor(&format!("{pfx}.ffn_gate_exps.weight"))?);
let up_experts = Arc::new(gg.tensor(&format!("{pfx}.ffn_up_exps.weight"))?);
let down_experts = Arc::new(gg.tensor(&format!("{pfx}.ffn_down_exps.weight"))?);
MoeOrMlp::Moe(NaiveMoe {
gate,
gate_experts,
up_experts,
down_experts,
cache: None,
num_experts,
num_experts_per_tok,
norm_topk_prob: true,
})
} else {
let gate = gg.qmatmul(&format!("{pfx}.ffn_gate.weight"))?;
let up = gg.qmatmul(&format!("{pfx}.ffn_up.weight"))?;
let down = gg.qmatmul(&format!("{pfx}.ffn_down.weight"))?;
MoeOrMlp::Mlp(Mlp { gate, up, down })
};
layers.push(Layer {
attn,
attn_norm,
mlp,
ffn_norm,
});
}
Ok(Self {
embeddings,
layers,
norm,
output,
dtype,
device: device.clone(),
})
}
pub fn forward(&mut self, x: &Tensor, offset: usize) -> Result<Tensor> {
let mut xs = self.embeddings.forward(x)?;
let (b, l) = x.dims2()?;
let mask = if l == 1 {
None
} else {
Some(self.causal_mask(b, l, offset)?)
};
for layer in self.layers.iter_mut() {
let residual = xs.clone();
let x = layer.attn_norm.forward(&xs)?;
let x = layer.attn.forward(&x, mask.as_ref(), offset)?;
let x = (x + residual)?;
let residual = x.clone();
let ffn_in = layer.ffn_norm.forward(&x)?;
let (fb, fl, fh) = ffn_in.dims3()?;
let ffn_flat = ffn_in.reshape((fb * fl, fh))?;
let ffn_out = layer.mlp.forward(&ffn_flat)?;
let ffn_out = ffn_out.reshape((fb, fl, fh))?;
xs = (ffn_out + residual)?;
}
let xs = xs.narrow(1, l - 1, 1)?;
let xs = self.norm.forward(&xs)?;
self.output.forward(&xs)?.to_dtype(DType::F32)?.squeeze(1)
}
pub fn clear_kv_cache(&mut self) {
for layer in &mut self.layers {
layer.attn.kv_cache = ConcatKvCache::new(2);
}
}
fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> Result<Tensor> {
let minf = f32::NEG_INFINITY;
let mask: Vec<f32> = (0..tgt)
.flat_map(|i| {
(0..(tgt + offset)).map(move |j| if j <= i + offset { 0.0 } else { minf })
})
.collect();
Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
}
}