use std::collections::HashMap;
use crate::errors::{GraphError, GraphResult};
use crate::tensor::DenseTensor;
use crate::tensor::traits::TensorBase;
use crate::transformer::loader::config::LlamaConfig;
use crate::transformer::model::{LlamaDecoderLayer, LlamaModel as LlamaModelStruct};
use crate::transformer::layers::{
MultiHeadAttention, FeedForward, RMSNorm, RoPE,
};
pub struct LlamaWeightMapper {
config: LlamaConfig,
}
impl LlamaWeightMapper {
pub fn new(config: LlamaConfig) -> Self {
Self { config }
}
pub fn config(&self) -> &LlamaConfig {
&self.config
}
pub fn build_model(
&self,
tensors: &HashMap<String, DenseTensor>,
) -> GraphResult<LlamaModelStruct> {
let embed_tokens = tensors
.get("model.embed_tokens.weight")
.ok_or_else(|| GraphError::NotFound("model.embed_tokens.weight".to_string()))?
.clone();
let mut layers = Vec::new();
for layer_idx in 0..self.config.num_hidden_layers {
let layer = self.build_layer(layer_idx, tensors)?;
layers.push(layer);
}
let norm = RMSNorm::new(
tensors
.get("model.norm.weight")
.ok_or_else(|| GraphError::NotFound("model.norm.weight".to_string()))?
.clone(),
self.config.rms_norm_eps,
);
let lm_head = tensors
.get("lm_head.weight")
.ok_or_else(|| GraphError::NotFound("lm_head.weight".to_string()))?
.clone();
Ok(LlamaModelStruct {
embed_tokens: DenseTensor::new(
embed_tokens.data().to_vec(),
embed_tokens.shape().to_vec(),
),
layers,
norm,
lm_head: Some(DenseTensor::new(
lm_head.data().to_vec(),
lm_head.shape().to_vec(),
)),
config: self.config.clone(),
rope: RoPE::new(
self.config.head_dim(),
self.config.max_position_embeddings,
self.config.rope_theta,
),
})
}
pub fn build_layer(
&self,
layer_idx: usize,
tensors: &HashMap<String, DenseTensor>,
) -> GraphResult<LlamaDecoderLayer> {
let prefix = format!("model.layers.{}", layer_idx);
let q_proj = tensors
.get(&format!("{}.self_attn.q_proj.weight", prefix))
.ok_or_else(|| GraphError::NotFound(format!("{}.self_attn.q_proj.weight", prefix)))?
.clone();
let k_proj = tensors
.get(&format!("{}.self_attn.k_proj.weight", prefix))
.ok_or_else(|| GraphError::NotFound(format!("{}.self_attn.k_proj.weight", prefix)))?
.clone();
let v_proj = tensors
.get(&format!("{}.self_attn.v_proj.weight", prefix))
.ok_or_else(|| GraphError::NotFound(format!("{}.self_attn.v_proj.weight", prefix)))?
.clone();
let o_proj = tensors
.get(&format!("{}.self_attn.o_proj.weight", prefix))
.ok_or_else(|| GraphError::NotFound(format!("{}.self_attn.o_proj.weight", prefix)))?
.clone();
let self_attn = MultiHeadAttention::new(
q_proj,
k_proj,
v_proj,
o_proj,
self.config.num_attention_heads,
self.config.get_num_key_value_heads(),
);
let gate_proj = tensors
.get(&format!("{}.mlp.gate_proj.weight", prefix))
.ok_or_else(|| GraphError::NotFound(format!("{}.mlp.gate_proj.weight", prefix)))?
.clone();
let up_proj = tensors
.get(&format!("{}.mlp.up_proj.weight", prefix))
.ok_or_else(|| GraphError::NotFound(format!("{}.mlp.up_proj.weight", prefix)))?
.clone();
let down_proj = tensors
.get(&format!("{}.mlp.down_proj.weight", prefix))
.ok_or_else(|| GraphError::NotFound(format!("{}.mlp.down_proj.weight", prefix)))?
.clone();
let mlp = FeedForward::swiglu(gate_proj, up_proj, down_proj);
let input_layernorm = RMSNorm::new(
tensors
.get(&format!("{}.input_layernorm.weight", prefix))
.ok_or_else(|| GraphError::NotFound(format!("{}.input_layernorm.weight", prefix)))?
.clone(),
self.config.rms_norm_eps,
);
let post_attention_layernorm = RMSNorm::new(
tensors
.get(&format!("{}.post_attention_layernorm.weight", prefix))
.ok_or_else(|| GraphError::NotFound(format!("{}.post_attention_layernorm.weight", prefix)))?
.clone(),
self.config.rms_norm_eps,
);
Ok(LlamaDecoderLayer::new(
self_attn,
mlp,
input_layernorm,
post_attention_layernorm,
))
}
pub fn get_weight<'a>(
&self,
layer_idx: usize,
component: &str,
tensors: &'a HashMap<String, DenseTensor>,
) -> GraphResult<&'a DenseTensor> {
let name = format!("model.layers.{}.{}", layer_idx, component);
tensors
.get(&name)
.ok_or(GraphError::NotFound(name))
}
}
#[derive(Debug, Clone)]
pub struct LlamaModel {
pub embed_tokens: DenseTensor,
pub layers: Vec<LlamaDecoderLayer>,
pub norm: RMSNorm,
pub lm_head: DenseTensor,
pub config: LlamaConfig,
}
impl LlamaModel {
pub fn num_parameters(&self) -> usize {
let mut total = 0;
total += self.embed_tokens.shape().iter().product::<usize>();
for layer in &self.layers {
total += layer.num_parameters();
}
total += self.norm.weight.shape().iter().product::<usize>();
total += self.lm_head.shape().iter().product::<usize>();
total
}
pub fn size_mb(&self) -> f64 {
(self.num_parameters() * 8) as f64 / (1024.0 * 1024.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_llama_weight_mapper_creation() {
let config = LlamaConfig::llama_7b();
let mapper = LlamaWeightMapper::new(config.clone());
assert_eq!(mapper.config().vocab_size, config.vocab_size);
assert_eq!(mapper.config().hidden_size, config.hidden_size);
}
#[test]
fn test_llama_model_structure() {
let config = LlamaConfig {
vocab_size: 100,
hidden_size: 64,
intermediate_size: 128,
num_hidden_layers: 2,
num_attention_heads: 8,
num_key_value_heads: Some(8),
max_position_embeddings: 512,
rms_norm_eps: 1e-6,
rope_theta: 10000.0,
tie_word_embeddings: false,
attention_bias: false,
};
let embed_tokens = DenseTensor::from_vec(
vec![1.0; config.vocab_size * config.hidden_size],
vec![config.vocab_size, config.hidden_size],
);
let lm_head = DenseTensor::from_vec(
vec![1.0; config.vocab_size * config.hidden_size],
vec![config.vocab_size, config.hidden_size],
);
let norm_weight = DenseTensor::from_vec(
vec![1.0; config.hidden_size],
vec![config.hidden_size],
);
let norm = RMSNorm::new(norm_weight, config.rms_norm_eps);
let layers = Vec::new();
let rope = RoPE::new(
config.head_dim(),
config.max_position_embeddings,
config.rope_theta,
);
let model = LlamaModelStruct {
embed_tokens,
layers,
norm,
lm_head: Some(lm_head),
config: config.clone(),
rope,
};
assert!(model.num_parameters() > 0);
assert!(model.size_mb() > 0.0);
}
}