use std::path::Path;
use crate::gguf::{GgufFile, MetadataValue};
use crate::tensor::{DType, Tensor};
use super::Architecture;
use super::config::{ActivationType, ModelConfig, RopeConfig, RopeScalingType, RopeType};
use super::deltanet::{BetaAlphaProjection, DeltaNetConfig, DeltaNetLayer};
use super::mamba::{MambaConfig, MambaLayer};
use super::error::{ModelError, ModelResult};
use super::layers::{
Attention, AttentionLayer, FeedForward, FfnLayer, LayerNorm, Linear, NormLayer,
NoGateFeedForward, RMSNorm, TransformerLayer,
};
use super::bert::{BertLayer, BertModel};
use super::llama::LlamaModel;
use super::moe::{MoeConfig, MoeExpert, MoeLayer, MoeRouter};
pub trait ModelSource {
fn config(&self) -> &ModelConfig;
fn config_mut(&mut self) -> &mut ModelConfig;
fn architecture(&self) -> Architecture;
fn load_tensor(&self, name: &str) -> ModelResult<Tensor>;
fn try_load_tensor(&self, name: &str) -> Option<Tensor>;
}
pub struct ModelLoader {
gguf: GgufFile,
architecture: Architecture,
config: ModelConfig,
}
impl ModelLoader {
pub fn load<P: AsRef<Path>>(path: P) -> ModelResult<Self> {
let gguf = GgufFile::open(path)?;
let arch_str = gguf
.data
.get_string("general.architecture")
.ok_or_else(|| ModelError::MissingMetadata("general.architecture".into()))?;
let architecture = Architecture::from_gguf_str(arch_str);
if matches!(architecture, Architecture::Unknown) {
return Err(ModelError::UnsupportedArchitecture(arch_str.to_string()));
}
let config = Self::parse_config(&gguf, &architecture)?;
Ok(Self {
gguf,
architecture,
config,
})
}
fn parse_config(gguf: &GgufFile, architecture: &Architecture) -> ModelResult<ModelConfig> {
let arch = architecture.as_str();
let get_u32 = |key: &str| -> ModelResult<u32> {
gguf.data
.get_u32(key)
.ok_or_else(|| ModelError::MissingMetadata(key.into()))
};
let get_f32_or =
|key: &str, default: f32| -> f32 { gguf.data.get_f32(key).unwrap_or(default) };
let vocab_size = get_u32(&format!("{}.vocab_size", arch))
.or_else(|_| get_u32("tokenizer.ggml.vocab_size"))
.map(|v| v as usize)
.unwrap_or_else(|_| {
if let Some(tokens) = gguf.data.metadata.get("tokenizer.ggml.tokens")
&& let MetadataValue::Array(arr) = tokens
{
return arr.values.len();
}
if let Some(emb_info) = gguf.data.get_tensor("token_embd.weight") {
if emb_info.dims.len() == 2 {
return emb_info.dims[1] as usize;
}
}
32000
});
let hidden_size = get_u32(&format!("{}.embedding_length", arch))? as usize;
let num_layers = get_u32(&format!("{}.block_count", arch))? as usize;
let (num_heads, num_kv_heads, head_dim) =
if matches!(architecture, Architecture::Mamba | Architecture::Mamba2) {
let nh = get_u32(&format!("{}.attention.head_count", arch)).unwrap_or(1) as usize;
let nkv = get_u32(&format!("{}.attention.head_count_kv", arch))
.unwrap_or(nh as u32) as usize;
let hd = get_u32(&format!("{}.attention.key_length", arch))
.unwrap_or_else(|_| (hidden_size / nh.max(1)) as u32) as usize;
(nh, nkv, hd)
} else {
let nh = get_u32(&format!("{}.attention.head_count", arch))? as usize;
let nkv = get_u32(&format!("{}.attention.head_count_kv", arch))
.unwrap_or(nh as u32) as usize;
let hd = get_u32(&format!("{}.attention.key_length", arch))
.map(|v| v as usize)
.unwrap_or(hidden_size / nh);
(nh, nkv, hd)
};
let intermediate_size = get_u32(&format!("{}.feed_forward_length", arch))
.unwrap_or_else(|_| {
if matches!(architecture, Architecture::Mamba | Architecture::Mamba2) {
hidden_size as u32 } else {
(hidden_size * 4 * 2 / 3) as u32
}
}) as usize;
let max_seq_len = get_u32(&format!("{}.context_length", arch)).unwrap_or(2048) as usize;
let norm_eps = gguf
.data
.get_f32(&format!("{}.attention.layer_norm_rms_epsilon", arch))
.or_else(|| gguf.data.get_f32(&format!("{}.attention.layer_norm_epsilon", arch)))
.unwrap_or(1e-5);
let freq_base = get_f32_or(&format!("{}.rope.freq_base", arch), 10000.0);
let freq_scale = get_f32_or(&format!("{}.rope.scale_linear", arch), 1.0);
let rope_type = match architecture {
Architecture::Qwen2
| Architecture::Qwen2Moe
| Architecture::Qwen3
| Architecture::Qwen35
| Architecture::Qwen35Moe
| Architecture::Qwen3Moe
| Architecture::Qwen3Next
| Architecture::GPTNeoX
| Architecture::Falcon
| Architecture::Phi
| Architecture::Phi2
| Architecture::Phi3
| Architecture::PhiMoe
| Architecture::GPTJ
| Architecture::StableLM
| Architecture::Gemma
| Architecture::Gemma2
| Architecture::Gemma3
| Architecture::Gemma3N
| Architecture::Gemma4
| Architecture::GemmaEmbedding => RopeType::NeoX,
_ => RopeType::Normal,
};
let num_experts = get_u32(&format!("{}.expert_count", arch)).unwrap_or(0) as usize;
let num_experts_per_token =
get_u32(&format!("{}.expert_used_count", arch)).unwrap_or(0) as usize;
let expert_intermediate_size =
get_u32(&format!("{}.expert_feed_forward_length", arch)).unwrap_or(0) as usize;
let key_length =
get_u32(&format!("{}.attention.key_length", arch)).unwrap_or(head_dim as u32) as usize;
let value_length = get_u32(&format!("{}.attention.value_length", arch))
.unwrap_or(head_dim as u32) as usize;
let rope_n_dims = get_u32(&format!("{}.rope.dimension_count", arch))
.unwrap_or(head_dim as u32) as usize;
let mrope_sections = if let Some(MetadataValue::Array(arr)) =
gguf.data.metadata.get(&format!("{}.rope.dimension_sections", arch))
{
let sections: Vec<usize> = arr.values.iter().filter_map(|v| match v {
MetadataValue::Int32(n) if *n > 0 => Some(*n as usize),
_ => None,
}).collect();
if sections.is_empty() { None } else { Some(sections) }
} else {
None
};
let rope_config = RopeConfig {
freq_base,
freq_scale,
n_dims: rope_n_dims,
scaling_type: RopeScalingType::None,
original_max_position_embeddings: max_seq_len,
rope_type,
mrope_sections,
};
let has_combined_qkv = architecture.has_combined_qkv();
let uses_layer_norm = architecture.uses_layer_norm();
let uses_gelu = architecture.uses_gelu();
let has_ffn_gate = !architecture.has_no_gate_ffn();
let attn_logit_softcap =
get_f32_or(&format!("{}.attn_logit_softcapping", arch), 0.0);
let final_logit_softcap =
get_f32_or(&format!("{}.final_logit_softcapping", arch), 0.0);
let sliding_window =
get_u32(&format!("{}.attention.sliding_window", arch)).unwrap_or(0) as usize;
let attention_bias = matches!(
architecture,
Architecture::Qwen
| Architecture::Qwen2
| Architecture::Qwen2Moe
| Architecture::Phi2
| Architecture::Phi3
| Architecture::PhiMoe
| Architecture::GPTNeoX
| Architecture::GPTJ
| Architecture::Falcon
| Architecture::BLOOM
| Architecture::MPT
| Architecture::OPT
| Architecture::GPT2
| Architecture::StableLM
| Architecture::Baichuan
);
let mlp_bias = matches!(
architecture,
Architecture::GPT2
| Architecture::GPTJ
| Architecture::GPTNeoX
| Architecture::BLOOM
| Architecture::OPT
| Architecture::StableLM
| Architecture::Phi2
| Architecture::Phi3
);
let use_parallel_residual = matches!(
architecture,
Architecture::GPTNeoX
| Architecture::GPTJ
| Architecture::StableLM
| Architecture::Phi
| Architecture::Phi2
| Architecture::CodeShell
);
let hidden_act = if architecture.uses_gelu() {
ActivationType::GELU
} else {
ActivationType::SiLU
};
let mut config = ModelConfig {
vocab_size,
hidden_size,
intermediate_size,
num_layers,
num_heads,
num_kv_heads,
head_dim,
max_seq_len,
norm_eps,
rope_config,
use_parallel_residual,
hidden_act,
attention_bias,
mlp_bias,
tie_word_embeddings: gguf
.data
.get_string("general.tie_word_embeddings")
.map(|s| s == "true")
.unwrap_or(false),
num_experts,
num_experts_per_token,
expert_intermediate_size,
key_length,
value_length,
ssm_d_inner: get_u32(&format!("{}.ssm.inner_size", arch)).unwrap_or(0) as usize,
ssm_d_state: get_u32(&format!("{}.ssm.state_size", arch)).unwrap_or(0) as usize,
ssm_n_group: {
let g = get_u32(&format!("{}.ssm.group_count", arch)).unwrap_or(0) as usize;
if g == 0 && matches!(architecture, Architecture::Mamba | Architecture::Mamba2) {
1
} else {
g
}
},
ssm_dt_rank: get_u32(&format!("{}.ssm.time_step_rank", arch)).unwrap_or(0) as usize,
ssm_conv_kernel: get_u32(&format!("{}.ssm.conv_kernel", arch)).unwrap_or(0) as usize,
attn_logit_softcap,
final_logit_softcap,
sliding_window,
has_combined_qkv,
uses_layer_norm,
uses_gelu,
has_ffn_gate,
attention_layer_configs: None,
kv_source_layer: None,
};
if architecture.has_heterogeneous_attention() {
let global_head_dim = config.head_dim; let global_kv_heads = config.num_kv_heads; let global_rope_freq_base = config.rope_config.freq_base; let global_rope_dims = if let Some(data) = gguf.tensor_data("rope_freqs.weight") {
let floats: &[f32] = bytemuck::cast_slice(data);
let active_pairs = floats.iter().filter(|&&v| v < 1e10).count();
active_pairs * 2
} else {
get_u32(&format!("{}.rope.dimension_count", arch))
.unwrap_or(global_head_dim as u32) as usize
};
let swa_head_dim =
get_u32(&format!("{}.attention.key_length_swa", arch))
.unwrap_or(global_head_dim as u32) as usize;
let swa_kv_heads =
get_u32(&format!("{}.attention.head_count_kv_swa", arch))
.unwrap_or(global_kv_heads as u32) as usize;
let swa_rope_freq_base =
get_f32_or(&format!("{}.rope.freq_base_swa", arch), global_rope_freq_base);
let swa_rope_dims =
get_u32(&format!("{}.rope.dimension_count_swa", arch))
.unwrap_or(swa_head_dim as u32) as usize;
let sliding_window = config.sliding_window;
let swa_pattern: Vec<bool> =
if let Some(MetadataValue::Array(arr)) =
gguf.data.metadata.get(&format!("{}.attention.sliding_window_pattern", arch))
{
arr.values
.iter()
.map(|v| matches!(v, MetadataValue::Bool(true)))
.collect()
} else {
(0..config.num_layers)
.map(|i| i % 6 != 5)
.collect()
};
config.attention_layer_configs =
Some(ModelConfig::build_attention_layer_configs_from_pattern(
&swa_pattern,
swa_head_dim,
swa_kv_heads,
swa_rope_freq_base,
swa_rope_dims,
sliding_window,
global_head_dim,
global_kv_heads,
global_rope_freq_base,
global_rope_dims,
));
let shared_layers =
get_u32(&format!("{}.attention.shared_kv_layers", arch)).unwrap_or(0) as usize;
if shared_layers > 0 {
config.kv_source_layer = Some(ModelConfig::build_kv_source_mapping(
config.num_layers,
shared_layers,
config.attention_layer_configs.as_ref().unwrap(),
));
}
}
Ok(config)
}
pub fn config(&self) -> &ModelConfig {
&self.config
}
pub fn config_mut(&mut self) -> &mut ModelConfig {
&mut self.config
}
pub fn architecture(&self) -> Architecture {
self.architecture
}
pub fn build_model(self) -> ModelResult<LlamaModel> {
build_llama_model(&self)
}
pub fn build_bert_model(self) -> ModelResult<BertModel> {
let token_embedding = self.load_tensor("token_embd.weight")?;
let position_embedding = self.try_load_tensor("position_embd.weight");
let token_type_embedding = self.try_load_tensor("token_types.weight");
let embed_norm = if let Some(w) = self.try_load_tensor("token_embd_norm.weight") {
if let Some(b) = self.try_load_tensor("token_embd_norm.bias") {
Some(NormLayer::Layer(LayerNorm::new(w, b, self.config.norm_eps)?))
} else {
Some(NormLayer::RMS(RMSNorm::new(w, self.config.norm_eps)?))
}
} else {
None
};
let mut layers = Vec::with_capacity(self.config.num_layers);
for i in 0..self.config.num_layers {
let prefix = format!("blk.{}", i);
let attn_norm_w = self
.try_load_tensor(&format!("{}.attn_output_norm.weight", prefix))
.or_else(|| self.try_load_tensor(&format!("{}.attn_norm.weight", prefix)))
.ok_or_else(|| {
ModelError::MissingTensor(format!("{}.attn_norm.weight", prefix))
})?;
let attn_norm_b = self
.try_load_tensor(&format!("{}.attn_output_norm.bias", prefix))
.or_else(|| self.try_load_tensor(&format!("{}.attn_norm.bias", prefix)));
let attn_norm = if let Some(b) = attn_norm_b {
NormLayer::Layer(LayerNorm::new(attn_norm_w, b, self.config.norm_eps)?)
} else {
NormLayer::RMS(RMSNorm::new(attn_norm_w, self.config.norm_eps)?)
};
let (wq, wk, wv) =
if let Some(qkv) = self.try_load_tensor(&format!("{}.attn_qkv.weight", prefix)) {
let num_heads = self.config.num_heads;
let head_dim = self.config.head_dim;
let hidden = self.config.hidden_size;
let q_size = num_heads * head_dim;
let k_size = num_heads * head_dim;
let v_size = num_heads * head_dim;
let qkv_f32 = if qkv.dtype() == DType::F32 {
qkv.as_f32()?.to_vec()
} else {
let backend = crate::backend::default_backend();
let mut deq = Tensor::zeros(vec![qkv.numel()], DType::F32);
backend
.dequantize(&qkv, &mut deq)
.map_err(|e| ModelError::ConfigError(format!("Dequant QKV: {}", e)))?;
deq.as_f32()?.to_vec()
};
let q_start = 0;
let k_start_off = q_size * hidden;
let v_start_off = (q_size + k_size) * hidden;
let qkv_bias = self.try_load_tensor(&format!("{}.attn_qkv.bias", prefix));
let (qb, kb, vb) = if let Some(ref b) = qkv_bias {
let bd = b.as_f32()?;
(
Some(Tensor::from_f32(&bd[..q_size], vec![q_size])?),
Some(Tensor::from_f32(
&bd[q_size..q_size + k_size],
vec![k_size],
)?),
Some(Tensor::from_f32(&bd[q_size + k_size..], vec![v_size])?),
)
} else {
(None, None, None)
};
(
Linear::new(
Tensor::from_f32(&qkv_f32[q_start..q_start + q_size * hidden], vec![hidden, q_size])?,
qb,
)?,
Linear::new(
Tensor::from_f32(&qkv_f32[k_start_off..k_start_off + k_size * hidden], vec![hidden, k_size])?,
kb,
)?,
Linear::new(
Tensor::from_f32(&qkv_f32[v_start_off..v_start_off + v_size * hidden], vec![hidden, v_size])?,
vb,
)?,
)
} else {
let qb = self.try_load_tensor(&format!("{}.attn_q.bias", prefix));
let kb = self.try_load_tensor(&format!("{}.attn_k.bias", prefix));
let vb = self.try_load_tensor(&format!("{}.attn_v.bias", prefix));
(
Linear::new(
self.load_tensor(&format!("{}.attn_q.weight", prefix))?,
qb,
)?,
Linear::new(
self.load_tensor(&format!("{}.attn_k.weight", prefix))?,
kb,
)?,
Linear::new(
self.load_tensor(&format!("{}.attn_v.weight", prefix))?,
vb,
)?,
)
};
let wo_bias = self.try_load_tensor(&format!("{}.attn_output.bias", prefix));
let wo = Linear::new(
self.load_tensor(&format!("{}.attn_output.weight", prefix))?,
wo_bias,
)?;
let ffn_norm_w = self
.try_load_tensor(&format!("{}.layer_output_norm.weight", prefix))
.or_else(|| self.try_load_tensor(&format!("{}.ffn_norm.weight", prefix)))
.ok_or_else(|| {
ModelError::MissingTensor(format!("{}.ffn_norm.weight", prefix))
})?;
let ffn_norm_b = self
.try_load_tensor(&format!("{}.layer_output_norm.bias", prefix))
.or_else(|| self.try_load_tensor(&format!("{}.ffn_norm.bias", prefix)));
let ffn_norm = if let Some(b) = ffn_norm_b {
NormLayer::Layer(LayerNorm::new(ffn_norm_w, b, self.config.norm_eps)?)
} else {
NormLayer::RMS(RMSNorm::new(ffn_norm_w, self.config.norm_eps)?)
};
let ffn_up_bias = self.try_load_tensor(&format!("{}.ffn_up.bias", prefix));
let ffn_up = Linear::new(
self.load_tensor(&format!("{}.ffn_up.weight", prefix))?,
ffn_up_bias,
)?;
let ffn_down_bias = self.try_load_tensor(&format!("{}.ffn_down.bias", prefix));
let ffn_down = Linear::new(
self.load_tensor(&format!("{}.ffn_down.weight", prefix))?,
ffn_down_bias,
)?;
layers.push(BertLayer {
attn_norm,
wq,
wk,
wv,
wo,
num_heads: self.config.num_heads,
head_dim: self.config.head_dim,
ffn_norm,
ffn_up,
ffn_down,
});
}
BertModel::new(
self.config,
token_embedding,
position_embedding,
token_type_embedding,
embed_norm,
layers,
self.architecture,
)
}
pub fn deltanet_config(&self) -> Option<DeltaNetConfig> {
deltanet_config_from_source(self)
}
pub fn recurrent_config(&self) -> Option<super::deltanet::RecurrentConfig> {
if !self.config.has_ssm() {
return None;
}
if matches!(self.architecture, Architecture::Mamba | Architecture::Mamba2) {
Some(super::deltanet::RecurrentConfig::Mamba(MambaConfig {
d_inner: self.config.ssm_d_inner,
d_state: self.config.ssm_d_state,
dt_rank: self.config.ssm_dt_rank,
conv_kernel: self.config.ssm_conv_kernel.max(1),
}))
} else if let Some(dn) = self.deltanet_config() {
Some(super::deltanet::RecurrentConfig::DeltaNet(dn))
} else {
None
}
}
fn gguf_try_load_tensor(&self, name: &str) -> Option<Tensor> {
let tensor_info = self.gguf.data.get_tensor(name)?;
let tensor_data = self.gguf.tensor_data(name)?;
let shape: Vec<usize> = tensor_info.dims.iter().map(|&d| d as usize).collect();
let dtype = DType::from(tensor_info.dtype);
Tensor::new(tensor_data.to_vec(), shape, dtype)
.ok()
.map(|mut t| {
t.set_name(name);
t
})
}
fn gguf_load_tensor(&self, name: &str) -> ModelResult<Tensor> {
let tensor_info = self
.gguf
.data
.get_tensor(name)
.ok_or_else(|| ModelError::MissingTensor(name.into()))?;
let tensor_data = self
.gguf
.tensor_data(name)
.ok_or_else(|| ModelError::MissingTensor(name.into()))?;
let shape: Vec<usize> = tensor_info.dims.iter().map(|&d| d as usize).collect();
let dtype = DType::from(tensor_info.dtype);
let mut tensor = Tensor::new(tensor_data.to_vec(), shape, dtype)?;
tensor.set_name(name);
Ok(tensor)
}
}
impl ModelSource for ModelLoader {
fn config(&self) -> &ModelConfig {
&self.config
}
fn config_mut(&mut self) -> &mut ModelConfig {
&mut self.config
}
fn architecture(&self) -> Architecture {
self.architecture
}
fn load_tensor(&self, name: &str) -> ModelResult<Tensor> {
self.gguf_load_tensor(name)
}
fn try_load_tensor(&self, name: &str) -> Option<Tensor> {
self.gguf_try_load_tensor(name)
}
}
pub fn build_llama_model(source: &dyn ModelSource) -> ModelResult<LlamaModel> {
let token_embedding = source.load_tensor("token_embd.weight")?;
let config = source.config();
let mut layers = Vec::with_capacity(config.num_layers);
for i in 0..config.num_layers {
let layer = load_transformer_layer(source, i)?;
layers.push(layer);
}
let recurrent_count = layers.iter().filter(|l| l.is_recurrent()).count();
if recurrent_count > 0 {
tracing::info!(
"Model has {}/{} DeltaNet recurrent layers",
recurrent_count,
layers.len()
);
}
let norm_weight = apply_gemma_norm_weight_offset(source.load_tensor("output_norm.weight")?)?;
let norm = if let Some(bias) = source.try_load_tensor("output_norm.bias") {
NormLayer::Layer(LayerNorm::new(norm_weight, bias, config.norm_eps)?)
} else {
NormLayer::RMS(RMSNorm::new(norm_weight, config.norm_eps)?)
};
let output_bias = source.try_load_tensor("output.bias");
let output =
if config.tie_word_embeddings || source.try_load_tensor("output.weight").is_none() {
Linear::new(token_embedding.clone(), output_bias)?
} else {
let output_weight = source.load_tensor("output.weight")?;
Linear::new(output_weight, output_bias)?
};
let per_layer_token_embd = source.try_load_tensor("per_layer_token_embd.weight");
let per_layer_model_proj = source
.try_load_tensor("per_layer_model_proj.weight")
.map(|w| {
if w.dtype() != DType::F32 {
let backend = crate::backend::default_backend();
let mut deq = Tensor::zeros(vec![w.numel()], DType::F32);
backend
.dequantize(&w, &mut deq)
.map_err(|e| {
ModelError::ConfigError(format!(
"Failed to dequantize per_layer_model_proj: {e}"
))
})?;
let shape = w.shape().to_vec();
let deq = deq.reshape(shape)?;
Linear::new(deq, None)
} else {
Linear::new(w, None)
}
})
.transpose()?;
let per_layer_proj_norm = source
.try_load_tensor("per_layer_proj_norm.weight")
.map(|w| RMSNorm::new(w, config.norm_eps))
.transpose()?;
let n_epl = per_layer_proj_norm
.as_ref()
.map(|n| n.hidden_size)
.unwrap_or(0);
if n_epl > 0 {
tracing::info!(
"Gemma 4 PLIE active: n_epl={}, n_layers={}, total_pl_dim={}",
n_epl,
config.num_layers,
n_epl * config.num_layers
);
}
LlamaModel::new(
config.clone(),
token_embedding,
layers,
norm,
output,
source.architecture(),
per_layer_token_embd,
per_layer_model_proj,
per_layer_proj_norm,
n_epl,
)
}
fn load_transformer_layer(source: &dyn ModelSource, layer_idx: usize) -> ModelResult<TransformerLayer> {
let prefix = format!("blk.{}", layer_idx);
let config = source.config();
let arch = source.architecture();
let is_mamba = matches!(arch, Architecture::Mamba | Architecture::Mamba2);
let attn_norm_weight = source
.try_load_tensor(&format!("{}.attn_norm.weight", prefix))
.or_else(|| source.try_load_tensor(&format!("{}.norm.weight", prefix)))
.ok_or_else(|| ModelError::MissingTensor(format!("{}.attn_norm.weight", prefix)))?;
let attn_norm_weight = apply_gemma_norm_weight_offset(attn_norm_weight)?;
let attn_norm_bias = source
.try_load_tensor(&format!("{}.attn_norm.bias", prefix))
.or_else(|| source.try_load_tensor(&format!("{}.norm.bias", prefix)));
let attn_norm = if let Some(bias) = attn_norm_bias {
NormLayer::Layer(LayerNorm::new(attn_norm_weight, bias, config.norm_eps)?)
} else {
NormLayer::RMS(RMSNorm::new(attn_norm_weight, config.norm_eps)?)
};
let attn_layer = load_attention_layer(source, layer_idx)?;
let ffn_norm_weight = source.try_load_tensor(&format!("{}.ffn_norm.weight", prefix));
let ffn_norm_bias = source.try_load_tensor(&format!("{}.ffn_norm.bias", prefix));
let post_attn_w =
source.try_load_tensor(&format!("{}.post_attention_norm.weight", prefix));
let post_attn_b = source.try_load_tensor(&format!("{}.post_attention_norm.bias", prefix));
let (ffn_norm, post_attn_norm) = if let Some(w) = ffn_norm_weight {
let w = apply_gemma_norm_weight_offset(w)?;
let ffn = if let Some(bias) = ffn_norm_bias {
NormLayer::Layer(LayerNorm::new(w, bias, config.norm_eps)?)
} else {
NormLayer::RMS(RMSNorm::new(w, config.norm_eps)?)
};
let pan = post_attn_w
.map(|w| -> ModelResult<NormLayer> {
let w = apply_gemma_norm_weight_offset(w)?;
Ok(if let Some(bias) = post_attn_b {
NormLayer::Layer(LayerNorm::new(w, bias, config.norm_eps)?)
} else {
NormLayer::RMS(RMSNorm::new(w, config.norm_eps)?)
})
})
.transpose()?;
(ffn, pan)
} else if let Some(w) = post_attn_w {
let w = apply_gemma_norm_weight_offset(w)?;
let ffn = if let Some(bias) = post_attn_b {
NormLayer::Layer(LayerNorm::new(w, bias, config.norm_eps)?)
} else {
NormLayer::RMS(RMSNorm::new(w, config.norm_eps)?)
};
(ffn, None)
} else if is_mamba || config.use_parallel_residual {
let hidden = config.hidden_size;
let ffn = NormLayer::RMS(RMSNorm::new(
Tensor::from_f32(&vec![1.0f32; hidden], vec![hidden])?,
config.norm_eps,
)?);
(ffn, None)
} else {
return Err(ModelError::MissingTensor(format!(
"{}.ffn_norm.weight",
prefix
)));
};
let ffn_layer = if config.is_moe() {
load_moe_layer(source, layer_idx)?
} else if is_mamba
&& source.try_load_tensor(&format!("{}.ffn_up.weight", prefix)).is_none()
{
FfnLayer::Identity
} else if !config.has_ffn_gate {
let up_tensor = source.load_tensor(&format!("{}.ffn_up.weight", prefix))?;
let up_out_dim = up_tensor.shape()[up_tensor.ndim() - 1];
let intermediate = config.intermediate_size;
if up_out_dim == 2 * intermediate {
let hidden = config.hidden_size;
let up_f32 = if up_tensor.dtype() == DType::F32 {
up_tensor.as_f32()?.to_vec()
} else {
let backend = crate::backend::default_backend();
let mut deq = Tensor::zeros(vec![up_tensor.numel()], DType::F32);
backend
.dequantize(&up_tensor, &mut deq)
.map_err(|e| ModelError::ConfigError(format!("Dequant ffn_up: {}", e)))?;
deq.as_f32()?.to_vec()
};
let gate_data = &up_f32[..hidden * intermediate];
let up_data = &up_f32[hidden * intermediate..];
let w_gate = Linear::new(
Tensor::from_f32(gate_data, vec![hidden, intermediate])?,
None,
)?;
let w_up = Linear::new(
Tensor::from_f32(up_data, vec![hidden, intermediate])?,
None,
)?;
let w_down = Linear::new(
source.load_tensor(&format!("{}.ffn_down.weight", prefix))?,
None,
)?;
let mut ffn = FeedForward::new(w_gate, w_up, w_down);
ffn.use_gelu = config.uses_gelu;
FfnLayer::Dense(ffn)
} else {
let w_up = Linear::new(
up_tensor,
source.try_load_tensor(&format!("{}.ffn_up.bias", prefix)),
)?;
let w_down = Linear::new(
source.load_tensor(&format!("{}.ffn_down.weight", prefix))?,
source.try_load_tensor(&format!("{}.ffn_down.bias", prefix)),
)?;
FfnLayer::NoGate(NoGateFeedForward::new(
w_up,
w_down,
config.uses_gelu,
))
}
} else {
let w_gate = Linear::new(
source.load_tensor(&format!("{}.ffn_gate.weight", prefix))?,
None,
)?;
let w_up = Linear::new(
source.load_tensor(&format!("{}.ffn_up.weight", prefix))?,
None,
)?;
let w_down = Linear::new(
source.load_tensor(&format!("{}.ffn_down.weight", prefix))?,
None,
)?;
let mut ffn = FeedForward::new(w_gate, w_up, w_down);
ffn.use_gelu = config.uses_gelu;
FfnLayer::Dense(ffn)
};
let post_ffn_norm =
if let Some(w) = source.try_load_tensor(&format!("{}.post_ffw_norm.weight", prefix)) {
let w = apply_gemma_norm_weight_offset(w)?;
let b = source.try_load_tensor(&format!("{}.post_ffw_norm.bias", prefix));
Some(if let Some(bias) = b {
NormLayer::Layer(LayerNorm::new(w, bias, config.norm_eps)?)
} else {
NormLayer::RMS(RMSNorm::new(w, config.norm_eps)?)
})
} else {
None
};
let rope_freq_base_override = config
.attention_layer_configs
.as_ref()
.map(|cfgs| cfgs[layer_idx].rope_freq_base)
.unwrap_or(0.0);
let plie_inp_gate = source
.try_load_tensor(&format!("{}.inp_gate.weight", prefix))
.map(|w| Linear::new(w, None))
.transpose()?;
let plie_proj = source
.try_load_tensor(&format!("{}.proj.weight", prefix))
.map(|w| Linear::new(w, None))
.transpose()?;
let plie_post_norm = source
.try_load_tensor(&format!("{}.post_norm.weight", prefix))
.map(|w| RMSNorm::new(w, config.norm_eps))
.transpose()?;
let layer_output_scale = source
.try_load_tensor(&format!("{}.layer_output_scale.weight", prefix))
.and_then(|t| {
if t.dtype() == crate::tensor::DType::F32 {
t.as_f32().ok().map(|d| d[0])
} else {
let raw = t.data();
if raw.len() >= 2 {
let bits = u16::from_le_bytes([raw[0], raw[1]]);
Some(f32::from_bits((bits as u32) << 16))
} else {
None
}
}
});
Ok(TransformerLayer {
attn_norm,
attn_layer,
post_attn_norm,
ffn_norm,
ffn_layer,
post_ffn_norm,
layer_idx,
use_parallel_residual: config.use_parallel_residual,
rope_freq_base_override,
plie_inp_gate,
plie_proj,
plie_post_norm,
layer_output_scale,
})
}
fn load_attention_layer(source: &dyn ModelSource, layer_idx: usize) -> ModelResult<AttentionLayer> {
let prefix = format!("blk.{}", layer_idx);
let config = source.config();
if let Some(wq_weight) = source.try_load_tensor(&format!("{}.attn_q.weight", prefix)) {
let attn = load_full_attention(source, layer_idx, wq_weight)?;
Ok(AttentionLayer::FullAttention(attn))
} else if let Some(qkv_weight) =
source.try_load_tensor(&format!("{}.attn_qkv.weight", prefix))
{
if config.has_ssm() {
let dn = load_deltanet_layer(source, layer_idx)?;
Ok(AttentionLayer::DeltaNet(Box::new(dn)))
} else {
let attn = load_combined_qkv_attention(source, layer_idx, qkv_weight)?;
Ok(AttentionLayer::FullAttention(attn))
}
} else if config.has_ssm()
&& source.try_load_tensor(&format!("{}.ssm_in.weight", prefix)).is_some()
{
let mamba = load_mamba_layer(source, layer_idx)?;
Ok(AttentionLayer::Mamba(Box::new(mamba)))
} else {
Err(ModelError::MissingTensor(format!(
"{}.attn_q.weight or {}.attn_qkv.weight or {}.ssm_in.weight",
prefix, prefix, prefix
)))
}
}
fn load_full_attention(
source: &dyn ModelSource,
layer_idx: usize,
wq_weight: Tensor,
) -> ModelResult<Attention> {
let prefix = format!("blk.{}", layer_idx);
let config = source.config();
let arch = source.architecture();
let use_neox_rope = matches!(config.rope_config.rope_type, RopeType::NeoX);
let (num_kv_heads, head_dim, kl, vl, rope_dims) =
if let Some(ref layer_configs) = config.attention_layer_configs {
let lc = &layer_configs[layer_idx];
(lc.num_kv_heads, lc.head_dim, lc.head_dim, lc.head_dim, lc.rope_dims)
} else {
let kl = config.key_length;
let vl = config.value_length;
let rope_dims = config.rope_config.n_dims;
(config.num_kv_heads, config.head_dim, kl, vl, rope_dims)
};
let wq_bias = source.try_load_tensor(&format!("{}.attn_q.bias", prefix));
let actual_q_out = wq_weight.shape()[1];
let has_attention_gate = actual_q_out == config.num_heads * (kl + vl);
let wq = Linear::new(wq_weight, wq_bias)?;
let wk_bias = source.try_load_tensor(&format!("{}.attn_k.bias", prefix));
let wk = Linear::new(
source.load_tensor(&format!("{}.attn_k.weight", prefix))?,
wk_bias,
)?;
let wv_bias = source.try_load_tensor(&format!("{}.attn_v.bias", prefix));
let wv = Linear::new(
source.load_tensor(&format!("{}.attn_v.weight", prefix))?,
wv_bias,
)?;
let wo_bias = source.try_load_tensor(&format!("{}.attn_output.bias", prefix));
let wo = Linear::new(
source.load_tensor(&format!("{}.attn_output.weight", prefix))?,
wo_bias,
)?;
let mut attention = Attention::with_kv_dims(
wq, wk, wv, wo,
config.num_heads,
num_kv_heads,
head_dim,
kl, vl, rope_dims,
use_neox_rope,
has_attention_gate,
);
if arch.uses_qk_norm()
&& let (Some(q_norm_w), Some(k_norm_w)) = (
source.try_load_tensor(&format!("{}.attn_q_norm.weight", prefix)),
source.try_load_tensor(&format!("{}.attn_k_norm.weight", prefix)),
)
{
let q_norm = RMSNorm::new(q_norm_w, config.norm_eps)?;
let k_norm = RMSNorm::new(k_norm_w, config.norm_eps)?;
attention.set_qk_norms(q_norm, k_norm);
}
if config.attn_logit_softcap > 0.0 {
attention.set_attn_logit_softcap(config.attn_logit_softcap);
}
if matches!(arch, Architecture::Qwen3Next | Architecture::Qwen35Moe) {
attention.set_rope_partial_at_end(true);
}
if let Some(ref sections) = config.rope_config.mrope_sections {
attention.mrope_sections = Some(sections.clone());
}
if let Some(ref layer_configs) = config.attention_layer_configs {
let lc = &layer_configs[layer_idx];
if lc.sliding_window > 0 {
attention.set_sliding_window(lc.sliding_window);
}
if lc.rope_dims < lc.head_dim {
attention.set_rope_freq_dim(lc.head_dim);
}
attention.normalize_v = true;
attention.scale = 1.0;
}
Ok(attention)
}
fn load_combined_qkv_attention(
source: &dyn ModelSource,
layer_idx: usize,
qkv_weight: Tensor,
) -> ModelResult<Attention> {
let prefix = format!("blk.{}", layer_idx);
let config = source.config();
let use_neox_rope = matches!(config.rope_config.rope_type, RopeType::NeoX);
let kl = config.key_length;
let vl = config.value_length;
let rope_dims = config.rope_config.n_dims;
let num_heads = config.num_heads;
let num_kv_heads = config.num_kv_heads;
let head_dim = config.head_dim;
let qkv_shape = qkv_weight.shape();
let in_features = qkv_shape[0];
let q_size = num_heads * head_dim;
let k_size = num_kv_heads * head_dim;
let v_size = num_kv_heads * head_dim;
let qkv_bias = source.try_load_tensor(&format!("{}.attn_qkv.bias", prefix));
if qkv_weight.dtype() == DType::F32 {
let qkv_f32 = qkv_weight.as_f32()?;
let q_start = 0;
let k_start = q_size * in_features;
let v_start = (q_size + k_size) * in_features;
let q_tensor = Tensor::from_f32(
&qkv_f32[q_start..q_start + q_size * in_features],
vec![in_features, q_size],
)?;
let k_tensor = Tensor::from_f32(
&qkv_f32[k_start..k_start + k_size * in_features],
vec![in_features, k_size],
)?;
let v_tensor = Tensor::from_f32(
&qkv_f32[v_start..v_start + v_size * in_features],
vec![in_features, v_size],
)?;
let (q_bias, k_bias, v_bias) = if let Some(ref bias) = qkv_bias {
let b = bias.as_f32()?;
let qb = Tensor::from_f32(&b[..q_size], vec![q_size])?;
let kb = Tensor::from_f32(&b[q_size..q_size + k_size], vec![k_size])?;
let vb = Tensor::from_f32(&b[q_size + k_size..], vec![v_size])?;
(Some(qb), Some(kb), Some(vb))
} else {
(None, None, None)
};
let wq = Linear::new(q_tensor, q_bias)?;
let wk = Linear::new(k_tensor, k_bias)?;
let wv = Linear::new(v_tensor, v_bias)?;
let wo_bias = source.try_load_tensor(&format!("{}.attn_output.bias", prefix));
let wo = Linear::new(
source.load_tensor(&format!("{}.attn_output.weight", prefix))?,
wo_bias,
)?;
Ok(Attention::with_kv_dims(
wq, wk, wv, wo,
num_heads, num_kv_heads, head_dim,
kl, vl, rope_dims,
use_neox_rope, false,
))
} else {
let backend = crate::backend::default_backend();
let numel = qkv_weight.numel();
let mut dequant = Tensor::zeros(vec![numel], DType::F32);
backend
.dequantize(&qkv_weight, &mut dequant)
.map_err(|e| ModelError::ConfigError(format!("Failed to dequantize QKV: {}", e)))?;
let qkv_f32 = dequant.as_f32()?;
let q_start = 0;
let k_start = q_size * in_features;
let v_start = (q_size + k_size) * in_features;
let q_tensor = Tensor::from_f32(
&qkv_f32[q_start..q_start + q_size * in_features],
vec![in_features, q_size],
)?;
let k_tensor = Tensor::from_f32(
&qkv_f32[k_start..k_start + k_size * in_features],
vec![in_features, k_size],
)?;
let v_tensor = Tensor::from_f32(
&qkv_f32[v_start..v_start + v_size * in_features],
vec![in_features, v_size],
)?;
let (q_bias, k_bias, v_bias) = if let Some(ref bias) = qkv_bias {
let b = bias.as_f32()?;
let qb = Tensor::from_f32(&b[..q_size], vec![q_size])?;
let kb = Tensor::from_f32(&b[q_size..q_size + k_size], vec![k_size])?;
let vb = Tensor::from_f32(&b[q_size + k_size..], vec![v_size])?;
(Some(qb), Some(kb), Some(vb))
} else {
(None, None, None)
};
let wq = Linear::new(q_tensor, q_bias)?;
let wk = Linear::new(k_tensor, k_bias)?;
let wv = Linear::new(v_tensor, v_bias)?;
let wo_bias = source.try_load_tensor(&format!("{}.attn_output.bias", prefix));
let wo = Linear::new(
source.load_tensor(&format!("{}.attn_output.weight", prefix))?,
wo_bias,
)?;
Ok(Attention::with_kv_dims(
wq, wk, wv, wo,
num_heads, num_kv_heads, head_dim,
kl, vl, rope_dims,
use_neox_rope, false,
))
}
}
fn load_deltanet_layer(source: &dyn ModelSource, layer_idx: usize) -> ModelResult<DeltaNetLayer> {
let prefix = format!("blk.{}", layer_idx);
let cfg = source.config();
let d_inner = cfg.ssm_d_inner;
let d_state = cfg.ssm_d_state;
let num_v_heads = cfg.ssm_dt_rank;
let num_k_heads = cfg.ssm_n_group;
let head_v_dim = d_inner / num_v_heads;
let head_k_dim = d_state;
let conv_kernel = cfg.ssm_conv_kernel;
let q_dim = num_k_heads * head_k_dim;
let k_dim = num_k_heads * head_k_dim;
let qkv_dim = q_dim + k_dim + d_inner;
let dn_config = DeltaNetConfig {
d_inner,
d_state,
num_v_heads,
num_k_heads,
head_v_dim,
head_k_dim,
conv_kernel,
qkv_dim,
};
let attn_qkv = Linear::new(
source.load_tensor(&format!("{}.attn_qkv.weight", prefix))?,
None,
)?;
let attn_gate = Linear::new(
source.load_tensor(&format!("{}.attn_gate.weight", prefix))?,
None,
)?;
let ssm_ba = if let Some(ba_weight) =
source.try_load_tensor(&format!("{}.ssm_ba.weight", prefix))
{
BetaAlphaProjection::Combined(Linear::new(ba_weight, None)?)
} else {
let beta_w = source.load_tensor(&format!("{}.ssm_beta.weight", prefix))?;
let alpha_w = source.load_tensor(&format!("{}.ssm_alpha.weight", prefix))?;
BetaAlphaProjection::Separate {
beta: Linear::new(beta_w, None)?,
alpha: Linear::new(alpha_w, None)?,
}
};
let ssm_conv1d_weight = source.load_tensor(&format!("{}.ssm_conv1d.weight", prefix))?;
let ssm_a = source.load_tensor(&format!("{}.ssm_a", prefix))?;
let ssm_dt_bias = source.load_tensor(&format!("{}.ssm_dt.bias", prefix))?;
let ssm_norm_weight = source.load_tensor(&format!("{}.ssm_norm.weight", prefix))?;
let ssm_norm = RMSNorm::new(ssm_norm_weight, cfg.norm_eps)?;
let ssm_out = Linear::new(
source.load_tensor(&format!("{}.ssm_out.weight", prefix))?,
None,
)?;
tracing::info!("Layer {}: loaded DeltaNet (d_inner={}, d_state={}, v_heads={}, k_heads={}, conv={})",
layer_idx, d_inner, d_state, num_v_heads, num_k_heads, conv_kernel);
Ok(DeltaNetLayer {
config: dn_config,
attn_qkv,
attn_gate,
ssm_ba,
ssm_conv1d_weight,
ssm_a,
ssm_dt_bias,
ssm_norm,
ssm_out,
})
}
fn load_mamba_layer(source: &dyn ModelSource, layer_idx: usize) -> ModelResult<MambaLayer> {
let prefix = format!("blk.{}", layer_idx);
let cfg = source.config();
let d_inner = cfg.ssm_d_inner;
let d_state = cfg.ssm_d_state;
let dt_rank = cfg.ssm_dt_rank;
let conv_kernel = cfg.ssm_conv_kernel.max(1);
let mamba_config = MambaConfig {
d_inner,
d_state,
dt_rank,
conv_kernel,
};
let ssm_in = Linear::new(
source.load_tensor(&format!("{}.ssm_in.weight", prefix))?,
None,
)?;
let ssm_conv1d_weight = source.load_tensor(&format!("{}.ssm_conv1d.weight", prefix))?;
let ssm_conv1d_bias = source.try_load_tensor(&format!("{}.ssm_conv1d.bias", prefix));
let ssm_x = Linear::new(
source.load_tensor(&format!("{}.ssm_x.weight", prefix))?,
None,
)?;
let ssm_dt = Linear::new(
source.load_tensor(&format!("{}.ssm_dt.weight", prefix))?,
None,
)?;
let ssm_dt_bias = source.load_tensor(&format!("{}.ssm_dt.bias", prefix))?;
let ssm_a = source.load_tensor(&format!("{}.ssm_a", prefix))?;
let ssm_d = source.try_load_tensor(&format!("{}.ssm_d", prefix));
let ssm_norm = match source.try_load_tensor(&format!("{}.ssm_norm.weight", prefix)) {
Some(w) => Some(RMSNorm::new(w, cfg.norm_eps)?),
None => None,
};
let ssm_out = Linear::new(
source.load_tensor(&format!("{}.ssm_out.weight", prefix))?,
None,
)?;
tracing::info!(
"Layer {}: loaded Mamba SSM (d_inner={}, d_state={}, dt_rank={}, conv={})",
layer_idx, d_inner, d_state, dt_rank, conv_kernel
);
Ok(MambaLayer {
ssm_in,
ssm_conv1d_weight,
ssm_conv1d_bias,
ssm_x,
ssm_dt,
ssm_dt_bias,
ssm_a,
ssm_d,
ssm_norm,
ssm_out,
config: mamba_config,
})
}
fn load_moe_layer(source: &dyn ModelSource, layer_idx: usize) -> ModelResult<FfnLayer> {
let prefix = format!("blk.{}", layer_idx);
let config = source.config();
let num_experts = config.num_experts;
let hidden_dim = config.hidden_size;
let expert_ffn_dim = if config.expert_intermediate_size > 0 {
config.expert_intermediate_size
} else {
config.intermediate_size / config.num_experts_per_token
};
let router_weight = source.load_tensor(&format!("{}.ffn_gate_inp.weight", prefix))?;
let router = MoeRouter::from_weight(
router_weight,
config.num_experts_per_token,
false, );
let gate_exps = source.load_tensor(&format!("{}.ffn_gate_exps.weight", prefix))?;
let up_exps = source.load_tensor(&format!("{}.ffn_up_exps.weight", prefix))?;
let down_exps = source.load_tensor(&format!("{}.ffn_down_exps.weight", prefix))?;
let mut experts = Vec::with_capacity(num_experts);
for e in 0..num_experts {
let mut gate_proj = extract_expert_tensor(&gate_exps, e)?;
let mut up_proj = extract_expert_tensor(&up_exps, e)?;
let mut down_proj = extract_expert_tensor(&down_exps, e)?;
gate_proj.set_name(format!("blk.{}.ffn_gate.{}.weight", layer_idx, e));
up_proj.set_name(format!("blk.{}.ffn_up.{}.weight", layer_idx, e));
down_proj.set_name(format!("blk.{}.ffn_down.{}.weight", layer_idx, e));
experts.push(MoeExpert {
gate_proj,
up_proj,
down_proj,
use_gelu: config.uses_gelu,
});
}
let mut shared_experts = Vec::new();
if let (Some(mut gate_shexp), Some(mut up_shexp), Some(mut down_shexp)) = (
source.try_load_tensor(&format!("{}.ffn_gate_shexp.weight", prefix)),
source.try_load_tensor(&format!("{}.ffn_up_shexp.weight", prefix)),
source.try_load_tensor(&format!("{}.ffn_down_shexp.weight", prefix)),
) {
gate_shexp.set_name(format!("blk.{}.ffn_gate_shexp.0.weight", layer_idx));
up_shexp.set_name(format!("blk.{}.ffn_up_shexp.0.weight", layer_idx));
down_shexp.set_name(format!("blk.{}.ffn_down_shexp.0.weight", layer_idx));
shared_experts.push(MoeExpert {
gate_proj: gate_shexp,
up_proj: up_shexp,
down_proj: down_shexp,
use_gelu: config.uses_gelu,
});
}
let shared_expert_gate =
source.try_load_tensor(&format!("{}.ffn_gate_inp_shexp.weight", prefix))
.map(|t| {
if t.dtype() == DType::F32 {
t
} else {
let raw = t.data();
let f32_vals: Vec<f32> = match t.dtype() {
DType::BF16 => {
raw.chunks_exact(2)
.map(|c| {
let bits = u16::from_le_bytes([c[0], c[1]]);
f32::from_bits((bits as u32) << 16)
})
.collect()
}
_ => {
tracing::warn!("Unsupported dtype for shared expert gate, zeroing");
vec![0.0f32; t.numel()]
}
};
let shape = t.shape().to_vec();
Tensor::from_f32(&f32_vals, shape).unwrap()
}
});
if shared_expert_gate.is_some() {
tracing::debug!("Layer {}: loaded shared expert gate", layer_idx);
}
let num_shared = shared_experts.len();
let moe_config = MoeConfig {
num_experts,
num_experts_per_token: config.num_experts_per_token,
expert_hidden_dim: expert_ffn_dim,
num_shared_experts: num_shared,
aux_loss_coef: 0.0,
normalize_router_logits: false,
};
let mut moe_layer = MoeLayer::new(hidden_dim, moe_config);
moe_layer.router = router;
moe_layer.experts = experts;
moe_layer.shared_experts = shared_experts;
moe_layer.shared_expert_gate = shared_expert_gate;
Ok(FfnLayer::Moe(moe_layer))
}
fn extract_expert_tensor(
batched: &Tensor,
expert_idx: usize,
) -> ModelResult<Tensor> {
let shape = batched.shape();
if shape.len() != 3 {
return Err(ModelError::ConfigError(format!(
"Expected 3D batched expert tensor, got shape {:?}",
shape
)));
}
let ne0 = shape[0];
let ne1 = shape[1];
let num_experts = shape[2];
let expert_numel = ne0 * ne1;
if expert_idx >= num_experts {
return Err(ModelError::ConfigError(format!(
"Expert index {} out of bounds ({})",
expert_idx, num_experts
)));
}
let per_expert_shape = vec![ne0, ne1];
if batched.dtype().is_quantized() {
let block_size = batched.dtype().block_size();
let block_bytes = batched.dtype().block_bytes();
if !expert_numel.is_multiple_of(block_size) {
return Err(ModelError::ConfigError(format!(
"Expert tensor elements ({}) not aligned to block size ({})",
expert_numel, block_size
)));
}
let blocks_per_expert = expert_numel / block_size;
let bytes_per_expert = blocks_per_expert * block_bytes;
let byte_offset = expert_idx * bytes_per_expert;
let raw_data = batched.data();
let expert_bytes = &raw_data[byte_offset..byte_offset + bytes_per_expert];
let mut tensor =
Tensor::new(expert_bytes.to_vec(), per_expert_shape, batched.dtype())?;
tensor.set_name(format!("expert.{}", expert_idx));
Ok(tensor)
} else {
let f32_data = batched.as_f32()?;
let offset = expert_idx * expert_numel;
let expert_slice = &f32_data[offset..offset + expert_numel];
let mut tensor = Tensor::from_f32(expert_slice, per_expert_shape)?;
tensor.set_name(format!("expert.{}", expert_idx));
Ok(tensor)
}
}
fn apply_gemma_norm_weight_offset(weight: Tensor) -> ModelResult<Tensor> {
Ok(weight)
}
pub fn deltanet_config_from_source(source: &dyn ModelSource) -> Option<DeltaNetConfig> {
let config = source.config();
let arch = source.architecture();
if !config.has_ssm()
|| matches!(arch, Architecture::Mamba | Architecture::Mamba2)
{
return None;
}
let d_inner = config.ssm_d_inner;
let d_state = config.ssm_d_state;
let num_v_heads = config.ssm_dt_rank;
let num_k_heads = config.ssm_n_group.max(1);
let head_v_dim = d_inner / num_v_heads.max(1);
let head_k_dim = d_state;
let conv_kernel = config.ssm_conv_kernel;
let q_dim = num_k_heads * head_k_dim;
let k_dim = num_k_heads * head_k_dim;
let qkv_dim = q_dim + k_dim + d_inner;
Some(DeltaNetConfig {
d_inner,
d_state,
num_v_heads,
num_k_heads,
head_v_dim,
head_k_dim,
conv_kernel,
qkv_dim,
})
}
pub fn load_llama_model<P: AsRef<Path>>(path: P) -> ModelResult<LlamaModel> {
let loader = ModelLoader::load(path)?;
if !loader.architecture().is_llama_like() {
return Err(ModelError::UnsupportedArchitecture(
loader.architecture().to_string(),
));
}
loader.build_model()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_architecture_detection() {
assert!(Architecture::Llama.is_llama_like());
assert!(Architecture::Mistral.is_llama_like());
assert!(Architecture::GPT2.is_llama_like());
assert!(!Architecture::Bert.is_llama_like());
assert!(!Architecture::Mamba.is_llama_like());
}
}