use std::path::Path;
use crate::gguf::{GgufFile, MetadataValue};
use crate::tensor::{DType, Tensor};
use super::config::{ActivationType, ModelConfig, RopeConfig, RopeScalingType, RopeType};
use super::error::{ModelError, ModelResult};
use super::layers::{Attention, FeedForward, Linear, RMSNorm, TransformerLayer};
use super::llama::LlamaModel;
use super::Architecture;
pub struct ModelLoader {
gguf: GgufFile,
architecture: Architecture,
config: ModelConfig,
}
impl ModelLoader {
pub fn load<P: AsRef<Path>>(path: P) -> ModelResult<Self> {
let gguf = GgufFile::open(path)?;
let arch_str = gguf
.data
.get_string("general.architecture")
.ok_or_else(|| ModelError::MissingMetadata("general.architecture".into()))?;
let architecture = Architecture::from_gguf_str(arch_str);
if matches!(architecture, Architecture::Unknown) {
return Err(ModelError::UnsupportedArchitecture(arch_str.to_string()));
}
let config = Self::parse_config(&gguf, &architecture)?;
Ok(Self {
gguf,
architecture,
config,
})
}
fn parse_config(gguf: &GgufFile, architecture: &Architecture) -> ModelResult<ModelConfig> {
let arch = architecture.as_str();
let get_u32 = |key: &str| -> ModelResult<u32> {
gguf.data
.get_u32(key)
.ok_or_else(|| ModelError::MissingMetadata(key.into()))
};
let get_f32_or = |key: &str, default: f32| -> f32 {
gguf.data.get_f32(key).unwrap_or(default)
};
let vocab_size = get_u32(&format!("{}.vocab_size", arch))
.or_else(|_| get_u32("tokenizer.ggml.vocab_size"))
.map(|v| v as usize)
.unwrap_or_else(|_| {
if let Some(tokens) = gguf.data.metadata.get("tokenizer.ggml.tokens") {
if let MetadataValue::Array(arr) = tokens {
return arr.values.len();
}
}
if let Some(emb_info) = gguf.data.get_tensor("token_embd.weight") {
if emb_info.dims.len() == 2 {
return emb_info.dims[1] as usize;
}
}
32000
});
let hidden_size = get_u32(&format!("{}.embedding_length", arch))? as usize;
let num_layers = get_u32(&format!("{}.block_count", arch))? as usize;
let num_heads = get_u32(&format!("{}.attention.head_count", arch))? as usize;
let num_kv_heads = get_u32(&format!("{}.attention.head_count_kv", arch))
.unwrap_or(num_heads as u32) as usize;
let head_dim = hidden_size / num_heads;
let intermediate_size = get_u32(&format!("{}.feed_forward_length", arch))
.unwrap_or((hidden_size * 4 * 2 / 3) as u32) as usize;
let max_seq_len = get_u32(&format!("{}.context_length", arch))
.unwrap_or(2048) as usize;
let norm_eps = get_f32_or(&format!("{}.attention.layer_norm_rms_epsilon", arch), 1e-5);
let freq_base = get_f32_or(&format!("{}.rope.freq_base", arch), 10000.0);
let freq_scale = get_f32_or(&format!("{}.rope.scale_linear", arch), 1.0);
let rope_type = match architecture {
Architecture::Qwen2 => RopeType::NeoX,
_ => RopeType::Normal,
};
let rope_config = RopeConfig {
freq_base,
freq_scale,
n_dims: head_dim,
scaling_type: RopeScalingType::None,
original_max_position_embeddings: max_seq_len,
rope_type,
};
Ok(ModelConfig {
vocab_size,
hidden_size,
intermediate_size,
num_layers,
num_heads,
num_kv_heads,
head_dim,
max_seq_len,
norm_eps,
rope_config,
use_parallel_residual: false,
hidden_act: ActivationType::SiLU,
attention_bias: false,
mlp_bias: false,
tie_word_embeddings: gguf.data.get_string("general.tie_word_embeddings")
.map(|s| s == "true")
.unwrap_or(false),
})
}
pub fn config(&self) -> &ModelConfig {
&self.config
}
pub fn architecture(&self) -> Architecture {
self.architecture
}
pub fn build_model(self) -> ModelResult<LlamaModel> {
let token_embedding = self.load_tensor("token_embd.weight")?;
let mut layers = Vec::with_capacity(self.config.num_layers);
for i in 0..self.config.num_layers {
let layer = self.load_transformer_layer(i)?;
layers.push(layer);
}
let norm_weight = self.load_tensor("output_norm.weight")?;
let norm = RMSNorm::new(norm_weight, self.config.norm_eps)?;
let output = if self.config.tie_word_embeddings || self.try_load_tensor("output.weight").is_none() {
Linear::new(token_embedding.clone(), None)?
} else {
let output_weight = self.load_tensor("output.weight")?;
Linear::new(output_weight, None)?
};
LlamaModel::new(
self.config,
token_embedding,
layers,
norm,
output,
self.architecture,
)
}
fn load_transformer_layer(&self, layer_idx: usize) -> ModelResult<TransformerLayer> {
let prefix = format!("blk.{}", layer_idx);
let attn_norm_weight = self.load_tensor(&format!("{}.attn_norm.weight", prefix))?;
let attn_norm = RMSNorm::new(attn_norm_weight, self.config.norm_eps)?;
let wq_bias = self.try_load_tensor(&format!("{}.attn_q.bias", prefix));
let wq = Linear::new(
self.load_tensor(&format!("{}.attn_q.weight", prefix))?,
wq_bias,
)?;
let wk_bias = self.try_load_tensor(&format!("{}.attn_k.bias", prefix));
let wk = Linear::new(
self.load_tensor(&format!("{}.attn_k.weight", prefix))?,
wk_bias,
)?;
let wv_bias = self.try_load_tensor(&format!("{}.attn_v.bias", prefix));
let wv = Linear::new(
self.load_tensor(&format!("{}.attn_v.weight", prefix))?,
wv_bias,
)?;
let wo = Linear::new(
self.load_tensor(&format!("{}.attn_output.weight", prefix))?,
None,
)?;
let use_neox_rope = matches!(self.config.rope_config.rope_type, RopeType::NeoX);
let attention = Attention::with_rope_type(
wq,
wk,
wv,
wo,
self.config.num_heads,
self.config.num_kv_heads,
self.config.head_dim,
use_neox_rope,
);
let ffn_norm_weight = self.load_tensor(&format!("{}.ffn_norm.weight", prefix))?;
let ffn_norm = RMSNorm::new(ffn_norm_weight, self.config.norm_eps)?;
let w_gate = Linear::new(
self.load_tensor(&format!("{}.ffn_gate.weight", prefix))?,
None,
)?;
let w_up = Linear::new(
self.load_tensor(&format!("{}.ffn_up.weight", prefix))?,
None,
)?;
let w_down = Linear::new(
self.load_tensor(&format!("{}.ffn_down.weight", prefix))?,
None,
)?;
let ffn = FeedForward::new(w_gate, w_up, w_down);
Ok(TransformerLayer {
attn_norm,
attention,
ffn_norm,
ffn,
layer_idx,
})
}
fn try_load_tensor(&self, name: &str) -> Option<Tensor> {
let tensor_info = self.gguf.data.get_tensor(name)?;
let tensor_data = self.gguf.tensor_data(name)?;
let shape: Vec<usize> = tensor_info.dims.iter().map(|&d| d as usize).collect();
let dtype = DType::from(tensor_info.dtype);
Tensor::new(tensor_data.to_vec(), shape, dtype).ok().map(|mut t| {
t.set_name(name);
t
})
}
fn load_tensor(&self, name: &str) -> ModelResult<Tensor> {
let tensor_info = self
.gguf
.data
.get_tensor(name)
.ok_or_else(|| ModelError::MissingTensor(name.into()))?;
let tensor_data = self
.gguf
.tensor_data(name)
.ok_or_else(|| ModelError::MissingTensor(name.into()))?;
let shape: Vec<usize> = tensor_info.dims.iter().map(|&d| d as usize).collect();
let dtype = DType::from(tensor_info.dtype);
let mut tensor = Tensor::new(tensor_data.to_vec(), shape, dtype)?;
tensor.set_name(name);
Ok(tensor)
}
}
pub fn load_llama_model<P: AsRef<Path>>(path: P) -> ModelResult<LlamaModel> {
let loader = ModelLoader::load(path)?;
if !loader.architecture().is_llama_like() {
return Err(ModelError::UnsupportedArchitecture(
loader.architecture().to_string(),
));
}
loader.build_model()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_architecture_detection() {
assert!(Architecture::Llama.is_llama_like());
assert!(Architecture::Mistral.is_llama_like());
assert!(!Architecture::GPT2.is_llama_like());
}
}