Skip to main content

god_graph/transformer/
model.rs

1//! LLaMA model implementation
2
3use crate::tensor::DenseTensor;
4use crate::tensor::traits::{TensorBase, TensorOps};
5use super::layers::{MultiHeadAttention, FeedForward, RMSNorm, RoPE};
6pub use super::loader::LlamaConfig;
7
8/// LLaMA decoder layer
9#[derive(Debug, Clone)]
10pub struct LlamaDecoderLayer {
11    /// Self-attention layer
12    pub self_attn: MultiHeadAttention,
13    /// Feed-forward network (SwiGLU)
14    pub mlp: FeedForward,
15    /// Input layer normalization
16    pub input_layernorm: RMSNorm,
17    /// Post-attention layer normalization
18    pub post_attention_layernorm: RMSNorm,
19}
20
21impl LlamaDecoderLayer {
22    /// Create a new LLaMA decoder layer
23    pub fn new(
24        self_attn: MultiHeadAttention,
25        mlp: FeedForward,
26        input_layernorm: RMSNorm,
27        post_attention_layernorm: RMSNorm,
28    ) -> Self {
29        Self {
30            self_attn,
31            mlp,
32            input_layernorm,
33            post_attention_layernorm,
34        }
35    }
36
37    /// Forward pass
38    ///
39    /// # Arguments
40    /// * `x` - Input tensor [batch_size, seq_len, hidden_dim]
41    /// * `mask` - Optional attention mask
42    ///
43    /// # Returns
44    /// Output tensor [batch_size, seq_len, hidden_dim]
45    pub fn forward(&self, x: &DenseTensor, mask: Option<&DenseTensor>) -> DenseTensor {
46        // Pre-norm residual architecture (LLaMA uses pre-LN)
47        
48        // 1. Input normalization
49        let normed = self.input_layernorm.forward(x);
50        
51        // 2. Self-attention with residual
52        let attn_output = self.self_attn.forward_with_mask(&normed, mask);
53        let hidden = x.add(&attn_output);
54        
55        // 3. Post-attention normalization
56        let normed = self.post_attention_layernorm.forward(&hidden);
57        
58        // 4. FFN with residual
59        let mlp_output = self.mlp.forward(&normed);
60        
61        
62        hidden.add(&mlp_output)
63    }
64
65    /// Forward pass with KV cache
66    ///
67    /// # Arguments
68    /// * `x` - Input tensor [batch_size, seq_len, hidden_dim]
69    /// * `kv_cache` - Optional KV cache for this layer
70    /// * `mask` - Optional attention mask
71    ///
72    /// # Returns
73    /// Output tensor and updated KV cache
74    pub fn forward_with_cache(
75        &self,
76        x: &DenseTensor,
77        kv_cache: Option<(&DenseTensor, &DenseTensor)>,
78        mask: Option<&DenseTensor>,
79    ) -> (DenseTensor, Option<(DenseTensor, DenseTensor)>) {
80        // For inference with KV cache
81        // This is a simplified version - full implementation would update cache
82        let output = self.forward(x, mask);
83        (output, kv_cache.map(|(k, v)| (k.clone(), v.clone())))
84    }
85
86    /// Get the number of parameters in this layer
87    pub fn num_parameters(&self) -> usize {
88        let mut total = 0;
89
90        // Attention parameters
91        total += self.self_attn.num_parameters();
92
93        // MLP parameters
94        total += self.mlp.num_parameters();
95
96        // Layer norm parameters (2 * hidden_dim)
97        total += self.input_layernorm.weight.shape().iter().product::<usize>();
98        total += self.post_attention_layernorm.weight.shape().iter().product::<usize>();
99
100        total
101    }
102}
103
104/// Complete LLaMA model
105#[derive(Debug, Clone)]
106pub struct LlamaModel {
107    /// Model configuration
108    pub config: LlamaConfig,
109    /// Token embeddings [vocab_size, hidden_dim]
110    pub embed_tokens: DenseTensor,
111    /// Decoder layers
112    pub layers: Vec<LlamaDecoderLayer>,
113    /// Final layer normalization
114    pub norm: RMSNorm,
115    /// LM head (optional, may be tied with embed_tokens)
116    pub lm_head: Option<DenseTensor>,
117    /// RoPE module
118    pub rope: RoPE,
119}
120
121impl LlamaModel {
122    /// Create a new LLaMA model
123    pub fn new(
124        config: LlamaConfig,
125        embed_tokens: DenseTensor,
126        layers: Vec<LlamaDecoderLayer>,
127        norm: RMSNorm,
128        lm_head: Option<DenseTensor>,
129    ) -> Self {
130        let rope = RoPE::new(
131            config.head_dim(),
132            config.max_position_embeddings,
133            config.rope_theta,
134        );
135        
136        Self {
137            config,
138            embed_tokens,
139            layers,
140            norm,
141            lm_head,
142            rope,
143        }
144    }
145
146    /// Forward pass
147    ///
148    /// # Arguments
149    /// * `input_ids` - Input token IDs [batch_size, seq_len]
150    /// * `mask` - Optional attention mask [batch_size, seq_len, seq_len]
151    ///
152    /// # Returns
153    /// Logits tensor [batch_size, seq_len, vocab_size]
154    pub fn forward(&self, input_ids: &[Vec<usize>], mask: Option<&DenseTensor>) -> DenseTensor {
155        let batch_size = input_ids.len();
156        let seq_len = input_ids[0].len();
157
158        // 1. Get token embeddings
159        let mut hidden = self.embed_tokens_batch(input_ids);
160
161        // 2. Apply RoPE to positions
162        let _positions: Vec<usize> = (0..seq_len).collect();
163
164        // 3. Pass through decoder layers
165        for layer in &self.layers {
166            hidden = layer.forward(&hidden, mask);
167        }
168
169        // 4. Final normalization
170        hidden = self.norm.forward(&hidden);
171
172        // 5. LM head projection
173        // hidden: [batch, seq_len, hidden_dim], lm_head: [vocab_size, hidden_dim]
174        // Need to compute: hidden @ lm_head.T for each (batch, seq) position
175        let lm_head = self.lm_head.as_ref().unwrap_or(&self.embed_tokens);
176        let lm_head_t = lm_head.transpose(None); // [hidden_dim, vocab_size]
177        
178        // Reshape hidden to [batch*seq_len, hidden_dim] for matmul
179        let hidden_data = hidden.data().to_vec();
180        let hidden_dim = self.config.hidden_size;
181        let flat_hidden = DenseTensor::new(hidden_data, vec![batch_size * seq_len, hidden_dim]);
182        
183        // Matmul: [batch*seq, hidden] @ [hidden, vocab] = [batch*seq, vocab]
184        let logits_flat = flat_hidden.matmul(&lm_head_t);
185        
186        // Reshape back to [batch, seq_len, vocab_size]
187        let vocab_size = self.config.vocab_size;
188        let logits_data = logits_flat.data().to_vec();
189        
190        DenseTensor::new(logits_data, vec![batch_size, seq_len, vocab_size])
191    }
192
193    /// Forward pass for a single sequence
194    pub fn forward_single(&self, input_ids: &[usize], mask: Option<&DenseTensor>) -> DenseTensor {
195        self.forward(&[input_ids.to_vec()], mask)
196    }
197
198    /// Embed tokens in batch
199    fn embed_tokens_batch(&self, input_ids: &[Vec<usize>]) -> DenseTensor {
200        let batch_size = input_ids.len();
201        let seq_len = input_ids[0].len();
202        let hidden_dim = self.config.hidden_size;
203        
204        let mut data = Vec::with_capacity(batch_size * seq_len * hidden_dim);
205        
206        for batch in input_ids {
207            for &token_id in batch {
208                let start = token_id * hidden_dim;
209                let end = start + hidden_dim;
210                data.extend_from_slice(&self.embed_tokens.data()[start..end]);
211            }
212        }
213        
214        DenseTensor::new(data, vec![batch_size, seq_len, hidden_dim])
215    }
216
217    /// Get the hidden dimension
218    pub fn hidden_dim(&self) -> usize {
219        self.config.hidden_size
220    }
221
222    /// Get the vocabulary size
223    pub fn vocab_size(&self) -> usize {
224        self.config.vocab_size
225    }
226
227    /// Get number of layers
228    pub fn num_layers(&self) -> usize {
229        self.layers.len()
230    }
231
232    /// Get the number of parameters in the model
233    pub fn num_parameters(&self) -> usize {
234        let mut total = 0;
235
236        // Embeddings
237        total += self.embed_tokens.shape().iter().product::<usize>();
238
239        // Each decoder layer
240        for layer in &self.layers {
241            total += layer.num_parameters();
242        }
243
244        // Final norm
245        total += self.norm.weight.shape().iter().product::<usize>();
246
247        // LM head
248        if let Some(lm_head) = &self.lm_head {
249            total += lm_head.shape().iter().product::<usize>();
250        }
251
252        total
253    }
254
255    /// Get model size in MB (assuming f64)
256    pub fn size_mb(&self) -> f64 {
257        (self.num_parameters() * 8) as f64 / (1024.0 * 1024.0)
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264    use crate::tensor::DenseTensor;
265    use crate::tensor::traits::TensorBase;
266
267    fn create_test_layer(config: &LlamaConfig) -> LlamaDecoderLayer {
268        let hidden_dim = config.hidden_size;
269        let num_heads = config.num_attention_heads;
270
271        // Create attention weights
272        let w_q = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
273        let w_k = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
274        let w_v = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
275        let w_o = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
276        let self_attn = MultiHeadAttention::standard(w_q, w_k, w_v, w_o, num_heads);
277
278        // Create FFN (SwiGLU)
279        let gate_proj = DenseTensor::ones(vec![hidden_dim, config.intermediate_size]);
280        let up_proj = DenseTensor::ones(vec![hidden_dim, config.intermediate_size]);
281        let down_proj = DenseTensor::ones(vec![config.intermediate_size, hidden_dim]);
282        let mlp = FeedForward::swiglu(gate_proj, up_proj, down_proj);
283
284        // Create norms
285        let input_layernorm = RMSNorm::default(hidden_dim);
286        let post_attention_layernorm = RMSNorm::default(hidden_dim);
287
288        LlamaDecoderLayer::new(self_attn, mlp, input_layernorm, post_attention_layernorm)
289    }
290
291    #[test]
292    fn test_decoder_layer() {
293        let config = LlamaConfig::llama_7b();
294        let layer = create_test_layer(&config);
295
296        let batch_size = 2;
297        let seq_len = 4;
298        let x = DenseTensor::ones(vec![batch_size, seq_len, config.hidden_size]);
299
300        let output = layer.forward(&x, None);
301
302        assert_eq!(output.shape(), &[batch_size, seq_len, config.hidden_size]);
303    }
304
305    #[test]
306    fn test_llama_model_creation() {
307        let config = LlamaConfig::llama_7b();
308
309        let embed_tokens = DenseTensor::ones(vec![config.vocab_size, config.hidden_size]);
310        let layers = vec![create_test_layer(&config); config.num_hidden_layers];
311        let norm = RMSNorm::default(config.hidden_size);
312        let lm_head = None; // Tied with embeddings
313
314        let model = LlamaModel::new(config, embed_tokens, layers, norm, lm_head);
315
316        assert_eq!(model.num_layers(), 32);
317        assert_eq!(model.vocab_size(), 32000);
318        assert_eq!(model.hidden_dim(), 4096);
319    }
320}
321
322// ============================================================================
323// LlamaModel Graph Builder
324// ============================================================================
325
326use crate::transformer::graph_transformer::GraphTransformer;
327
328/// LlamaModel graph builder for constructing graph-structured Llama models
329///
330/// This builder converts a standard LlamaModel into a graph-structured representation
331/// that can leverage god-gragh's graph algorithms for optimization and analysis.
332///
333/// # Example
334///
335/// ```no_run
336/// use god_gragh::transformer::model::{LlamaModel, LlamaConfig, LlamaModelGraphBuilder};
337/// use god_gragh::transformer::layers::RMSNorm;
338/// use god_gragh::tensor::DenseTensor;
339///
340/// let config = LlamaConfig::llama_7b();
341/// let embed_tokens = DenseTensor::ones(vec![config.vocab_size, config.hidden_size]);
342/// let layers = vec![]; // Add your layers here
343/// let norm = RMSNorm::default(config.hidden_size);
344/// let model = LlamaModel::new(config, embed_tokens, layers, norm, None);
345///
346/// let builder = LlamaModelGraphBuilder::new(&model);
347/// let graph_transformer = builder.build_graph();
348/// ```
349pub struct LlamaModelGraphBuilder<'a> {
350    model: &'a LlamaModel,
351}
352
353impl<'a> LlamaModelGraphBuilder<'a> {
354    /// Create a new graph builder from a LlamaModel
355    pub fn new(model: &'a LlamaModel) -> Self {
356        Self { model }
357    }
358
359    /// Build graph-structured transformer from the model
360    pub fn build_graph(&self) -> GraphTransformer {
361        let mut transformer = GraphTransformer::new(
362            self.model.num_layers(),
363            self.model.config.num_attention_heads,
364            self.model.config.hidden_size,
365        );
366
367        // Build graph structure
368        // Note: In a real implementation, this would use actual weights
369        // For now, we create the graph topology
370        let dummy_input = vec![0; 1]; // Single token for graph structure
371        transformer.build_graph(&dummy_input);
372
373        transformer
374    }
375
376    /// Build graph with specific input sequence
377    pub fn build_graph_for_input(&self, input_ids: &[usize]) -> GraphTransformer {
378        let mut transformer = GraphTransformer::new(
379            self.model.num_layers(),
380            self.model.config.num_attention_heads,
381            self.model.config.hidden_size,
382        );
383
384        transformer.build_graph(input_ids);
385        transformer
386    }
387
388    /// Export graph to DOT format for visualization
389    pub fn export_to_dot(&self, transformer: &GraphTransformer) -> String {
390        transformer.to_dot()
391    }
392}
393
394#[cfg(test)]
395mod graph_builder_tests {
396    use super::*;
397    use crate::transformer::layers::{MultiHeadAttention, FeedForward, RMSNorm};
398
399    fn create_test_layer(config: &LlamaConfig) -> LlamaDecoderLayer {
400        let hidden_dim = config.hidden_size;
401        let num_heads = config.num_attention_heads;
402
403        let w_q = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
404        let w_k = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
405        let w_v = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
406        let w_o = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
407        let self_attn = MultiHeadAttention::standard(w_q, w_k, w_v, w_o, num_heads);
408
409        let gate_proj = DenseTensor::ones(vec![hidden_dim, config.intermediate_size]);
410        let up_proj = DenseTensor::ones(vec![hidden_dim, config.intermediate_size]);
411        let down_proj = DenseTensor::ones(vec![config.intermediate_size, hidden_dim]);
412        let mlp = FeedForward::swiglu(gate_proj, up_proj, down_proj);
413
414        let input_layernorm = RMSNorm::default(hidden_dim);
415        let post_attention_layernorm = RMSNorm::default(hidden_dim);
416
417        LlamaDecoderLayer::new(self_attn, mlp, input_layernorm, post_attention_layernorm)
418    }
419
420    #[test]
421    fn test_llama_model_graph_builder() {
422        let config = LlamaConfig::llama_7b();
423        let embed_tokens = DenseTensor::ones(vec![config.vocab_size, config.hidden_size]);
424        let layers = vec![create_test_layer(&config); 2]; // Use 2 layers for test
425        let norm = RMSNorm::default(config.hidden_size);
426        let lm_head = None;
427
428        let model = LlamaModel::new(config.clone(), embed_tokens, layers, norm, lm_head);
429
430        let builder = LlamaModelGraphBuilder::new(&model);
431        let transformer = builder.build_graph();
432
433        // Verify graph was built
434        assert!(transformer.num_nodes() > 0);
435        assert!(transformer.num_edges() > 0);
436    }
437
438    #[test]
439    fn test_llama_model_graph_builder_with_input() {
440        let config = LlamaConfig::llama_7b();
441        let embed_tokens = DenseTensor::ones(vec![config.vocab_size, config.hidden_size]);
442        let layers = vec![create_test_layer(&config); 1];
443        let norm = RMSNorm::default(config.hidden_size);
444        let lm_head = None;
445
446        let model = LlamaModel::new(config.clone(), embed_tokens, layers, norm, lm_head);
447
448        let builder = LlamaModelGraphBuilder::new(&model);
449        let input_ids = vec![1, 2, 3, 4, 5];
450        let mut transformer = builder.build_graph_for_input(&input_ids);
451
452        // Verify graph structure
453        assert!(transformer.num_nodes() > 0);
454        assert!(transformer.num_edges() > 0);
455
456        // Test forward pass
457        let output = transformer.forward(&input_ids);
458        assert!(!output.data().is_empty());
459    }
460
461    #[test]
462    fn test_graph_export_to_dot() {
463        let config = LlamaConfig::llama_7b();
464        let embed_tokens = DenseTensor::ones(vec![config.vocab_size, config.hidden_size]);
465        let layers = vec![create_test_layer(&config); 1];
466        let norm = RMSNorm::default(config.hidden_size);
467        let lm_head = None;
468
469        let model = LlamaModel::new(config.clone(), embed_tokens, layers, norm, lm_head);
470
471        let builder = LlamaModelGraphBuilder::new(&model);
472        let transformer = builder.build_graph();
473        let dot = builder.export_to_dot(&transformer);
474
475        // Verify DOT format
476        assert!(dot.contains("digraph Transformer"));
477        assert!(dot.contains("rankdir=TB"));
478    }
479}