Skip to main content

god_graph/transformer/loader/
weight_mapper.rs

1//! Weight mapper for converting Safetensors weights to model structures
2//!
3//! This module provides mapping from HuggingFace Safetensors weight names
4//! to GodGraph model structures (LlamaModel, etc.).
5
6use std::collections::HashMap;
7use crate::errors::{GraphError, GraphResult};
8use crate::tensor::DenseTensor;
9use crate::tensor::traits::TensorBase;
10use crate::transformer::loader::config::LlamaConfig;
11use crate::transformer::model::{LlamaDecoderLayer, LlamaModel as LlamaModelStruct};
12use crate::transformer::layers::{
13    MultiHeadAttention, FeedForward, RMSNorm, RoPE,
14};
15
16/// Weight mapper for LLaMA models
17///
18/// Maps Safetensors weight names to LlamaModel components
19pub struct LlamaWeightMapper {
20    config: LlamaConfig,
21}
22
23impl LlamaWeightMapper {
24    /// Create a new LLaMA weight mapper
25    pub fn new(config: LlamaConfig) -> Self {
26        Self { config }
27    }
28
29    /// Get the model configuration
30    pub fn config(&self) -> &LlamaConfig {
31        &self.config
32    }
33
34    /// Build a complete LlamaModel from loaded tensors
35    ///
36    /// # Arguments
37    ///
38    /// * `tensors` - Map of tensor names to tensor data
39    ///
40    /// # Returns
41    ///
42    /// Complete LlamaModelStruct with all weights loaded
43    pub fn build_model(
44        &self,
45        tensors: &HashMap<String, DenseTensor>,
46    ) -> GraphResult<LlamaModelStruct> {
47        // Extract token embeddings
48        let embed_tokens = tensors
49            .get("model.embed_tokens.weight")
50            .ok_or_else(|| GraphError::NotFound("model.embed_tokens.weight".to_string()))?
51            .clone();
52
53        // Build decoder layers
54        let mut layers = Vec::new();
55        for layer_idx in 0..self.config.num_hidden_layers {
56            let layer = self.build_layer(layer_idx, tensors)?;
57            layers.push(layer);
58        }
59
60        // Extract final layer norm
61        let norm = RMSNorm::new(
62            tensors
63                .get("model.norm.weight")
64                .ok_or_else(|| GraphError::NotFound("model.norm.weight".to_string()))?
65                .clone(),
66            self.config.rms_norm_eps,
67        );
68
69        // Extract language model head
70        let lm_head = tensors
71            .get("lm_head.weight")
72            .ok_or_else(|| GraphError::NotFound("lm_head.weight".to_string()))?
73            .clone();
74
75        Ok(LlamaModelStruct {
76            embed_tokens: DenseTensor::new(
77                embed_tokens.data().to_vec(),
78                embed_tokens.shape().to_vec(),
79            ),
80            layers,
81            norm,
82            lm_head: Some(DenseTensor::new(
83                lm_head.data().to_vec(),
84                lm_head.shape().to_vec(),
85            )),
86            config: self.config.clone(),
87            rope: RoPE::new(
88                self.config.head_dim(),
89                self.config.max_position_embeddings,
90                self.config.rope_theta,
91            ),
92        })
93    }
94
95    /// Build a single decoder layer
96    ///
97    /// # Arguments
98    ///
99    /// * `layer_idx` - Layer index (0-based)
100    /// * `tensors` - Map of tensor names to tensor data
101    ///
102    /// # Returns
103    ///
104    /// LlamaDecoderLayer with weights loaded
105    pub fn build_layer(
106        &self,
107        layer_idx: usize,
108        tensors: &HashMap<String, DenseTensor>,
109    ) -> GraphResult<LlamaDecoderLayer> {
110        let prefix = format!("model.layers.{}", layer_idx);
111
112        // Extract attention weights
113        let q_proj = tensors
114            .get(&format!("{}.self_attn.q_proj.weight", prefix))
115            .ok_or_else(|| GraphError::NotFound(format!("{}.self_attn.q_proj.weight", prefix)))?
116            .clone();
117
118        let k_proj = tensors
119            .get(&format!("{}.self_attn.k_proj.weight", prefix))
120            .ok_or_else(|| GraphError::NotFound(format!("{}.self_attn.k_proj.weight", prefix)))?
121            .clone();
122
123        let v_proj = tensors
124            .get(&format!("{}.self_attn.v_proj.weight", prefix))
125            .ok_or_else(|| GraphError::NotFound(format!("{}.self_attn.v_proj.weight", prefix)))?
126            .clone();
127
128        let o_proj = tensors
129            .get(&format!("{}.self_attn.o_proj.weight", prefix))
130            .ok_or_else(|| GraphError::NotFound(format!("{}.self_attn.o_proj.weight", prefix)))?
131            .clone();
132
133        // Build multi-head attention
134        let self_attn = MultiHeadAttention::new(
135            q_proj,
136            k_proj,
137            v_proj,
138            o_proj,
139            self.config.num_attention_heads,
140            self.config.get_num_key_value_heads(),
141        );
142
143        // Extract FFN weights (SwiGLU)
144        let gate_proj = tensors
145            .get(&format!("{}.mlp.gate_proj.weight", prefix))
146            .ok_or_else(|| GraphError::NotFound(format!("{}.mlp.gate_proj.weight", prefix)))?
147            .clone();
148
149        let up_proj = tensors
150            .get(&format!("{}.mlp.up_proj.weight", prefix))
151            .ok_or_else(|| GraphError::NotFound(format!("{}.mlp.up_proj.weight", prefix)))?
152            .clone();
153
154        let down_proj = tensors
155            .get(&format!("{}.mlp.down_proj.weight", prefix))
156            .ok_or_else(|| GraphError::NotFound(format!("{}.mlp.down_proj.weight", prefix)))?
157            .clone();
158
159        // Build SwiGLU feed-forward network
160        let mlp = FeedForward::swiglu(gate_proj, up_proj, down_proj);
161
162        // Extract normalization weights
163        let input_layernorm = RMSNorm::new(
164            tensors
165                .get(&format!("{}.input_layernorm.weight", prefix))
166                .ok_or_else(|| GraphError::NotFound(format!("{}.input_layernorm.weight", prefix)))?
167                .clone(),
168            self.config.rms_norm_eps,
169        );
170
171        let post_attention_layernorm = RMSNorm::new(
172            tensors
173                .get(&format!("{}.post_attention_layernorm.weight", prefix))
174                .ok_or_else(|| GraphError::NotFound(format!("{}.post_attention_layernorm.weight", prefix)))?
175                .clone(),
176            self.config.rms_norm_eps,
177        );
178
179        Ok(LlamaDecoderLayer::new(
180            self_attn,
181            mlp,
182            input_layernorm,
183            post_attention_layernorm,
184        ))
185    }
186
187    /// Extract a specific weight tensor by layer and component
188    ///
189    /// # Arguments
190    ///
191    /// * `layer_idx` - Layer index
192    /// * `component` - Component name (e.g., "q_proj", "k_proj", "mlp.gate_proj")
193    /// * `tensors` - Map of tensor names to tensor data
194    ///
195    /// # Returns
196    ///
197    /// The requested weight tensor
198    pub fn get_weight<'a>(
199        &self,
200        layer_idx: usize,
201        component: &str,
202        tensors: &'a HashMap<String, DenseTensor>,
203    ) -> GraphResult<&'a DenseTensor> {
204        let name = format!("model.layers.{}.{}", layer_idx, component);
205        tensors
206            .get(&name)
207            .ok_or(GraphError::NotFound(name))
208    }
209}
210
211/// LLaMA model structure with loaded weights
212#[derive(Debug, Clone)]
213pub struct LlamaModel {
214    /// Token embedding matrix
215    pub embed_tokens: DenseTensor,
216    /// Decoder layers
217    pub layers: Vec<LlamaDecoderLayer>,
218    /// Final layer normalization
219    pub norm: RMSNorm,
220    /// Language model head
221    pub lm_head: DenseTensor,
222    /// Model configuration
223    pub config: LlamaConfig,
224}
225
226impl LlamaModel {
227    /// Get the number of parameters in the model
228    pub fn num_parameters(&self) -> usize {
229        let mut total = 0;
230
231        // Embeddings
232        total += self.embed_tokens.shape().iter().product::<usize>();
233
234        // Each decoder layer
235        for layer in &self.layers {
236            total += layer.num_parameters();
237        }
238
239        // Final norm
240        total += self.norm.weight.shape().iter().product::<usize>();
241
242        // LM head
243        total += self.lm_head.shape().iter().product::<usize>();
244
245        total
246    }
247
248    /// Get model size in MB (assuming f64)
249    pub fn size_mb(&self) -> f64 {
250        (self.num_parameters() * 8) as f64 / (1024.0 * 1024.0)
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257
258    #[test]
259    fn test_llama_weight_mapper_creation() {
260        let config = LlamaConfig::llama_7b();
261        let mapper = LlamaWeightMapper::new(config.clone());
262
263        assert_eq!(mapper.config().vocab_size, config.vocab_size);
264        assert_eq!(mapper.config().hidden_size, config.hidden_size);
265    }
266
267    #[test]
268    fn test_llama_model_structure() {
269        // Create a minimal mock model for testing
270        let config = LlamaConfig {
271            vocab_size: 100,
272            hidden_size: 64,
273            intermediate_size: 128,
274            num_hidden_layers: 2,
275            num_attention_heads: 8,
276            num_key_value_heads: Some(8),
277            max_position_embeddings: 512,
278            rms_norm_eps: 1e-6,
279            rope_theta: 10000.0,
280            tie_word_embeddings: false,
281            attention_bias: false,
282        };
283
284        let embed_tokens = DenseTensor::from_vec(
285            vec![1.0; config.vocab_size * config.hidden_size],
286            vec![config.vocab_size, config.hidden_size],
287        );
288
289        let lm_head = DenseTensor::from_vec(
290            vec![1.0; config.vocab_size * config.hidden_size],
291            vec![config.vocab_size, config.hidden_size],
292        );
293
294        let norm_weight = DenseTensor::from_vec(
295            vec![1.0; config.hidden_size],
296            vec![config.hidden_size],
297        );
298
299        let norm = RMSNorm::new(norm_weight, config.rms_norm_eps);
300
301        // Create mock layers (this would normally come from the mapper)
302        let layers = Vec::new(); // Empty for this test
303
304        let rope = RoPE::new(
305            config.head_dim(),
306            config.max_position_embeddings,
307            config.rope_theta,
308        );
309
310        let model = LlamaModelStruct {
311            embed_tokens,
312            layers,
313            norm,
314            lm_head: Some(lm_head),
315            config: config.clone(),
316            rope,
317        };
318
319        // Verify parameter count calculation
320        assert!(model.num_parameters() > 0);
321        assert!(model.size_mb() > 0.0);
322    }
323}