use std::path::Path;
use crate::gguf::{GgufFile, MetadataValue};
use crate::tensor::{DType, Tensor};
use super::Architecture;
use super::config::{ActivationType, ModelConfig, RopeConfig, RopeScalingType, RopeType};
use super::deltanet::{BetaAlphaProjection, DeltaNetConfig, DeltaNetLayer};
use super::mamba::{MambaConfig, MambaLayer};
use super::error::{ModelError, ModelResult};
use super::layers::{
Attention, AttentionLayer, FeedForward, FfnLayer, LayerNorm, Linear, NormLayer,
NoGateFeedForward, RMSNorm, TransformerLayer,
};
use super::bert::{BertLayer, BertModel};
use super::llama::LlamaModel;
use super::moe::{MoeConfig, MoeExpert, MoeLayer, MoeRouter};
pub struct ModelLoader {
gguf: GgufFile,
architecture: Architecture,
config: ModelConfig,
}
impl ModelLoader {
pub fn load<P: AsRef<Path>>(path: P) -> ModelResult<Self> {
let gguf = GgufFile::open(path)?;
let arch_str = gguf
.data
.get_string("general.architecture")
.ok_or_else(|| ModelError::MissingMetadata("general.architecture".into()))?;
let architecture = Architecture::from_gguf_str(arch_str);
if matches!(architecture, Architecture::Unknown) {
return Err(ModelError::UnsupportedArchitecture(arch_str.to_string()));
}
let config = Self::parse_config(&gguf, &architecture)?;
Ok(Self {
gguf,
architecture,
config,
})
}
fn parse_config(gguf: &GgufFile, architecture: &Architecture) -> ModelResult<ModelConfig> {
let arch = architecture.as_str();
let get_u32 = |key: &str| -> ModelResult<u32> {
gguf.data
.get_u32(key)
.ok_or_else(|| ModelError::MissingMetadata(key.into()))
};
let get_f32_or =
|key: &str, default: f32| -> f32 { gguf.data.get_f32(key).unwrap_or(default) };
let vocab_size = get_u32(&format!("{}.vocab_size", arch))
.or_else(|_| get_u32("tokenizer.ggml.vocab_size"))
.map(|v| v as usize)
.unwrap_or_else(|_| {
if let Some(tokens) = gguf.data.metadata.get("tokenizer.ggml.tokens")
&& let MetadataValue::Array(arr) = tokens
{
return arr.values.len();
}
if let Some(emb_info) = gguf.data.get_tensor("token_embd.weight") {
if emb_info.dims.len() == 2 {
return emb_info.dims[1] as usize;
}
}
32000
});
let hidden_size = get_u32(&format!("{}.embedding_length", arch))? as usize;
let num_layers = get_u32(&format!("{}.block_count", arch))? as usize;
let (num_heads, num_kv_heads, head_dim) =
if matches!(architecture, Architecture::Mamba | Architecture::Mamba2) {
let nh = get_u32(&format!("{}.attention.head_count", arch)).unwrap_or(1) as usize;
let nkv = get_u32(&format!("{}.attention.head_count_kv", arch))
.unwrap_or(nh as u32) as usize;
let hd = get_u32(&format!("{}.attention.key_length", arch))
.unwrap_or_else(|_| (hidden_size / nh.max(1)) as u32) as usize;
(nh, nkv, hd)
} else {
let nh = get_u32(&format!("{}.attention.head_count", arch))? as usize;
let nkv = get_u32(&format!("{}.attention.head_count_kv", arch))
.unwrap_or(nh as u32) as usize;
let hd = get_u32(&format!("{}.attention.key_length", arch))
.map(|v| v as usize)
.unwrap_or(hidden_size / nh);
(nh, nkv, hd)
};
let intermediate_size = get_u32(&format!("{}.feed_forward_length", arch))
.unwrap_or_else(|_| {
if matches!(architecture, Architecture::Mamba | Architecture::Mamba2) {
hidden_size as u32 } else {
(hidden_size * 4 * 2 / 3) as u32
}
}) as usize;
let max_seq_len = get_u32(&format!("{}.context_length", arch)).unwrap_or(2048) as usize;
let norm_eps = gguf
.data
.get_f32(&format!("{}.attention.layer_norm_rms_epsilon", arch))
.or_else(|| gguf.data.get_f32(&format!("{}.attention.layer_norm_epsilon", arch)))
.unwrap_or(1e-5);
let freq_base = get_f32_or(&format!("{}.rope.freq_base", arch), 10000.0);
let freq_scale = get_f32_or(&format!("{}.rope.scale_linear", arch), 1.0);
let rope_type = match architecture {
Architecture::Qwen2
| Architecture::Qwen2Moe
| Architecture::Qwen3
| Architecture::Qwen35
| Architecture::Qwen35Moe
| Architecture::Qwen3Moe
| Architecture::Qwen3Next
| Architecture::GPTNeoX
| Architecture::Falcon
| Architecture::Phi
| Architecture::Phi2
| Architecture::Phi3
| Architecture::PhiMoe
| Architecture::GPTJ
| Architecture::StableLM => RopeType::NeoX,
_ => RopeType::Normal,
};
let num_experts = get_u32(&format!("{}.expert_count", arch)).unwrap_or(0) as usize;
let num_experts_per_token =
get_u32(&format!("{}.expert_used_count", arch)).unwrap_or(0) as usize;
let expert_intermediate_size =
get_u32(&format!("{}.expert_feed_forward_length", arch)).unwrap_or(0) as usize;
let key_length =
get_u32(&format!("{}.attention.key_length", arch)).unwrap_or(head_dim as u32) as usize;
let value_length = get_u32(&format!("{}.attention.value_length", arch))
.unwrap_or(head_dim as u32) as usize;
let rope_n_dims = get_u32(&format!("{}.rope.dimension_count", arch))
.unwrap_or(head_dim as u32) as usize;
let rope_config = RopeConfig {
freq_base,
freq_scale,
n_dims: rope_n_dims,
scaling_type: RopeScalingType::None,
original_max_position_embeddings: max_seq_len,
rope_type,
};
let has_combined_qkv = architecture.has_combined_qkv();
let uses_layer_norm = architecture.uses_layer_norm();
let uses_gelu = architecture.uses_gelu();
let has_ffn_gate = !architecture.has_no_gate_ffn();
let attn_logit_softcap =
get_f32_or(&format!("{}.attn_logit_softcapping", arch), 0.0);
let final_logit_softcap =
get_f32_or(&format!("{}.final_logit_softcapping", arch), 0.0);
let sliding_window =
get_u32(&format!("{}.attention.sliding_window", arch)).unwrap_or(0) as usize;
let attention_bias = matches!(
architecture,
Architecture::Qwen
| Architecture::Qwen2
| Architecture::Qwen2Moe
| Architecture::Phi2
| Architecture::Phi3
| Architecture::PhiMoe
| Architecture::GPTNeoX
| Architecture::GPTJ
| Architecture::Falcon
| Architecture::BLOOM
| Architecture::MPT
| Architecture::OPT
| Architecture::GPT2
| Architecture::StableLM
| Architecture::Baichuan
);
let mlp_bias = matches!(
architecture,
Architecture::GPT2
| Architecture::GPTJ
| Architecture::GPTNeoX
| Architecture::BLOOM
| Architecture::OPT
| Architecture::StableLM
| Architecture::Phi2
| Architecture::Phi3
);
let use_parallel_residual = matches!(
architecture,
Architecture::GPTNeoX
| Architecture::GPTJ
| Architecture::StableLM
| Architecture::Phi
| Architecture::Phi2
| Architecture::CodeShell
);
let hidden_act = if architecture.uses_gelu() {
ActivationType::GELU
} else {
ActivationType::SiLU
};
Ok(ModelConfig {
vocab_size,
hidden_size,
intermediate_size,
num_layers,
num_heads,
num_kv_heads,
head_dim,
max_seq_len,
norm_eps,
rope_config,
use_parallel_residual,
hidden_act,
attention_bias,
mlp_bias,
tie_word_embeddings: gguf
.data
.get_string("general.tie_word_embeddings")
.map(|s| s == "true")
.unwrap_or(false),
num_experts,
num_experts_per_token,
expert_intermediate_size,
key_length,
value_length,
ssm_d_inner: get_u32(&format!("{}.ssm.inner_size", arch)).unwrap_or(0) as usize,
ssm_d_state: get_u32(&format!("{}.ssm.state_size", arch)).unwrap_or(0) as usize,
ssm_n_group: {
let g = get_u32(&format!("{}.ssm.group_count", arch)).unwrap_or(0) as usize;
if g == 0 && matches!(architecture, Architecture::Mamba | Architecture::Mamba2) {
1
} else {
g
}
},
ssm_dt_rank: get_u32(&format!("{}.ssm.time_step_rank", arch)).unwrap_or(0) as usize,
ssm_conv_kernel: get_u32(&format!("{}.ssm.conv_kernel", arch)).unwrap_or(0) as usize,
attn_logit_softcap,
final_logit_softcap,
sliding_window,
has_combined_qkv,
uses_layer_norm,
uses_gelu,
has_ffn_gate,
})
}
pub fn config(&self) -> &ModelConfig {
&self.config
}
pub fn config_mut(&mut self) -> &mut ModelConfig {
&mut self.config
}
pub fn architecture(&self) -> Architecture {
self.architecture
}
pub fn build_model(self) -> ModelResult<LlamaModel> {
let token_embedding = self.load_tensor("token_embd.weight")?;
let mut layers = Vec::with_capacity(self.config.num_layers);
for i in 0..self.config.num_layers {
let layer = self.load_transformer_layer(i)?;
layers.push(layer);
}
let recurrent_count = layers.iter().filter(|l| l.is_recurrent()).count();
if recurrent_count > 0 {
tracing::info!(
"Model has {}/{} DeltaNet recurrent layers",
recurrent_count,
layers.len()
);
}
let norm_weight = self.apply_gemma_norm_weight_offset(self.load_tensor("output_norm.weight")?)?;
let norm = if let Some(bias) = self.try_load_tensor("output_norm.bias") {
NormLayer::Layer(LayerNorm::new(norm_weight, bias, self.config.norm_eps)?)
} else {
NormLayer::RMS(RMSNorm::new(norm_weight, self.config.norm_eps)?)
};
let output_bias = self.try_load_tensor("output.bias");
let output =
if self.config.tie_word_embeddings || self.try_load_tensor("output.weight").is_none() {
Linear::new(token_embedding.clone(), output_bias)?
} else {
let output_weight = self.load_tensor("output.weight")?;
Linear::new(output_weight, output_bias)?
};
LlamaModel::new(
self.config,
token_embedding,
layers,
norm,
output,
self.architecture,
)
}
pub fn build_bert_model(self) -> ModelResult<BertModel> {
let token_embedding = self.load_tensor("token_embd.weight")?;
let position_embedding = self.try_load_tensor("position_embd.weight");
let token_type_embedding = self.try_load_tensor("token_types.weight");
let embed_norm = if let Some(w) = self.try_load_tensor("token_embd_norm.weight") {
if let Some(b) = self.try_load_tensor("token_embd_norm.bias") {
Some(NormLayer::Layer(LayerNorm::new(w, b, self.config.norm_eps)?))
} else {
Some(NormLayer::RMS(RMSNorm::new(w, self.config.norm_eps)?))
}
} else {
None
};
let mut layers = Vec::with_capacity(self.config.num_layers);
for i in 0..self.config.num_layers {
let prefix = format!("blk.{}", i);
let attn_norm_w = self
.try_load_tensor(&format!("{}.attn_output_norm.weight", prefix))
.or_else(|| self.try_load_tensor(&format!("{}.attn_norm.weight", prefix)))
.ok_or_else(|| {
ModelError::MissingTensor(format!("{}.attn_norm.weight", prefix))
})?;
let attn_norm_b = self
.try_load_tensor(&format!("{}.attn_output_norm.bias", prefix))
.or_else(|| self.try_load_tensor(&format!("{}.attn_norm.bias", prefix)));
let attn_norm = if let Some(b) = attn_norm_b {
NormLayer::Layer(LayerNorm::new(attn_norm_w, b, self.config.norm_eps)?)
} else {
NormLayer::RMS(RMSNorm::new(attn_norm_w, self.config.norm_eps)?)
};
let (wq, wk, wv) =
if let Some(qkv) = self.try_load_tensor(&format!("{}.attn_qkv.weight", prefix)) {
let num_heads = self.config.num_heads;
let head_dim = self.config.head_dim;
let hidden = self.config.hidden_size;
let q_size = num_heads * head_dim;
let k_size = num_heads * head_dim;
let v_size = num_heads * head_dim;
let total = q_size + k_size + v_size;
let qkv_f32 = if qkv.dtype() == DType::F32 {
qkv.as_f32()?.to_vec()
} else {
let backend = crate::backend::default_backend();
let mut deq = Tensor::zeros(vec![qkv.numel()], DType::F32);
backend
.dequantize(&qkv, &mut deq)
.map_err(|e| ModelError::ConfigError(format!("Dequant QKV: {}", e)))?;
deq.as_f32()?.to_vec()
};
let q_start = 0;
let k_start_off = q_size * hidden;
let v_start_off = (q_size + k_size) * hidden;
let qkv_bias = self.try_load_tensor(&format!("{}.attn_qkv.bias", prefix));
let (qb, kb, vb) = if let Some(ref b) = qkv_bias {
let bd = b.as_f32()?;
(
Some(Tensor::from_f32(&bd[..q_size], vec![q_size])?),
Some(Tensor::from_f32(
&bd[q_size..q_size + k_size],
vec![k_size],
)?),
Some(Tensor::from_f32(&bd[q_size + k_size..], vec![v_size])?),
)
} else {
(None, None, None)
};
(
Linear::new(
Tensor::from_f32(&qkv_f32[q_start..q_start + q_size * hidden], vec![hidden, q_size])?,
qb,
)?,
Linear::new(
Tensor::from_f32(&qkv_f32[k_start_off..k_start_off + k_size * hidden], vec![hidden, k_size])?,
kb,
)?,
Linear::new(
Tensor::from_f32(&qkv_f32[v_start_off..v_start_off + v_size * hidden], vec![hidden, v_size])?,
vb,
)?,
)
} else {
let qb = self.try_load_tensor(&format!("{}.attn_q.bias", prefix));
let kb = self.try_load_tensor(&format!("{}.attn_k.bias", prefix));
let vb = self.try_load_tensor(&format!("{}.attn_v.bias", prefix));
(
Linear::new(
self.load_tensor(&format!("{}.attn_q.weight", prefix))?,
qb,
)?,
Linear::new(
self.load_tensor(&format!("{}.attn_k.weight", prefix))?,
kb,
)?,
Linear::new(
self.load_tensor(&format!("{}.attn_v.weight", prefix))?,
vb,
)?,
)
};
let wo_bias = self.try_load_tensor(&format!("{}.attn_output.bias", prefix));
let wo = Linear::new(
self.load_tensor(&format!("{}.attn_output.weight", prefix))?,
wo_bias,
)?;
let ffn_norm_w = self
.try_load_tensor(&format!("{}.layer_output_norm.weight", prefix))
.or_else(|| self.try_load_tensor(&format!("{}.ffn_norm.weight", prefix)))
.ok_or_else(|| {
ModelError::MissingTensor(format!("{}.ffn_norm.weight", prefix))
})?;
let ffn_norm_b = self
.try_load_tensor(&format!("{}.layer_output_norm.bias", prefix))
.or_else(|| self.try_load_tensor(&format!("{}.ffn_norm.bias", prefix)));
let ffn_norm = if let Some(b) = ffn_norm_b {
NormLayer::Layer(LayerNorm::new(ffn_norm_w, b, self.config.norm_eps)?)
} else {
NormLayer::RMS(RMSNorm::new(ffn_norm_w, self.config.norm_eps)?)
};
let ffn_up_bias = self.try_load_tensor(&format!("{}.ffn_up.bias", prefix));
let ffn_up = Linear::new(
self.load_tensor(&format!("{}.ffn_up.weight", prefix))?,
ffn_up_bias,
)?;
let ffn_down_bias = self.try_load_tensor(&format!("{}.ffn_down.bias", prefix));
let ffn_down = Linear::new(
self.load_tensor(&format!("{}.ffn_down.weight", prefix))?,
ffn_down_bias,
)?;
layers.push(BertLayer {
attn_norm,
wq,
wk,
wv,
wo,
num_heads: self.config.num_heads,
head_dim: self.config.head_dim,
ffn_norm,
ffn_up,
ffn_down,
});
}
BertModel::new(
self.config,
token_embedding,
position_embedding,
token_type_embedding,
embed_norm,
layers,
self.architecture,
)
}
pub fn deltanet_config(&self) -> Option<DeltaNetConfig> {
if !self.config.has_ssm()
|| matches!(self.architecture, Architecture::Mamba | Architecture::Mamba2)
{
return None;
}
let d_inner = self.config.ssm_d_inner;
let d_state = self.config.ssm_d_state;
let num_v_heads = self.config.ssm_dt_rank;
let num_k_heads = self.config.ssm_n_group.max(1);
let head_v_dim = d_inner / num_v_heads.max(1);
let head_k_dim = d_state;
let conv_kernel = self.config.ssm_conv_kernel;
let q_dim = num_k_heads * head_k_dim;
let k_dim = num_k_heads * head_k_dim;
let qkv_dim = q_dim + k_dim + d_inner;
Some(DeltaNetConfig {
d_inner,
d_state,
num_v_heads,
num_k_heads,
head_v_dim,
head_k_dim,
conv_kernel,
qkv_dim,
})
}
pub fn recurrent_config(&self) -> Option<super::deltanet::RecurrentConfig> {
if !self.config.has_ssm() {
return None;
}
if matches!(self.architecture, Architecture::Mamba | Architecture::Mamba2) {
Some(super::deltanet::RecurrentConfig::Mamba(MambaConfig {
d_inner: self.config.ssm_d_inner,
d_state: self.config.ssm_d_state,
dt_rank: self.config.ssm_dt_rank,
conv_kernel: self.config.ssm_conv_kernel.max(1),
}))
} else if let Some(dn) = self.deltanet_config() {
Some(super::deltanet::RecurrentConfig::DeltaNet(dn))
} else {
None
}
}
fn load_transformer_layer(&self, layer_idx: usize) -> ModelResult<TransformerLayer> {
let prefix = format!("blk.{}", layer_idx);
let is_mamba = matches!(self.architecture, Architecture::Mamba | Architecture::Mamba2);
let attn_norm_weight = self
.try_load_tensor(&format!("{}.attn_norm.weight", prefix))
.or_else(|| self.try_load_tensor(&format!("{}.norm.weight", prefix)))
.ok_or_else(|| ModelError::MissingTensor(format!("{}.attn_norm.weight", prefix)))?;
let attn_norm_weight = self.apply_gemma_norm_weight_offset(attn_norm_weight)?;
let attn_norm_bias = self
.try_load_tensor(&format!("{}.attn_norm.bias", prefix))
.or_else(|| self.try_load_tensor(&format!("{}.norm.bias", prefix)));
let attn_norm = if let Some(bias) = attn_norm_bias {
NormLayer::Layer(LayerNorm::new(attn_norm_weight, bias, self.config.norm_eps)?)
} else {
NormLayer::RMS(RMSNorm::new(attn_norm_weight, self.config.norm_eps)?)
};
let attn_layer = self.load_attention_layer(layer_idx)?;
let post_attn_norm =
if let Some(w) = self.try_load_tensor(&format!("{}.post_attention_norm.weight", prefix))
{
let w = self.apply_gemma_norm_weight_offset(w)?;
let b = self.try_load_tensor(&format!("{}.post_attention_norm.bias", prefix));
Some(if let Some(bias) = b {
NormLayer::Layer(LayerNorm::new(w, bias, self.config.norm_eps)?)
} else {
NormLayer::RMS(RMSNorm::new(w, self.config.norm_eps)?)
})
} else {
None
};
let ffn_norm_weight = self.try_load_tensor(&format!("{}.ffn_norm.weight", prefix));
let ffn_norm_bias = self.try_load_tensor(&format!("{}.ffn_norm.bias", prefix));
let ffn_norm = if let Some(w) = ffn_norm_weight {
let w = self.apply_gemma_norm_weight_offset(w)?;
if let Some(bias) = ffn_norm_bias {
NormLayer::Layer(LayerNorm::new(w, bias, self.config.norm_eps)?)
} else {
NormLayer::RMS(RMSNorm::new(w, self.config.norm_eps)?)
}
} else if post_attn_norm.is_some() || is_mamba || self.config.use_parallel_residual {
let hidden = self.config.hidden_size;
NormLayer::RMS(RMSNorm::new(
Tensor::from_f32(&vec![1.0f32; hidden], vec![hidden])?,
self.config.norm_eps,
)?)
} else {
return Err(ModelError::MissingTensor(format!(
"{}.ffn_norm.weight",
prefix
)));
};
let ffn_layer = if self.config.is_moe() {
self.load_moe_layer(layer_idx)?
} else if is_mamba
&& self.try_load_tensor(&format!("{}.ffn_up.weight", prefix)).is_none()
{
FfnLayer::Identity
} else if !self.config.has_ffn_gate {
let up_tensor = self.load_tensor(&format!("{}.ffn_up.weight", prefix))?;
let up_out_dim = up_tensor.shape()[up_tensor.ndim() - 1];
let intermediate = self.config.intermediate_size;
if up_out_dim == 2 * intermediate {
let hidden = self.config.hidden_size;
let up_f32 = if up_tensor.dtype() == DType::F32 {
up_tensor.as_f32()?.to_vec()
} else {
let backend = crate::backend::default_backend();
let mut deq = Tensor::zeros(vec![up_tensor.numel()], DType::F32);
backend
.dequantize(&up_tensor, &mut deq)
.map_err(|e| ModelError::ConfigError(format!("Dequant ffn_up: {}", e)))?;
deq.as_f32()?.to_vec()
};
let gate_data = &up_f32[..hidden * intermediate];
let up_data = &up_f32[hidden * intermediate..];
let w_gate = Linear::new(
Tensor::from_f32(gate_data, vec![hidden, intermediate])?,
None,
)?;
let w_up = Linear::new(
Tensor::from_f32(up_data, vec![hidden, intermediate])?,
None,
)?;
let w_down = Linear::new(
self.load_tensor(&format!("{}.ffn_down.weight", prefix))?,
None,
)?;
FfnLayer::Dense(FeedForward::new(w_gate, w_up, w_down))
} else {
let w_up = Linear::new(
up_tensor,
self.try_load_tensor(&format!("{}.ffn_up.bias", prefix)),
)?;
let w_down = Linear::new(
self.load_tensor(&format!("{}.ffn_down.weight", prefix))?,
self.try_load_tensor(&format!("{}.ffn_down.bias", prefix)),
)?;
FfnLayer::NoGate(NoGateFeedForward::new(
w_up,
w_down,
self.config.uses_gelu,
))
}
} else {
let w_gate = Linear::new(
self.load_tensor(&format!("{}.ffn_gate.weight", prefix))?,
None,
)?;
let w_up = Linear::new(
self.load_tensor(&format!("{}.ffn_up.weight", prefix))?,
None,
)?;
let w_down = Linear::new(
self.load_tensor(&format!("{}.ffn_down.weight", prefix))?,
None,
)?;
FfnLayer::Dense(FeedForward::new(w_gate, w_up, w_down))
};
let post_ffn_norm =
if let Some(w) = self.try_load_tensor(&format!("{}.post_ffw_norm.weight", prefix)) {
let w = self.apply_gemma_norm_weight_offset(w)?;
let b = self.try_load_tensor(&format!("{}.post_ffw_norm.bias", prefix));
Some(if let Some(bias) = b {
NormLayer::Layer(LayerNorm::new(w, bias, self.config.norm_eps)?)
} else {
NormLayer::RMS(RMSNorm::new(w, self.config.norm_eps)?)
})
} else {
None
};
Ok(TransformerLayer {
attn_norm,
attn_layer,
post_attn_norm,
ffn_norm,
ffn_layer,
post_ffn_norm,
layer_idx,
use_parallel_residual: self.config.use_parallel_residual,
})
}
fn load_attention_layer(&self, layer_idx: usize) -> ModelResult<AttentionLayer> {
let prefix = format!("blk.{}", layer_idx);
if let Some(wq_weight) = self.try_load_tensor(&format!("{}.attn_q.weight", prefix)) {
let attn = self.load_full_attention(layer_idx, wq_weight)?;
Ok(AttentionLayer::FullAttention(attn))
} else if let Some(qkv_weight) =
self.try_load_tensor(&format!("{}.attn_qkv.weight", prefix))
{
if self.config.has_ssm() {
let dn = self.load_deltanet_layer(layer_idx)?;
Ok(AttentionLayer::DeltaNet(Box::new(dn)))
} else {
let attn = self.load_combined_qkv_attention(layer_idx, qkv_weight)?;
Ok(AttentionLayer::FullAttention(attn))
}
} else if self.config.has_ssm()
&& self.try_load_tensor(&format!("{}.ssm_in.weight", prefix)).is_some()
{
let mamba = self.load_mamba_layer(layer_idx)?;
Ok(AttentionLayer::Mamba(Box::new(mamba)))
} else {
Err(ModelError::MissingTensor(format!(
"{}.attn_q.weight or {}.attn_qkv.weight or {}.ssm_in.weight",
prefix, prefix, prefix
)))
}
}
fn load_full_attention(
&self,
layer_idx: usize,
wq_weight: Tensor,
) -> ModelResult<Attention> {
let prefix = format!("blk.{}", layer_idx);
let use_neox_rope = matches!(self.config.rope_config.rope_type, RopeType::NeoX);
let kl = self.config.key_length;
let vl = self.config.value_length;
let rope_dims = self.config.rope_config.n_dims;
let wq_bias = self.try_load_tensor(&format!("{}.attn_q.bias", prefix));
let actual_q_out = wq_weight.shape()[1];
let has_attention_gate = actual_q_out == self.config.num_heads * (kl + vl);
let wq = Linear::new(wq_weight, wq_bias)?;
let wk_bias = self.try_load_tensor(&format!("{}.attn_k.bias", prefix));
let wk = Linear::new(
self.load_tensor(&format!("{}.attn_k.weight", prefix))?,
wk_bias,
)?;
let wv_bias = self.try_load_tensor(&format!("{}.attn_v.bias", prefix));
let wv = Linear::new(
self.load_tensor(&format!("{}.attn_v.weight", prefix))?,
wv_bias,
)?;
let wo_bias = self.try_load_tensor(&format!("{}.attn_output.bias", prefix));
let wo = Linear::new(
self.load_tensor(&format!("{}.attn_output.weight", prefix))?,
wo_bias,
)?;
let mut attention = Attention::with_kv_dims(
wq, wk, wv, wo,
self.config.num_heads,
self.config.num_kv_heads,
self.config.head_dim,
kl, vl, rope_dims,
use_neox_rope,
has_attention_gate,
);
if self.architecture.uses_qk_norm()
&& let (Some(q_norm_w), Some(k_norm_w)) = (
self.try_load_tensor(&format!("{}.attn_q_norm.weight", prefix)),
self.try_load_tensor(&format!("{}.attn_k_norm.weight", prefix)),
)
{
let q_norm = RMSNorm::new(q_norm_w, self.config.norm_eps)?;
let k_norm = RMSNorm::new(k_norm_w, self.config.norm_eps)?;
attention.set_qk_norms(q_norm, k_norm);
}
if self.config.attn_logit_softcap > 0.0 {
attention.set_attn_logit_softcap(self.config.attn_logit_softcap);
}
if matches!(self.architecture, Architecture::Qwen3Next) {
attention.set_rope_partial_at_end(true);
}
Ok(attention)
}
fn load_combined_qkv_attention(
&self,
layer_idx: usize,
qkv_weight: Tensor,
) -> ModelResult<Attention> {
let prefix = format!("blk.{}", layer_idx);
let use_neox_rope = matches!(self.config.rope_config.rope_type, RopeType::NeoX);
let kl = self.config.key_length;
let vl = self.config.value_length;
let rope_dims = self.config.rope_config.n_dims;
let num_heads = self.config.num_heads;
let num_kv_heads = self.config.num_kv_heads;
let head_dim = self.config.head_dim;
let qkv_shape = qkv_weight.shape();
let in_features = qkv_shape[0];
let q_size = num_heads * head_dim;
let k_size = num_kv_heads * head_dim;
let v_size = num_kv_heads * head_dim;
let total_out = q_size + k_size + v_size;
let qkv_bias = self.try_load_tensor(&format!("{}.attn_qkv.bias", prefix));
if qkv_weight.dtype() == DType::F32 {
let qkv_f32 = qkv_weight.as_f32()?;
let q_start = 0;
let k_start = q_size * in_features;
let v_start = (q_size + k_size) * in_features;
let q_tensor = Tensor::from_f32(
&qkv_f32[q_start..q_start + q_size * in_features],
vec![in_features, q_size],
)?;
let k_tensor = Tensor::from_f32(
&qkv_f32[k_start..k_start + k_size * in_features],
vec![in_features, k_size],
)?;
let v_tensor = Tensor::from_f32(
&qkv_f32[v_start..v_start + v_size * in_features],
vec![in_features, v_size],
)?;
let (q_bias, k_bias, v_bias) = if let Some(ref bias) = qkv_bias {
let b = bias.as_f32()?;
let qb = Tensor::from_f32(&b[..q_size], vec![q_size])?;
let kb = Tensor::from_f32(&b[q_size..q_size + k_size], vec![k_size])?;
let vb = Tensor::from_f32(&b[q_size + k_size..], vec![v_size])?;
(Some(qb), Some(kb), Some(vb))
} else {
(None, None, None)
};
let wq = Linear::new(q_tensor, q_bias)?;
let wk = Linear::new(k_tensor, k_bias)?;
let wv = Linear::new(v_tensor, v_bias)?;
let wo_bias = self.try_load_tensor(&format!("{}.attn_output.bias", prefix));
let wo = Linear::new(
self.load_tensor(&format!("{}.attn_output.weight", prefix))?,
wo_bias,
)?;
Ok(Attention::with_kv_dims(
wq, wk, wv, wo,
num_heads, num_kv_heads, head_dim,
kl, vl, rope_dims,
use_neox_rope, false,
))
} else {
let backend = crate::backend::default_backend();
let numel = qkv_weight.numel();
let mut dequant = Tensor::zeros(vec![numel], DType::F32);
backend
.dequantize(&qkv_weight, &mut dequant)
.map_err(|e| ModelError::ConfigError(format!("Failed to dequantize QKV: {}", e)))?;
let qkv_f32 = dequant.as_f32()?;
let q_start = 0;
let k_start = q_size * in_features;
let v_start = (q_size + k_size) * in_features;
let q_tensor = Tensor::from_f32(
&qkv_f32[q_start..q_start + q_size * in_features],
vec![in_features, q_size],
)?;
let k_tensor = Tensor::from_f32(
&qkv_f32[k_start..k_start + k_size * in_features],
vec![in_features, k_size],
)?;
let v_tensor = Tensor::from_f32(
&qkv_f32[v_start..v_start + v_size * in_features],
vec![in_features, v_size],
)?;
let (q_bias, k_bias, v_bias) = if let Some(ref bias) = qkv_bias {
let b = bias.as_f32()?;
let qb = Tensor::from_f32(&b[..q_size], vec![q_size])?;
let kb = Tensor::from_f32(&b[q_size..q_size + k_size], vec![k_size])?;
let vb = Tensor::from_f32(&b[q_size + k_size..], vec![v_size])?;
(Some(qb), Some(kb), Some(vb))
} else {
(None, None, None)
};
let wq = Linear::new(q_tensor, q_bias)?;
let wk = Linear::new(k_tensor, k_bias)?;
let wv = Linear::new(v_tensor, v_bias)?;
let wo_bias = self.try_load_tensor(&format!("{}.attn_output.bias", prefix));
let wo = Linear::new(
self.load_tensor(&format!("{}.attn_output.weight", prefix))?,
wo_bias,
)?;
Ok(Attention::with_kv_dims(
wq, wk, wv, wo,
num_heads, num_kv_heads, head_dim,
kl, vl, rope_dims,
use_neox_rope, false,
))
}
}
fn load_deltanet_layer(&self, layer_idx: usize) -> ModelResult<DeltaNetLayer> {
let prefix = format!("blk.{}", layer_idx);
let cfg = &self.config;
let d_inner = cfg.ssm_d_inner;
let d_state = cfg.ssm_d_state;
let num_v_heads = cfg.ssm_dt_rank;
let num_k_heads = cfg.ssm_n_group;
let head_v_dim = d_inner / num_v_heads;
let head_k_dim = d_state;
let conv_kernel = cfg.ssm_conv_kernel;
let q_dim = num_k_heads * head_k_dim;
let k_dim = num_k_heads * head_k_dim;
let qkv_dim = q_dim + k_dim + d_inner;
let dn_config = DeltaNetConfig {
d_inner,
d_state,
num_v_heads,
num_k_heads,
head_v_dim,
head_k_dim,
conv_kernel,
qkv_dim,
};
let attn_qkv = Linear::new(
self.load_tensor(&format!("{}.attn_qkv.weight", prefix))?,
None,
)?;
let attn_gate = Linear::new(
self.load_tensor(&format!("{}.attn_gate.weight", prefix))?,
None,
)?;
let ssm_ba = if let Some(ba_weight) =
self.try_load_tensor(&format!("{}.ssm_ba.weight", prefix))
{
BetaAlphaProjection::Combined(Linear::new(ba_weight, None)?)
} else {
let beta_w = self.load_tensor(&format!("{}.ssm_beta.weight", prefix))?;
let alpha_w = self.load_tensor(&format!("{}.ssm_alpha.weight", prefix))?;
BetaAlphaProjection::Separate {
beta: Linear::new(beta_w, None)?,
alpha: Linear::new(alpha_w, None)?,
}
};
let ssm_conv1d_weight = self.load_tensor(&format!("{}.ssm_conv1d.weight", prefix))?;
let ssm_a = self.load_tensor(&format!("{}.ssm_a", prefix))?;
let ssm_dt_bias = self.load_tensor(&format!("{}.ssm_dt.bias", prefix))?;
let ssm_norm_weight = self.load_tensor(&format!("{}.ssm_norm.weight", prefix))?;
let ssm_norm = RMSNorm::new(ssm_norm_weight, cfg.norm_eps)?;
let ssm_out = Linear::new(
self.load_tensor(&format!("{}.ssm_out.weight", prefix))?,
None,
)?;
tracing::info!("Layer {}: loaded DeltaNet (d_inner={}, d_state={}, v_heads={}, k_heads={}, conv={})",
layer_idx, d_inner, d_state, num_v_heads, num_k_heads, conv_kernel);
Ok(DeltaNetLayer {
config: dn_config,
attn_qkv,
attn_gate,
ssm_ba,
ssm_conv1d_weight,
ssm_a,
ssm_dt_bias,
ssm_norm,
ssm_out,
})
}
fn load_mamba_layer(&self, layer_idx: usize) -> ModelResult<MambaLayer> {
let prefix = format!("blk.{}", layer_idx);
let cfg = &self.config;
let d_inner = cfg.ssm_d_inner;
let d_state = cfg.ssm_d_state;
let dt_rank = cfg.ssm_dt_rank;
let conv_kernel = cfg.ssm_conv_kernel.max(1);
let mamba_config = MambaConfig {
d_inner,
d_state,
dt_rank,
conv_kernel,
};
let ssm_in = Linear::new(
self.load_tensor(&format!("{}.ssm_in.weight", prefix))?,
None,
)?;
let ssm_conv1d_weight = self.load_tensor(&format!("{}.ssm_conv1d.weight", prefix))?;
let ssm_conv1d_bias = self.try_load_tensor(&format!("{}.ssm_conv1d.bias", prefix));
let ssm_x = Linear::new(
self.load_tensor(&format!("{}.ssm_x.weight", prefix))?,
None,
)?;
let ssm_dt = Linear::new(
self.load_tensor(&format!("{}.ssm_dt.weight", prefix))?,
None,
)?;
let ssm_dt_bias = self.load_tensor(&format!("{}.ssm_dt.bias", prefix))?;
let ssm_a = self.load_tensor(&format!("{}.ssm_a", prefix))?;
let ssm_d = self.try_load_tensor(&format!("{}.ssm_d", prefix));
let ssm_norm = match self.try_load_tensor(&format!("{}.ssm_norm.weight", prefix)) {
Some(w) => Some(RMSNorm::new(w, cfg.norm_eps)?),
None => None,
};
let ssm_out = Linear::new(
self.load_tensor(&format!("{}.ssm_out.weight", prefix))?,
None,
)?;
tracing::info!(
"Layer {}: loaded Mamba SSM (d_inner={}, d_state={}, dt_rank={}, conv={})",
layer_idx, d_inner, d_state, dt_rank, conv_kernel
);
Ok(MambaLayer {
ssm_in,
ssm_conv1d_weight,
ssm_conv1d_bias,
ssm_x,
ssm_dt,
ssm_dt_bias,
ssm_a,
ssm_d,
ssm_norm,
ssm_out,
config: mamba_config,
})
}
fn load_moe_layer(&self, layer_idx: usize) -> ModelResult<FfnLayer> {
let prefix = format!("blk.{}", layer_idx);
let num_experts = self.config.num_experts;
let hidden_dim = self.config.hidden_size;
let expert_ffn_dim = if self.config.expert_intermediate_size > 0 {
self.config.expert_intermediate_size
} else {
self.config.intermediate_size / self.config.num_experts_per_token
};
let router_weight = self.load_tensor(&format!("{}.ffn_gate_inp.weight", prefix))?;
let router = MoeRouter::from_weight(
router_weight,
self.config.num_experts_per_token,
false, );
let gate_exps = self.load_tensor(&format!("{}.ffn_gate_exps.weight", prefix))?;
let up_exps = self.load_tensor(&format!("{}.ffn_up_exps.weight", prefix))?;
let down_exps = self.load_tensor(&format!("{}.ffn_down_exps.weight", prefix))?;
let mut experts = Vec::with_capacity(num_experts);
for e in 0..num_experts {
let mut gate_proj = self.extract_expert_tensor(&gate_exps, e)?;
let mut up_proj = self.extract_expert_tensor(&up_exps, e)?;
let mut down_proj = self.extract_expert_tensor(&down_exps, e)?;
gate_proj.set_name(format!("blk.{}.ffn_gate.{}.weight", layer_idx, e));
up_proj.set_name(format!("blk.{}.ffn_up.{}.weight", layer_idx, e));
down_proj.set_name(format!("blk.{}.ffn_down.{}.weight", layer_idx, e));
experts.push(MoeExpert {
gate_proj,
up_proj,
down_proj,
});
}
let mut shared_experts = Vec::new();
if let (Some(mut gate_shexp), Some(mut up_shexp), Some(mut down_shexp)) = (
self.try_load_tensor(&format!("{}.ffn_gate_shexp.weight", prefix)),
self.try_load_tensor(&format!("{}.ffn_up_shexp.weight", prefix)),
self.try_load_tensor(&format!("{}.ffn_down_shexp.weight", prefix)),
) {
gate_shexp.set_name(format!("blk.{}.ffn_gate_shexp.0.weight", layer_idx));
up_shexp.set_name(format!("blk.{}.ffn_up_shexp.0.weight", layer_idx));
down_shexp.set_name(format!("blk.{}.ffn_down_shexp.0.weight", layer_idx));
shared_experts.push(MoeExpert {
gate_proj: gate_shexp,
up_proj: up_shexp,
down_proj: down_shexp,
});
}
let shared_expert_gate =
self.try_load_tensor(&format!("{}.ffn_gate_inp_shexp.weight", prefix))
.map(|t| {
if t.dtype() == DType::F32 {
t
} else {
let raw = t.data();
let f32_vals: Vec<f32> = match t.dtype() {
DType::BF16 => {
raw.chunks_exact(2)
.map(|c| {
let bits = u16::from_le_bytes([c[0], c[1]]);
f32::from_bits((bits as u32) << 16)
})
.collect()
}
_ => {
tracing::warn!("Unsupported dtype for shared expert gate, zeroing");
vec![0.0f32; t.numel()]
}
};
let shape = t.shape().to_vec();
Tensor::from_f32(&f32_vals, shape).unwrap()
}
});
if shared_expert_gate.is_some() {
tracing::debug!("Layer {}: loaded shared expert gate", layer_idx);
}
let num_shared = shared_experts.len();
let moe_config = MoeConfig {
num_experts,
num_experts_per_token: self.config.num_experts_per_token,
expert_hidden_dim: expert_ffn_dim,
num_shared_experts: num_shared,
aux_loss_coef: 0.0,
normalize_router_logits: false,
};
let mut moe_layer = MoeLayer::new(hidden_dim, moe_config);
moe_layer.router = router;
moe_layer.experts = experts;
moe_layer.shared_experts = shared_experts;
moe_layer.shared_expert_gate = shared_expert_gate;
Ok(FfnLayer::Moe(moe_layer))
}
fn extract_expert_tensor(
&self,
batched: &Tensor,
expert_idx: usize,
) -> ModelResult<Tensor> {
let shape = batched.shape();
if shape.len() != 3 {
return Err(ModelError::ConfigError(format!(
"Expected 3D batched expert tensor, got shape {:?}",
shape
)));
}
let ne0 = shape[0];
let ne1 = shape[1];
let num_experts = shape[2];
let expert_numel = ne0 * ne1;
if expert_idx >= num_experts {
return Err(ModelError::ConfigError(format!(
"Expert index {} out of bounds ({})",
expert_idx, num_experts
)));
}
let per_expert_shape = vec![ne0, ne1];
if batched.dtype().is_quantized() {
let block_size = batched.dtype().block_size();
let block_bytes = batched.dtype().block_bytes();
if !expert_numel.is_multiple_of(block_size) {
return Err(ModelError::ConfigError(format!(
"Expert tensor elements ({}) not aligned to block size ({})",
expert_numel, block_size
)));
}
let blocks_per_expert = expert_numel / block_size;
let bytes_per_expert = blocks_per_expert * block_bytes;
let byte_offset = expert_idx * bytes_per_expert;
let raw_data = batched.data();
let expert_bytes = &raw_data[byte_offset..byte_offset + bytes_per_expert];
let mut tensor =
Tensor::new(expert_bytes.to_vec(), per_expert_shape, batched.dtype())?;
tensor.set_name(format!("expert.{}", expert_idx));
Ok(tensor)
} else {
let f32_data = batched.as_f32()?;
let offset = expert_idx * expert_numel;
let expert_slice = &f32_data[offset..offset + expert_numel];
let mut tensor = Tensor::from_f32(expert_slice, per_expert_shape)?;
tensor.set_name(format!("expert.{}", expert_idx));
Ok(tensor)
}
}
fn try_load_tensor(&self, name: &str) -> Option<Tensor> {
let tensor_info = self.gguf.data.get_tensor(name)?;
let tensor_data = self.gguf.tensor_data(name)?;
let shape: Vec<usize> = tensor_info.dims.iter().map(|&d| d as usize).collect();
let dtype = DType::from(tensor_info.dtype);
Tensor::new(tensor_data.to_vec(), shape, dtype)
.ok()
.map(|mut t| {
t.set_name(name);
t
})
}
fn apply_gemma_norm_weight_offset(&self, weight: Tensor) -> ModelResult<Tensor> {
Ok(weight)
}
fn load_tensor(&self, name: &str) -> ModelResult<Tensor> {
let tensor_info = self
.gguf
.data
.get_tensor(name)
.ok_or_else(|| ModelError::MissingTensor(name.into()))?;
let tensor_data = self
.gguf
.tensor_data(name)
.ok_or_else(|| ModelError::MissingTensor(name.into()))?;
let shape: Vec<usize> = tensor_info.dims.iter().map(|&d| d as usize).collect();
let dtype = DType::from(tensor_info.dtype);
let mut tensor = Tensor::new(tensor_data.to_vec(), shape, dtype)?;
tensor.set_name(name);
Ok(tensor)
}
}
pub fn load_llama_model<P: AsRef<Path>>(path: P) -> ModelResult<LlamaModel> {
let loader = ModelLoader::load(path)?;
if !loader.architecture().is_llama_like() {
return Err(ModelError::UnsupportedArchitecture(
loader.architecture().to_string(),
));
}
loader.build_model()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_architecture_detection() {
assert!(Architecture::Llama.is_llama_like());
assert!(Architecture::Mistral.is_llama_like());
assert!(Architecture::GPT2.is_llama_like());
assert!(!Architecture::Bert.is_llama_like());
assert!(!Architecture::Mamba.is_llama_like());
}
}