#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
use std::collections::HashMap;
use std::sync::Arc;
use candle_core::{DType, Device, Result, Tensor};
use candle_nn::{Embedding, Module};
use mistralrs_quant::{GgufMatMul, QuantMethod, QuantMethodConfig};
use crate::attention::SdpaParams;
use crate::device_map::{DeviceMappedMask, DeviceMapper};
use crate::gguf::Content;
use crate::layers::{CausalMasker, MatMul, QRmsNorm, RotaryEmbedding, Sdpa};
use crate::layers_masker::PastKvLenCache;
use crate::paged_attention::{AttentionImplementation, PagedAttention};
use crate::pipeline::text_models_inputs_processor::PagedAttentionInputMetadata;
use crate::pipeline::{extract_logits, EitherCache, KvCache, NormalCache};
use crate::utils::gguf_metadata::ContentMetadata;
use crate::utils::model_config as ModelConfig;
use crate::utils::progress::{new_multi_progress, NiceProgressBar};
const DEFAULT_MAX_SEQ_LEN: u32 = 4096;
struct Mlp {
feed_forward_w1: Arc<dyn QuantMethod>,
feed_forward_w2: Arc<dyn QuantMethod>,
feed_forward_w3: Arc<dyn QuantMethod>,
}
impl Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let w1 = MatMul.qmethod_matmul(xs, &*self.feed_forward_w1)?;
let w3 = MatMul.qmethod_matmul(xs, &*self.feed_forward_w3)?;
let y = crate::ops::mul_and_act(&w1, &w3, crate::layers::Activation::Silu)?;
MatMul.qmethod_matmul(&y, &*self.feed_forward_w2)
}
}
struct LayerWeights {
attention_wq: Arc<dyn QuantMethod>,
attention_wk: Arc<dyn QuantMethod>,
attention_wv: Arc<dyn QuantMethod>,
attention_wo: Arc<dyn QuantMethod>,
attention_norm: QRmsNorm,
q_norm: QRmsNorm,
k_norm: QRmsNorm,
mlp: Mlp,
ffn_norm: QRmsNorm,
n_head: usize,
n_kv_head: usize,
head_dim: usize,
rotary: Arc<RotaryEmbedding>,
paged_attn: Option<PagedAttention>,
sdpa_params: SdpaParams,
dtype: DType,
}
impl LayerWeights {
fn forward_attn(
&self,
x: &Tensor,
mask: Option<&Tensor>,
start_offsets: &[usize],
kv_cache: &mut KvCache,
metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
) -> Result<Tensor> {
let (b_sz, seq_len, _) = x.dims3()?;
let q = MatMul.qmethod_matmul(x, &*self.attention_wq)?;
let k = MatMul.qmethod_matmul(x, &*self.attention_wk)?;
let v = MatMul.qmethod_matmul(x, &*self.attention_wv)?;
let (q, k, v) = if seq_len != 1 {
let q = q
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
.transpose(1, 2)?;
let k = k
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
.transpose(1, 2)?;
let v = v
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
.transpose(1, 2)?;
(q, k, v)
} else {
let q = q.reshape((b_sz, self.n_head, seq_len, self.head_dim))?;
let k = k.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
let v = v.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
(q, k, v)
};
let q_flat = q.flatten(0, 2)?;
let k_flat = k.flatten(0, 2)?;
let q_flat = self.q_norm.forward(&q_flat)?;
let k_flat = self.k_norm.forward(&k_flat)?;
let q = q_flat.reshape((b_sz, self.n_head, seq_len, self.head_dim))?;
let k = k_flat.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
let (q, k) = self.rotary.forward(&q, &k, start_offsets)?;
let (q, k, v) = (
q.to_dtype(self.dtype)?,
k.to_dtype(self.dtype)?,
v.to_dtype(self.dtype)?,
);
let y = match &self.paged_attn {
Some(paged_attn) => {
let ((key_cache, value_cache), input_metadata) = metadata.unwrap();
paged_attn.forward(
&q,
&k,
&v,
mask,
Some(key_cache),
Some(value_cache),
input_metadata,
&self.sdpa_params,
None,
)?
}
None => {
let (k, v) = kv_cache.append(&k, &v)?;
Sdpa.run_attention(&q, &k, &v, mask, None, &self.sdpa_params)?
}
};
let y = if mask.is_some() {
y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?
} else {
y.reshape((b_sz, seq_len, ()))?
};
let y = MatMul.qmethod_matmul(&y.to_dtype(x.dtype())?, &*self.attention_wo)?;
Ok(y)
}
}
pub struct ModelWeights {
tok_embeddings: Embedding,
layers: Vec<LayerWeights>,
norm: QRmsNorm,
output: Arc<dyn QuantMethod>,
pub device: Device,
pub cache: EitherCache,
pub max_seq_len: usize,
mapper: Option<Box<dyn DeviceMapper + Send + Sync>>,
dtype: DType,
}
pub(crate) struct PropsGGUF {
pub head_count: usize,
pub head_count_kv: usize,
pub block_count: usize,
pub embedding_length: usize,
pub rms_norm_eps: f32,
pub max_seq_len: usize,
pub rope_freq_base: f32,
pub key_length: usize,
pub value_length: usize,
}
fn verify_qwen3_arch(
metadata: &HashMap<String, candle_core::quantized::gguf_file::Value>,
) -> Result<String> {
use crate::utils::gguf_metadata::TryValueInto;
let actual_arch: String = metadata
.get("general.architecture")
.cloned()
.try_value_into()?;
if actual_arch != "qwen3" {
candle_core::bail!("Expected `qwen3` architecture, got `{actual_arch}`.");
}
Ok(actual_arch)
}
impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
type Error = anyhow::Error;
fn try_from(c: ContentMetadata) -> std::result::Result<Self, Self::Error> {
let _ = verify_qwen3_arch(c.metadata)?;
let required = [
"attention.head_count",
"attention.head_count_kv",
"block_count",
"embedding_length",
"attention.layer_norm_rms_epsilon",
];
c.has_required_keys(&required)?;
let embed_len = c.get_value::<u32>("embedding_length")? as usize;
let head_count = c.get_value::<u32>("attention.head_count")? as usize;
let props = Self {
head_count,
head_count_kv: c.get_value::<u32>("attention.head_count_kv")? as usize,
block_count: c.get_value::<u32>("block_count")? as usize,
embedding_length: embed_len,
rms_norm_eps: c.get_value("attention.layer_norm_rms_epsilon")?,
max_seq_len: c
.get_value::<u64>("context_length")
.ok()
.unwrap_or(DEFAULT_MAX_SEQ_LEN as u64) as usize,
rope_freq_base: c.get_value("rope.freq_base").ok().unwrap_or(10_000_f32),
key_length: c
.get_value::<u32>("attention.key_length")
.ok()
.map(|x| x as usize)
.unwrap_or(embed_len / head_count),
value_length: c
.get_value::<u32>("attention.value_length")
.ok()
.map(|x| x as usize)
.unwrap_or(embed_len / head_count),
};
Ok(props)
}
}
impl ModelConfig::FromGGUF for ModelWeights {
fn from_gguf<R: std::io::Seek + std::io::Read>(
mut ct: Content<'_, R>,
device: &Device,
mapper: Box<dyn DeviceMapper + Send + Sync>,
attention_mechanism: AttentionImplementation,
dtype: DType,
) -> Result<Self> {
let meta = ct.get_metadata();
let actual_arch = verify_qwen3_arch(meta)?;
let metadata = ContentMetadata {
path_prefix: &actual_arch,
metadata: meta,
};
let PropsGGUF {
head_count,
head_count_kv,
block_count,
embedding_length,
rms_norm_eps,
max_seq_len,
rope_freq_base,
key_length,
value_length,
} = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
let qtok_embeddings = ct.tensor("token_embd.weight", device)?;
let tok_embeddings = qtok_embeddings.dequantize(device)?;
let norm = QRmsNorm::new(ct.tensor("output_norm.weight", device)?, rms_norm_eps)?;
let output = if !ct.has_tensor("output.weight") {
ct.tensor("token_embd.weight", device)?
} else {
ct.tensor("output.weight", device)?
};
let mut layers = Vec::with_capacity(block_count);
let head_dim = key_length;
if key_length != value_length {
candle_core::bail!(
"Expected key_length == value_length, got {key_length} != {value_length}"
);
}
let mut ropes = HashMap::new();
for layer_idx in 0..block_count {
let device = mapper.device_for(layer_idx, false).unwrap_or(device);
ropes.insert(
device.location(),
Arc::new(RotaryEmbedding::new(
rope_freq_base,
head_dim,
max_seq_len,
device,
true,
DType::F32,
)?),
);
}
for layer_idx in NiceProgressBar::<_, 'b'>(
0..block_count,
"Loading repeating layers",
&new_multi_progress(),
) {
let prefix = format!("blk.{layer_idx}");
let device = mapper.device_for(layer_idx, false).unwrap_or(device);
let rotary = ropes
.get(&device.location())
.expect("No RoPE for device location!")
.clone();
let attention_wq = ct.tensor(&format!("{prefix}.attn_q.weight"), device)?;
let attention_wk = ct.tensor(&format!("{prefix}.attn_k.weight"), device)?;
let attention_wv = ct.tensor(&format!("{prefix}.attn_v.weight"), device)?;
let attention_wo = ct.tensor(&format!("{prefix}.attn_output.weight"), device)?;
let feed_forward_w1 = ct.tensor(&format!("{prefix}.ffn_gate.weight"), device)?;
let feed_forward_w2 = ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?;
let feed_forward_w3 = ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?;
let mlp = Mlp {
feed_forward_w1: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
q_weight: Arc::new(feed_forward_w1),
b: None,
})?),
feed_forward_w2: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
q_weight: Arc::new(feed_forward_w2),
b: None,
})?),
feed_forward_w3: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
q_weight: Arc::new(feed_forward_w3),
b: None,
})?),
};
let q_norm = QRmsNorm::new(
ct.tensor(&format!("{prefix}.attn_q_norm.weight"), device)?,
rms_norm_eps,
)?;
let k_norm = QRmsNorm::new(
ct.tensor(&format!("{prefix}.attn_k_norm.weight"), device)?,
rms_norm_eps,
)?;
let attention_norm = ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?;
let ffn_norm = ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?;
let paged_attn = match &attention_mechanism {
AttentionImplementation::Eager => None,
AttentionImplementation::PagedAttention => {
Some(PagedAttention::new(head_dim, device, None)?)
}
};
layers.push(LayerWeights {
attention_wq: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
q_weight: Arc::new(attention_wq),
b: None,
})?),
attention_wk: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
q_weight: Arc::new(attention_wk),
b: None,
})?),
attention_wv: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
q_weight: Arc::new(attention_wv),
b: None,
})?),
attention_wo: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
q_weight: Arc::new(attention_wo),
b: None,
})?),
attention_norm: QRmsNorm::new(attention_norm, rms_norm_eps)?,
q_norm,
k_norm,
mlp,
ffn_norm: QRmsNorm::new(ffn_norm, rms_norm_eps)?,
n_head: head_count,
n_kv_head: head_count_kv,
head_dim,
rotary: rotary.clone(),
paged_attn,
sdpa_params: SdpaParams {
n_kv_groups: head_count / head_count_kv,
softcap: None,
softmax_scale: 1.0 / (head_dim as f32).sqrt(),
sliding_window: None,
sinks: None,
},
dtype,
})
}
Ok(Self {
tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
layers,
norm,
output: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
q_weight: Arc::new(output),
b: None,
})?),
device: device.clone(),
cache: EitherCache::Normal(NormalCache::new(block_count, max_seq_len)),
max_seq_len,
mapper: Some(mapper),
dtype,
})
}
}
impl ModelWeights {
pub fn forward(
&self,
x: &Tensor,
start_offsets: &[usize],
context_lens: Vec<(usize, usize)>,
metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
) -> Result<Tensor> {
let mut layer_in = self.tok_embeddings.forward(x)?;
let cache = &mut self.cache.normal().0;
let mask = CausalMasker.make_causal_mask_matrix(
x,
metadata
.as_ref()
.map(|(_, _)| &start_offsets as &dyn PastKvLenCache)
.unwrap_or(cache as &dyn PastKvLenCache),
self.dtype,
self.layers[0].n_head,
)?;
let mask = mask.filter(|_| {
metadata
.as_ref()
.map(|(_, meta)| meta.is_first_prompt_chunk)
.unwrap_or(true)
});
let mask = if let Some(ref mapper) = self.mapper {
DeviceMappedMask::new(mask, &**mapper)?
} else {
DeviceMappedMask::from_single(mask)
};
for (i, layer) in self.layers.iter().enumerate() {
if let Some(ref mapper) = self.mapper {
layer_in = mapper.map(layer_in, i)?;
}
let x = layer_in;
let residual = &x;
let x = layer.attention_norm.forward(&x)?;
let attn = layer.forward_attn(
&x,
mask.as_ref().map(|m| m.get(x.device())),
start_offsets,
&mut cache[i],
metadata
.as_ref()
.map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
)?;
let x = (attn + residual)?;
let residual = &x;
let x = layer.ffn_norm.forward(&x)?;
let x = layer.mlp.forward(&x)?;
let x = (x + residual)?;
layer_in = x;
}
let x = self.norm.forward(&layer_in)?;
let x = extract_logits(&x, context_lens)?;
MatMul.qmethod_matmul(&x.contiguous()?, &*self.output)
}
}