god_graph/transformer/loader/
weight_mapper.rs1use std::collections::HashMap;
7use crate::errors::{GraphError, GraphResult};
8use crate::tensor::DenseTensor;
9use crate::tensor::traits::TensorBase;
10use crate::transformer::loader::config::LlamaConfig;
11use crate::transformer::model::{LlamaDecoderLayer, LlamaModel as LlamaModelStruct};
12use crate::transformer::layers::{
13 MultiHeadAttention, FeedForward, RMSNorm, RoPE,
14};
15
16pub struct LlamaWeightMapper {
20 config: LlamaConfig,
21}
22
23impl LlamaWeightMapper {
24 pub fn new(config: LlamaConfig) -> Self {
26 Self { config }
27 }
28
29 pub fn config(&self) -> &LlamaConfig {
31 &self.config
32 }
33
34 pub fn build_model(
44 &self,
45 tensors: &HashMap<String, DenseTensor>,
46 ) -> GraphResult<LlamaModelStruct> {
47 let embed_tokens = tensors
49 .get("model.embed_tokens.weight")
50 .ok_or_else(|| GraphError::NotFound("model.embed_tokens.weight".to_string()))?
51 .clone();
52
53 let mut layers = Vec::new();
55 for layer_idx in 0..self.config.num_hidden_layers {
56 let layer = self.build_layer(layer_idx, tensors)?;
57 layers.push(layer);
58 }
59
60 let norm = RMSNorm::new(
62 tensors
63 .get("model.norm.weight")
64 .ok_or_else(|| GraphError::NotFound("model.norm.weight".to_string()))?
65 .clone(),
66 self.config.rms_norm_eps,
67 );
68
69 let lm_head = tensors
71 .get("lm_head.weight")
72 .ok_or_else(|| GraphError::NotFound("lm_head.weight".to_string()))?
73 .clone();
74
75 Ok(LlamaModelStruct {
76 embed_tokens: DenseTensor::new(
77 embed_tokens.data().to_vec(),
78 embed_tokens.shape().to_vec(),
79 ),
80 layers,
81 norm,
82 lm_head: Some(DenseTensor::new(
83 lm_head.data().to_vec(),
84 lm_head.shape().to_vec(),
85 )),
86 config: self.config.clone(),
87 rope: RoPE::new(
88 self.config.head_dim(),
89 self.config.max_position_embeddings,
90 self.config.rope_theta,
91 ),
92 })
93 }
94
95 pub fn build_layer(
106 &self,
107 layer_idx: usize,
108 tensors: &HashMap<String, DenseTensor>,
109 ) -> GraphResult<LlamaDecoderLayer> {
110 let prefix = format!("model.layers.{}", layer_idx);
111
112 let q_proj = tensors
114 .get(&format!("{}.self_attn.q_proj.weight", prefix))
115 .ok_or_else(|| GraphError::NotFound(format!("{}.self_attn.q_proj.weight", prefix)))?
116 .clone();
117
118 let k_proj = tensors
119 .get(&format!("{}.self_attn.k_proj.weight", prefix))
120 .ok_or_else(|| GraphError::NotFound(format!("{}.self_attn.k_proj.weight", prefix)))?
121 .clone();
122
123 let v_proj = tensors
124 .get(&format!("{}.self_attn.v_proj.weight", prefix))
125 .ok_or_else(|| GraphError::NotFound(format!("{}.self_attn.v_proj.weight", prefix)))?
126 .clone();
127
128 let o_proj = tensors
129 .get(&format!("{}.self_attn.o_proj.weight", prefix))
130 .ok_or_else(|| GraphError::NotFound(format!("{}.self_attn.o_proj.weight", prefix)))?
131 .clone();
132
133 let self_attn = MultiHeadAttention::new(
135 q_proj,
136 k_proj,
137 v_proj,
138 o_proj,
139 self.config.num_attention_heads,
140 self.config.get_num_key_value_heads(),
141 );
142
143 let gate_proj = tensors
145 .get(&format!("{}.mlp.gate_proj.weight", prefix))
146 .ok_or_else(|| GraphError::NotFound(format!("{}.mlp.gate_proj.weight", prefix)))?
147 .clone();
148
149 let up_proj = tensors
150 .get(&format!("{}.mlp.up_proj.weight", prefix))
151 .ok_or_else(|| GraphError::NotFound(format!("{}.mlp.up_proj.weight", prefix)))?
152 .clone();
153
154 let down_proj = tensors
155 .get(&format!("{}.mlp.down_proj.weight", prefix))
156 .ok_or_else(|| GraphError::NotFound(format!("{}.mlp.down_proj.weight", prefix)))?
157 .clone();
158
159 let mlp = FeedForward::swiglu(gate_proj, up_proj, down_proj);
161
162 let input_layernorm = RMSNorm::new(
164 tensors
165 .get(&format!("{}.input_layernorm.weight", prefix))
166 .ok_or_else(|| GraphError::NotFound(format!("{}.input_layernorm.weight", prefix)))?
167 .clone(),
168 self.config.rms_norm_eps,
169 );
170
171 let post_attention_layernorm = RMSNorm::new(
172 tensors
173 .get(&format!("{}.post_attention_layernorm.weight", prefix))
174 .ok_or_else(|| GraphError::NotFound(format!("{}.post_attention_layernorm.weight", prefix)))?
175 .clone(),
176 self.config.rms_norm_eps,
177 );
178
179 Ok(LlamaDecoderLayer::new(
180 self_attn,
181 mlp,
182 input_layernorm,
183 post_attention_layernorm,
184 ))
185 }
186
187 pub fn get_weight<'a>(
199 &self,
200 layer_idx: usize,
201 component: &str,
202 tensors: &'a HashMap<String, DenseTensor>,
203 ) -> GraphResult<&'a DenseTensor> {
204 let name = format!("model.layers.{}.{}", layer_idx, component);
205 tensors
206 .get(&name)
207 .ok_or(GraphError::NotFound(name))
208 }
209}
210
211#[derive(Debug, Clone)]
213pub struct LlamaModel {
214 pub embed_tokens: DenseTensor,
216 pub layers: Vec<LlamaDecoderLayer>,
218 pub norm: RMSNorm,
220 pub lm_head: DenseTensor,
222 pub config: LlamaConfig,
224}
225
226impl LlamaModel {
227 pub fn num_parameters(&self) -> usize {
229 let mut total = 0;
230
231 total += self.embed_tokens.shape().iter().product::<usize>();
233
234 for layer in &self.layers {
236 total += layer.num_parameters();
237 }
238
239 total += self.norm.weight.shape().iter().product::<usize>();
241
242 total += self.lm_head.shape().iter().product::<usize>();
244
245 total
246 }
247
248 pub fn size_mb(&self) -> f64 {
250 (self.num_parameters() * 8) as f64 / (1024.0 * 1024.0)
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257
258 #[test]
259 fn test_llama_weight_mapper_creation() {
260 let config = LlamaConfig::llama_7b();
261 let mapper = LlamaWeightMapper::new(config.clone());
262
263 assert_eq!(mapper.config().vocab_size, config.vocab_size);
264 assert_eq!(mapper.config().hidden_size, config.hidden_size);
265 }
266
267 #[test]
268 fn test_llama_model_structure() {
269 let config = LlamaConfig {
271 vocab_size: 100,
272 hidden_size: 64,
273 intermediate_size: 128,
274 num_hidden_layers: 2,
275 num_attention_heads: 8,
276 num_key_value_heads: Some(8),
277 max_position_embeddings: 512,
278 rms_norm_eps: 1e-6,
279 rope_theta: 10000.0,
280 tie_word_embeddings: false,
281 attention_bias: false,
282 };
283
284 let embed_tokens = DenseTensor::from_vec(
285 vec![1.0; config.vocab_size * config.hidden_size],
286 vec![config.vocab_size, config.hidden_size],
287 );
288
289 let lm_head = DenseTensor::from_vec(
290 vec![1.0; config.vocab_size * config.hidden_size],
291 vec![config.vocab_size, config.hidden_size],
292 );
293
294 let norm_weight = DenseTensor::from_vec(
295 vec![1.0; config.hidden_size],
296 vec![config.hidden_size],
297 );
298
299 let norm = RMSNorm::new(norm_weight, config.rms_norm_eps);
300
301 let layers = Vec::new(); let rope = RoPE::new(
305 config.head_dim(),
306 config.max_position_embeddings,
307 config.rope_theta,
308 );
309
310 let model = LlamaModelStruct {
311 embed_tokens,
312 layers,
313 norm,
314 lm_head: Some(lm_head),
315 config: config.clone(),
316 rope,
317 };
318
319 assert!(model.num_parameters() > 0);
321 assert!(model.size_mb() > 0.0);
322 }
323}