use crate::types::{Error, Result};
use burn::module::Param;
use burn::nn::attention::MultiHeadAttention;
use burn::nn::{Embedding, LayerNorm, Linear};
use burn::tensor::backend::Backend;
use burn::tensor::{Tensor, TensorData};
use safetensors::{Dtype, SafeTensors};
use std::collections::HashMap;
use std::path::Path;
pub struct WeightLoader<'a> {
tensors: SafeTensors<'a>,
}
#[derive(Debug)]
pub struct LoadedTensor {
pub data: Vec<f32>,
pub shape: Vec<usize>,
}
#[derive(Debug)]
pub struct RawTensor {
pub data: Vec<u8>,
pub shape: Vec<usize>,
pub dtype: Dtype,
}
impl LoadedTensor {
pub fn to_tensor<B: Backend, const D: usize>(
&self,
device: &B::Device,
expected_shape: [usize; D],
) -> Result<Tensor<B, D>> {
if self.shape.len() != D {
return Err(Error::LoadError(format!(
"Shape dimension mismatch: expected {}, got {}",
D,
self.shape.len()
)));
}
for (i, (&expected, &actual)) in expected_shape.iter().zip(self.shape.iter()).enumerate() {
if expected != actual {
return Err(Error::LoadError(format!(
"Shape mismatch at dim {}: expected {}, got {}",
i, expected, actual
)));
}
}
let tensor_data = TensorData::new(self.data.clone(), expected_shape);
Ok(Tensor::from_data(tensor_data, device))
}
pub fn to_tensor_transposed<B: Backend>(
&self,
device: &B::Device,
) -> Result<Tensor<B, 2>> {
if self.shape.len() != 2 {
return Err(Error::LoadError(format!(
"Transpose requires 2D tensor, got {}D",
self.shape.len()
)));
}
let [out_features, in_features] = [self.shape[0], self.shape[1]];
let mut transposed = vec![0.0f32; self.data.len()];
for i in 0..in_features {
for o in 0..out_features {
transposed[i * out_features + o] = self.data[o * in_features + i];
}
}
let tensor_data = TensorData::new(transposed, [in_features, out_features]);
Ok(Tensor::from_data(tensor_data, device))
}
}
impl<'a> WeightLoader<'a> {
pub fn from_bytes(bytes: &'a [u8]) -> Result<Self> {
let tensors = SafeTensors::deserialize(bytes)
.map_err(|e| Error::LoadError(format!("Failed to deserialize SafeTensors: {}", e)))?;
Ok(Self { tensors })
}
pub fn tensor_names(&self) -> Vec<String> {
self.tensors.names().into_iter().map(|s| s.to_string()).collect()
}
pub fn load_tensor(&self, name: &str) -> Result<LoadedTensor> {
let tensor_view = self.tensors.tensor(name).map_err(|e| {
Error::LoadError(format!("Failed to load tensor '{}': {}", name, e))
})?;
let shape: Vec<usize> = tensor_view.shape().to_vec();
let dtype = tensor_view.dtype();
let raw_data = tensor_view.data();
let data = convert_to_f32(&raw_data, dtype)?;
Ok(LoadedTensor { data, shape })
}
pub fn load_raw_tensor(&self, name: &str) -> Result<RawTensor> {
let tensor_view = self.tensors.tensor(name).map_err(|e| {
Error::LoadError(format!("Failed to load tensor '{}': {}", name, e))
})?;
let shape = tensor_view.shape().to_vec();
let dtype = tensor_view.dtype();
let data = tensor_view.data().to_vec();
Ok(RawTensor { data, shape, dtype })
}
pub fn has_tensor(&self, name: &str) -> bool {
self.tensors.tensor(name).is_ok()
}
pub fn is_awq_model(&self) -> bool {
self.has_tensor("model.layers.0.self_attn.q_proj.qweight")
}
}
fn convert_to_f32(data: &[u8], dtype: Dtype) -> Result<Vec<f32>> {
match dtype {
Dtype::F32 => {
let floats: Vec<f32> = data
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
Ok(floats)
}
Dtype::F16 => {
let floats: Vec<f32> = data
.chunks_exact(2)
.map(|chunk| {
let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
half::f16::from_bits(bits).to_f32()
})
.collect();
Ok(floats)
}
Dtype::BF16 => {
let floats: Vec<f32> = data
.chunks_exact(2)
.map(|chunk| {
let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
half::bf16::from_bits(bits).to_f32()
})
.collect();
Ok(floats)
}
Dtype::F64 => {
let floats: Vec<f32> = data
.chunks_exact(8)
.map(|chunk| {
let arr: [u8; 8] = chunk.try_into().unwrap();
f64::from_le_bytes(arr) as f32
})
.collect();
Ok(floats)
}
_ => Err(Error::LoadError(format!(
"Unsupported dtype for weight loading: {:?}",
dtype
))),
}
}
pub fn load_linear<B: Backend>(
loader: &WeightLoader,
weight_name: &str,
bias_name: Option<&str>,
device: &B::Device,
) -> Result<Linear<B>> {
let weight_tensor = loader.load_tensor(weight_name)?;
let weight = weight_tensor.to_tensor_transposed::<B>(device)?;
let weight_param = Param::from_tensor(weight);
let bias = if let Some(bias_name) = bias_name {
if loader.has_tensor(bias_name) {
let bias_tensor = loader.load_tensor(bias_name)?;
let bias_shape = [bias_tensor.shape[0]];
let bias = bias_tensor.to_tensor::<B, 1>(device, bias_shape)?;
Some(Param::from_tensor(bias))
} else {
None
}
} else {
None
};
Ok(Linear { weight: weight_param, bias })
}
pub fn load_embedding<B: Backend>(
loader: &WeightLoader,
weight_name: &str,
device: &B::Device,
) -> Result<Embedding<B>> {
let weight_tensor = loader.load_tensor(weight_name)?;
let shape = [weight_tensor.shape[0], weight_tensor.shape[1]];
let weight = weight_tensor.to_tensor::<B, 2>(device, shape)?;
let weight_param = Param::from_tensor(weight);
Ok(Embedding { weight: weight_param })
}
pub fn load_layer_norm<B: Backend>(
loader: &WeightLoader,
weight_name: &str,
bias_name: Option<&str>,
d_model: usize,
epsilon: f64,
device: &B::Device,
) -> Result<LayerNorm<B>> {
use burn::nn::LayerNormConfig;
let mut layer_norm = LayerNormConfig::new(d_model)
.with_epsilon(epsilon)
.init(device);
let gamma_tensor = loader.load_tensor(weight_name)?;
let gamma = gamma_tensor.to_tensor::<B, 1>(device, [d_model])?;
layer_norm.gamma = Param::from_tensor(gamma);
if let Some(bias_name) = bias_name {
if loader.has_tensor(bias_name) {
let beta_tensor = loader.load_tensor(bias_name)?;
let beta_val = beta_tensor.to_tensor::<B, 1>(device, [d_model])?;
layer_norm.beta = Some(Param::from_tensor(beta_val));
}
}
Ok(layer_norm)
}
pub fn load_mha<B: Backend>(
loader: &WeightLoader,
query_weight: &str,
query_bias: Option<&str>,
key_weight: &str,
key_bias: Option<&str>,
value_weight: &str,
value_bias: Option<&str>,
output_weight: &str,
output_bias: Option<&str>,
d_model: usize,
n_heads: usize,
dropout: f64,
device: &B::Device,
) -> Result<MultiHeadAttention<B>> {
use burn::nn::attention::MultiHeadAttentionConfig;
use burn::nn::Initializer;
let mut mha = MultiHeadAttentionConfig::new(d_model, n_heads)
.with_dropout(dropout)
.with_initializer(Initializer::Zeros) .init(device);
mha.query = load_linear(loader, query_weight, query_bias, device)?;
mha.key = load_linear(loader, key_weight, key_bias, device)?;
mha.value = load_linear(loader, value_weight, value_bias, device)?;
mha.output = load_linear(loader, output_weight, output_bias, device)?;
Ok(mha)
}
pub mod mappings {
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Architecture {
Llama,
Gpt2,
Bert,
Glm,
}
impl Architecture {
pub fn from_model_type(model_type: &str) -> Self {
let lower = model_type.to_lowercase();
if lower.contains("llama")
|| lower.contains("mistral")
|| lower.contains("qwen")
|| lower.contains("deepseek")
|| lower.contains("mixtral") {
Architecture::Llama
} else if lower.contains("gpt2") || lower.contains("gpt-neo") {
Architecture::Gpt2
} else if lower.contains("glm") || lower.contains("chatglm") {
Architecture::Glm
} else {
Architecture::Bert
}
}
pub fn embedding_weight(&self) -> &'static str {
match self {
Architecture::Llama => "model.embed_tokens.weight",
Architecture::Gpt2 => "transformer.wte.weight",
Architecture::Bert => "embeddings.word_embeddings.weight",
Architecture::Glm => "transformer.embedding.word_embeddings.weight",
}
}
pub fn layer_prefix(&self, layer_idx: usize) -> String {
match self {
Architecture::Llama => format!("model.layers.{}", layer_idx),
Architecture::Gpt2 => format!("transformer.h.{}", layer_idx),
Architecture::Bert => format!("encoder.layer.{}", layer_idx),
Architecture::Glm => format!("transformer.encoder.layers.{}", layer_idx),
}
}
pub fn attention_weights(&self, layer_prefix: &str) -> AttentionWeights {
match self {
Architecture::Llama => AttentionWeights {
q_proj_weight: format!("{}.self_attn.q_proj.weight", layer_prefix),
k_proj_weight: format!("{}.self_attn.k_proj.weight", layer_prefix),
v_proj_weight: format!("{}.self_attn.v_proj.weight", layer_prefix),
o_proj_weight: format!("{}.self_attn.o_proj.weight", layer_prefix),
q_proj_bias: None,
k_proj_bias: None,
v_proj_bias: None,
o_proj_bias: None,
},
Architecture::Gpt2 => AttentionWeights {
q_proj_weight: format!("{}.attn.c_attn.weight", layer_prefix),
k_proj_weight: format!("{}.attn.c_attn.weight", layer_prefix),
v_proj_weight: format!("{}.attn.c_attn.weight", layer_prefix),
o_proj_weight: format!("{}.attn.c_proj.weight", layer_prefix),
q_proj_bias: Some(format!("{}.attn.c_attn.bias", layer_prefix)),
k_proj_bias: Some(format!("{}.attn.c_attn.bias", layer_prefix)),
v_proj_bias: Some(format!("{}.attn.c_attn.bias", layer_prefix)),
o_proj_bias: Some(format!("{}.attn.c_proj.bias", layer_prefix)),
},
Architecture::Bert => AttentionWeights {
q_proj_weight: format!("{}.attention.self.query.weight", layer_prefix),
k_proj_weight: format!("{}.attention.self.key.weight", layer_prefix),
v_proj_weight: format!("{}.attention.self.value.weight", layer_prefix),
o_proj_weight: format!("{}.attention.output.dense.weight", layer_prefix),
q_proj_bias: Some(format!("{}.attention.self.query.bias", layer_prefix)),
k_proj_bias: Some(format!("{}.attention.self.key.bias", layer_prefix)),
v_proj_bias: Some(format!("{}.attention.self.value.bias", layer_prefix)),
o_proj_bias: Some(format!("{}.attention.output.dense.bias", layer_prefix)),
},
Architecture::Glm => AttentionWeights {
q_proj_weight: format!("{}.self_attention.query_key_value.weight", layer_prefix),
k_proj_weight: format!("{}.self_attention.query_key_value.weight", layer_prefix),
v_proj_weight: format!("{}.self_attention.query_key_value.weight", layer_prefix),
o_proj_weight: format!("{}.self_attention.dense.weight", layer_prefix),
q_proj_bias: Some(format!("{}.self_attention.query_key_value.bias", layer_prefix)),
k_proj_bias: Some(format!("{}.self_attention.query_key_value.bias", layer_prefix)),
v_proj_bias: Some(format!("{}.self_attention.query_key_value.bias", layer_prefix)),
o_proj_bias: Some(format!("{}.self_attention.dense.bias", layer_prefix)),
},
}
}
pub fn ffn_weights(&self, layer_prefix: &str) -> FfnWeights {
match self {
Architecture::Llama => FfnWeights {
gate_proj_weight: format!("{}.mlp.gate_proj.weight", layer_prefix),
up_proj_weight: format!("{}.mlp.up_proj.weight", layer_prefix),
down_proj_weight: format!("{}.mlp.down_proj.weight", layer_prefix),
gate_proj_bias: None,
up_proj_bias: None,
down_proj_bias: None,
},
Architecture::Gpt2 => FfnWeights {
gate_proj_weight: format!("{}.mlp.c_fc.weight", layer_prefix),
up_proj_weight: format!("{}.mlp.c_fc.weight", layer_prefix),
down_proj_weight: format!("{}.mlp.c_proj.weight", layer_prefix),
gate_proj_bias: Some(format!("{}.mlp.c_fc.bias", layer_prefix)),
up_proj_bias: Some(format!("{}.mlp.c_fc.bias", layer_prefix)),
down_proj_bias: Some(format!("{}.mlp.c_proj.bias", layer_prefix)),
},
Architecture::Bert => FfnWeights {
gate_proj_weight: format!("{}.intermediate.dense.weight", layer_prefix),
up_proj_weight: format!("{}.intermediate.dense.weight", layer_prefix),
down_proj_weight: format!("{}.output.dense.weight", layer_prefix),
gate_proj_bias: Some(format!("{}.intermediate.dense.bias", layer_prefix)),
up_proj_bias: Some(format!("{}.intermediate.dense.bias", layer_prefix)),
down_proj_bias: Some(format!("{}.output.dense.bias", layer_prefix)),
},
Architecture::Glm => FfnWeights {
gate_proj_weight: format!("{}.mlp.dense_h_to_4h.weight", layer_prefix),
up_proj_weight: format!("{}.mlp.dense_h_to_4h.weight", layer_prefix),
down_proj_weight: format!("{}.mlp.dense_4h_to_h.weight", layer_prefix),
gate_proj_bias: Some(format!("{}.mlp.dense_h_to_4h.bias", layer_prefix)),
up_proj_bias: Some(format!("{}.mlp.dense_h_to_4h.bias", layer_prefix)),
down_proj_bias: Some(format!("{}.mlp.dense_4h_to_h.bias", layer_prefix)),
},
}
}
pub fn layer_norm_weights(&self, layer_prefix: &str) -> LayerNormWeights {
match self {
Architecture::Llama => LayerNormWeights {
attention_norm_weight: format!("{}.input_layernorm.weight", layer_prefix),
ffn_norm_weight: format!("{}.post_attention_layernorm.weight", layer_prefix),
attention_norm_bias: None,
ffn_norm_bias: None,
},
Architecture::Gpt2 => LayerNormWeights {
attention_norm_weight: format!("{}.ln_1.weight", layer_prefix),
ffn_norm_weight: format!("{}.ln_2.weight", layer_prefix),
attention_norm_bias: Some(format!("{}.ln_1.bias", layer_prefix)),
ffn_norm_bias: Some(format!("{}.ln_2.bias", layer_prefix)),
},
Architecture::Bert => LayerNormWeights {
attention_norm_weight: format!("{}.attention.output.LayerNorm.weight", layer_prefix),
ffn_norm_weight: format!("{}.output.LayerNorm.weight", layer_prefix),
attention_norm_bias: Some(format!("{}.attention.output.LayerNorm.bias", layer_prefix)),
ffn_norm_bias: Some(format!("{}.output.LayerNorm.bias", layer_prefix)),
},
Architecture::Glm => LayerNormWeights {
attention_norm_weight: format!("{}.input_layernorm.weight", layer_prefix),
ffn_norm_weight: format!("{}.post_attention_layernorm.weight", layer_prefix),
attention_norm_bias: Some(format!("{}.input_layernorm.bias", layer_prefix)),
ffn_norm_bias: Some(format!("{}.post_attention_layernorm.bias", layer_prefix)),
},
}
}
pub fn final_norm_weight(&self) -> &'static str {
match self {
Architecture::Llama => "model.norm.weight",
Architecture::Gpt2 => "transformer.ln_f.weight",
Architecture::Bert => "embeddings.LayerNorm.weight",
Architecture::Glm => "transformer.encoder.final_layernorm.weight",
}
}
pub fn lm_head_weight(&self) -> &'static str {
match self {
Architecture::Llama => "lm_head.weight",
Architecture::Gpt2 => "lm_head.weight",
Architecture::Bert => "cls.predictions.decoder.weight",
Architecture::Glm => "transformer.output_layer.weight",
}
}
}
pub struct AttentionWeights {
pub q_proj_weight: String,
pub k_proj_weight: String,
pub v_proj_weight: String,
pub o_proj_weight: String,
pub q_proj_bias: Option<String>,
pub k_proj_bias: Option<String>,
pub v_proj_bias: Option<String>,
pub o_proj_bias: Option<String>,
}
pub struct FfnWeights {
pub gate_proj_weight: String,
pub up_proj_weight: String,
pub down_proj_weight: String,
pub gate_proj_bias: Option<String>,
pub up_proj_bias: Option<String>,
pub down_proj_bias: Option<String>,
}
pub struct LayerNormWeights {
pub attention_norm_weight: String,
pub ffn_norm_weight: String,
pub attention_norm_bias: Option<String>,
pub ffn_norm_bias: Option<String>,
}
}
pub mod shards {
use super::*;
use std::fs;
#[derive(Debug)]
pub struct ShardIndex {
pub tensor_to_shard: HashMap<String, String>,
pub shard_files: Vec<String>,
}
impl ShardIndex {
pub fn from_index_file(index_path: &Path) -> Result<Self> {
let content = fs::read_to_string(index_path).map_err(|e| {
Error::LoadError(format!("Failed to read shard index: {}", e))
})?;
let index: serde_json::Value = serde_json::from_str(&content).map_err(|e| {
Error::LoadError(format!("Failed to parse shard index: {}", e))
})?;
let weight_map = index
.get("weight_map")
.and_then(|v| v.as_object())
.ok_or_else(|| Error::LoadError("Missing weight_map in index".to_string()))?;
let mut tensor_to_shard = HashMap::new();
let mut shard_files = Vec::new();
for (tensor_name, shard_file) in weight_map {
let shard = shard_file
.as_str()
.ok_or_else(|| Error::LoadError("Invalid shard file name".to_string()))?
.to_string();
tensor_to_shard.insert(tensor_name.clone(), shard.clone());
if !shard_files.contains(&shard) {
shard_files.push(shard);
}
}
Ok(Self {
tensor_to_shard,
shard_files,
})
}
pub fn is_sharded(model_dir: &Path) -> bool {
model_dir.join("model.safetensors.index.json").exists()
}
pub fn shard_paths(&self, model_dir: &Path) -> Vec<std::path::PathBuf> {
self.shard_files
.iter()
.map(|f| model_dir.join(f))
.collect()
}
}
pub fn load_tensor_from_shards(
model_dir: &Path,
index: &ShardIndex,
tensor_name: &str,
) -> Result<LoadedTensor> {
let shard_file = index.tensor_to_shard.get(tensor_name).ok_or_else(|| {
Error::LoadError(format!("Tensor '{}' not found in shard index", tensor_name))
})?;
let shard_path = model_dir.join(shard_file);
let bytes = fs::read(&shard_path).map_err(|e| {
Error::LoadError(format!("Failed to read shard file '{}': {}", shard_file, e))
})?;
let loader = WeightLoader::from_bytes(&bytes)?;
loader.load_tensor(tensor_name)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_convert_f32() {
let data: Vec<u8> = vec![0x00, 0x00, 0x80, 0x3f]; let result = convert_to_f32(&data, Dtype::F32).unwrap();
assert_eq!(result, vec![1.0f32]);
}
#[test]
fn test_tensor_transpose() {
let tensor = LoadedTensor {
data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
shape: vec![2, 3],
};
let mut transposed = vec![0.0f32; 6];
let [out_features, in_features] = [2, 3];
for i in 0..in_features {
for o in 0..out_features {
transposed[i * out_features + o] = tensor.data[o * in_features + i];
}
}
assert_eq!(transposed, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
}
#[test]
fn test_architecture_detection() {
use mappings::Architecture;
assert_eq!(
Architecture::from_model_type("llama"),
Architecture::Llama
);
assert_eq!(
Architecture::from_model_type("Qwen2ForCausalLM"),
Architecture::Llama
);
assert_eq!(
Architecture::from_model_type("bert"),
Architecture::Bert
);
assert_eq!(
Architecture::from_model_type("chatglm"),
Architecture::Glm
);
}
}