ghostflow_nn/
gpt.rs

1//! GPT (Generative Pre-trained Transformer)
2//!
3//! Implements GPT-style autoregressive language models
4//! - Token embeddings
5//! - Position embeddings
6//! - Causal (masked) self-attention
7//! - Transformer decoder blocks
8//! - Language modeling head
9
10use ghostflow_core::Tensor;
11use crate::transformer::TransformerEncoder;
12use crate::linear::Linear;
13use crate::norm::LayerNorm;
14use crate::Module;
15
16/// GPT configuration
17#[derive(Debug, Clone)]
18pub struct GPTConfig {
19    /// Vocabulary size
20    pub vocab_size: usize,
21    /// Context length (maximum sequence length)
22    pub context_length: usize,
23    /// Embedding dimension
24    pub embed_dim: usize,
25    /// Number of transformer layers
26    pub num_layers: usize,
27    /// Number of attention heads
28    pub num_heads: usize,
29    /// Feed-forward hidden dimension
30    pub ff_dim: usize,
31    /// Dropout probability
32    pub dropout: f32,
33    /// Use bias in linear layers
34    pub bias: bool,
35}
36
37impl Default for GPTConfig {
38    fn default() -> Self {
39        GPTConfig {
40            vocab_size: 50257,
41            context_length: 1024,
42            embed_dim: 768,
43            num_layers: 12,
44            num_heads: 12,
45            ff_dim: 3072,
46            dropout: 0.1,
47            bias: true,
48        }
49    }
50}
51
52impl GPTConfig {
53    /// GPT-2 Small (117M parameters)
54    pub fn gpt2_small() -> Self {
55        Self::default()
56    }
57    
58    /// GPT-2 Medium (345M parameters)
59    pub fn gpt2_medium() -> Self {
60        GPTConfig {
61            embed_dim: 1024,
62            num_layers: 24,
63            num_heads: 16,
64            ff_dim: 4096,
65            ..Default::default()
66        }
67    }
68    
69    /// GPT-2 Large (774M parameters)
70    pub fn gpt2_large() -> Self {
71        GPTConfig {
72            embed_dim: 1280,
73            num_layers: 36,
74            num_heads: 20,
75            ff_dim: 5120,
76            ..Default::default()
77        }
78    }
79    
80    /// GPT-2 XL (1.5B parameters)
81    pub fn gpt2_xl() -> Self {
82        GPTConfig {
83            embed_dim: 1600,
84            num_layers: 48,
85            num_heads: 25,
86            ff_dim: 6400,
87            ..Default::default()
88        }
89    }
90    
91    /// GPT-3 Small (125M parameters)
92    pub fn gpt3_small() -> Self {
93        GPTConfig {
94            vocab_size: 50257,
95            context_length: 2048,
96            embed_dim: 768,
97            num_layers: 12,
98            num_heads: 12,
99            ff_dim: 3072,
100            dropout: 0.0,
101            bias: false,
102        }
103    }
104    
105    /// GPT-3 Medium (350M parameters)
106    pub fn gpt3_medium() -> Self {
107        GPTConfig {
108            vocab_size: 50257,
109            context_length: 2048,
110            embed_dim: 1024,
111            num_layers: 24,
112            num_heads: 16,
113            ff_dim: 4096,
114            dropout: 0.0,
115            bias: false,
116        }
117    }
118    
119    /// GPT-3 Large (760M parameters)
120    pub fn gpt3_large() -> Self {
121        GPTConfig {
122            vocab_size: 50257,
123            context_length: 2048,
124            embed_dim: 1280,
125            num_layers: 36,
126            num_heads: 20,
127            ff_dim: 5120,
128            dropout: 0.0,
129            bias: false,
130        }
131    }
132    
133    /// GPT-3 XL (1.3B parameters)
134    pub fn gpt3_xl() -> Self {
135        GPTConfig {
136            vocab_size: 50257,
137            context_length: 2048,
138            embed_dim: 1536,
139            num_layers: 48,
140            num_heads: 24,
141            ff_dim: 6144,
142            dropout: 0.0,
143            bias: false,
144        }
145    }
146    
147    /// GPT-Tiny (for testing)
148    pub fn gpt_tiny() -> Self {
149        GPTConfig {
150            vocab_size: 1000,
151            context_length: 128,
152            embed_dim: 128,
153            num_layers: 2,
154            num_heads: 2,
155            ff_dim: 512,
156            dropout: 0.1,
157            bias: true,
158        }
159    }
160}
161
162/// GPT embeddings (token + position)
163pub struct GPTEmbeddings {
164    /// Token embeddings
165    token_embeddings: Tensor,
166    /// Position embeddings
167    position_embeddings: Tensor,
168    /// Dropout
169    dropout: f32,
170    /// Configuration
171    config: GPTConfig,
172}
173
174impl GPTEmbeddings {
175    /// Create new GPT embeddings
176    pub fn new(config: GPTConfig) -> Self {
177        let token_embeddings = Tensor::randn(&[config.vocab_size, config.embed_dim]);
178        let position_embeddings = Tensor::randn(&[config.context_length, config.embed_dim]);
179        
180        GPTEmbeddings {
181            token_embeddings,
182            position_embeddings,
183            dropout: config.dropout,
184            config,
185        }
186    }
187    
188    /// Forward pass
189    pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor, String> {
190        let dims = input_ids.dims();
191        if dims.len() != 2 {
192            return Err(format!("Expected 2D input_ids, got {}D", dims.len()));
193        }
194        
195        let seq_length = dims[1];
196        
197        if seq_length > self.config.context_length {
198            return Err(format!("Sequence length {} exceeds context length {}", 
199                             seq_length, self.config.context_length));
200        }
201        
202        // Get token embeddings
203        let token_embeds = self.get_token_embeddings(input_ids)?;
204        
205        // Get position embeddings
206        let position_embeds = self.get_position_embeddings(seq_length)?;
207        
208        // Sum embeddings
209        self.sum_embeddings(&token_embeds, &position_embeds)
210    }
211    
212    /// Get token embeddings by lookup
213    fn get_token_embeddings(&self, input_ids: &Tensor) -> Result<Tensor, String> {
214        let ids_data = input_ids.data_f32();
215        let embed_data = self.token_embeddings.data_f32();
216        
217        let dims = input_ids.dims();
218        let batch_size = dims[0];
219        let seq_length = dims[1];
220        let embed_dim = self.config.embed_dim;
221        
222        let mut result = Vec::with_capacity(batch_size * seq_length * embed_dim);
223        
224        for &id in ids_data.iter() {
225            let idx = id as usize;
226            if idx >= self.config.vocab_size {
227                return Err(format!("Token ID {} out of vocabulary range", idx));
228            }
229            
230            let start = idx * embed_dim;
231            let end = start + embed_dim;
232            result.extend_from_slice(&embed_data[start..end]);
233        }
234        
235        Tensor::from_slice(&result, &[batch_size, seq_length, embed_dim])
236            .map_err(|e| format!("Failed to create token embeddings: {:?}", e))
237    }
238    
239    /// Get position embeddings
240    fn get_position_embeddings(&self, seq_length: usize) -> Result<Tensor, String> {
241        let embed_data = self.position_embeddings.data_f32();
242        let embed_dim = self.config.embed_dim;
243        
244        let result = embed_data[..seq_length * embed_dim].to_vec();
245        
246        Tensor::from_slice(&result, &[seq_length, embed_dim])
247            .map_err(|e| format!("Failed to create position embeddings: {:?}", e))
248    }
249    
250    /// Sum token and position embeddings
251    fn sum_embeddings(&self, token: &Tensor, position: &Tensor) -> Result<Tensor, String> {
252        let token_data = token.data_f32();
253        let pos_data = position.data_f32();
254        
255        let dims = token.dims();
256        let batch_size = dims[0];
257        let seq_length = dims[1];
258        let embed_dim = dims[2];
259        
260        let mut result = Vec::with_capacity(token_data.len());
261        
262        for b in 0..batch_size {
263            for s in 0..seq_length {
264                for e in 0..embed_dim {
265                    let token_idx = b * seq_length * embed_dim + s * embed_dim + e;
266                    let pos_idx = s * embed_dim + e;
267                    result.push(token_data[token_idx] + pos_data[pos_idx]);
268                }
269            }
270        }
271        
272        Tensor::from_slice(&result, &[batch_size, seq_length, embed_dim])
273            .map_err(|e| format!("Failed to sum embeddings: {:?}", e))
274    }
275}
276
277/// GPT model
278pub struct GPTModel {
279    /// Configuration
280    config: GPTConfig,
281    /// Embeddings
282    embeddings: GPTEmbeddings,
283    /// Transformer blocks
284    transformer: TransformerEncoder,
285    /// Final layer norm
286    ln_f: LayerNorm,
287}
288
289impl GPTModel {
290    /// Create new GPT model
291    pub fn new(config: GPTConfig) -> Self {
292        let embeddings = GPTEmbeddings::new(config.clone());
293        
294        let transformer = TransformerEncoder::new(
295            config.embed_dim,
296            config.num_heads,
297            config.ff_dim,
298            config.num_layers,
299            config.dropout,
300        );
301        
302        let ln_f = LayerNorm::new(&[config.embed_dim]);
303        
304        GPTModel {
305            config,
306            embeddings,
307            transformer,
308            ln_f,
309        }
310    }
311    
312    /// Forward pass
313    pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor, String> {
314        // Get embeddings
315        let hidden_states = self.embeddings.forward(input_ids)?;
316        
317        // Transformer blocks (with causal masking)
318        let hidden_states = self.transformer.forward(&hidden_states);
319        
320        // Final layer norm
321        let hidden_states = self.ln_f.forward(&hidden_states);
322        
323        Ok(hidden_states)
324    }
325}
326
327/// GPT for Language Modeling
328pub struct GPTForCausalLM {
329    /// Base GPT model
330    gpt: GPTModel,
331    /// Language modeling head
332    lm_head: Linear,
333}
334
335impl GPTForCausalLM {
336    /// Create new GPT for causal language modeling
337    pub fn new(config: GPTConfig) -> Self {
338        let gpt = GPTModel::new(config.clone());
339        let lm_head = Linear::new(config.embed_dim, config.vocab_size);
340        
341        GPTForCausalLM {
342            gpt,
343            lm_head,
344        }
345    }
346    
347    /// Forward pass
348    pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor, String> {
349        let hidden_states = self.gpt.forward(input_ids)?;
350        let logits = self.lm_head.forward(&hidden_states);
351        Ok(logits)
352    }
353    
354    /// Generate text autoregressively
355    pub fn generate(&self, input_ids: &Tensor, max_new_tokens: usize, 
356                    temperature: f32) -> Result<Vec<usize>, String> {
357        let mut current_ids = input_ids.data_f32().iter().map(|&x| x as usize).collect::<Vec<_>>();
358        
359        for _ in 0..max_new_tokens {
360            // Get logits for current sequence
361            let input_tensor = Tensor::from_slice(
362                &current_ids.iter().map(|&x| x as f32).collect::<Vec<_>>(),
363                &[1, current_ids.len()]
364            ).map_err(|e| format!("Failed to create input tensor: {:?}", e))?;
365            
366            let logits = self.forward(&input_tensor)?;
367            
368            // Get logits for last token
369            let last_logits = self.extract_last_token_logits(&logits)?;
370            
371            // Apply temperature and sample
372            let next_token = self.sample_token(&last_logits, temperature)?;
373            
374            current_ids.push(next_token);
375        }
376        
377        Ok(current_ids)
378    }
379    
380    /// Extract logits for last token
381    fn extract_last_token_logits(&self, logits: &Tensor) -> Result<Tensor, String> {
382        let data = logits.data_f32();
383        let dims = logits.dims();
384        
385        if dims.len() != 3 {
386            return Err(format!("Expected 3D logits, got {}D", dims.len()));
387        }
388        
389        let seq_length = dims[1];
390        let vocab_size = dims[2];
391        
392        // Extract last token: [vocab_size]
393        let start = (seq_length - 1) * vocab_size;
394        let end = start + vocab_size;
395        let last_logits = data[start..end].to_vec();
396        
397        Tensor::from_slice(&last_logits, &[vocab_size])
398            .map_err(|e| format!("Failed to extract last token logits: {:?}", e))
399    }
400    
401    /// Sample next token from logits
402    fn sample_token(&self, logits: &Tensor, temperature: f32) -> Result<usize, String> {
403        let data = logits.data_f32();
404        
405        // Apply temperature
406        let scaled: Vec<f32> = data.iter().map(|&x| x / temperature).collect();
407        
408        // Softmax
409        let max_val = scaled.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
410        let exp_vals: Vec<f32> = scaled.iter().map(|&x| (x - max_val).exp()).collect();
411        let sum: f32 = exp_vals.iter().sum();
412        let probs: Vec<f32> = exp_vals.iter().map(|&x| x / sum).collect();
413        
414        // Sample (greedy for now - take argmax)
415        let next_token = probs.iter()
416            .enumerate()
417            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
418            .map(|(idx, _)| idx)
419            .ok_or_else(|| "Failed to sample token".to_string())?;
420        
421        Ok(next_token)
422    }
423}
424
425/// GPT for Text Classification
426pub struct GPTForSequenceClassification {
427    /// Base GPT model
428    gpt: GPTModel,
429    /// Classification head
430    classifier: Linear,
431    /// Number of labels
432    num_labels: usize,
433}
434
435impl GPTForSequenceClassification {
436    /// Create new GPT for classification
437    pub fn new(config: GPTConfig, num_labels: usize) -> Self {
438        let gpt = GPTModel::new(config.clone());
439        let classifier = Linear::new(config.embed_dim, num_labels);
440        
441        GPTForSequenceClassification {
442            gpt,
443            classifier,
444            num_labels,
445        }
446    }
447    
448    /// Forward pass
449    pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor, String> {
450        let hidden_states = self.gpt.forward(input_ids)?;
451        
452        // Use last token for classification
453        let last_hidden = self.extract_last_token(&hidden_states)?;
454        
455        let logits = self.classifier.forward(&last_hidden);
456        Ok(logits)
457    }
458    
459    /// Extract last token hidden state
460    fn extract_last_token(&self, hidden_states: &Tensor) -> Result<Tensor, String> {
461        let data = hidden_states.data_f32();
462        let dims = hidden_states.dims();
463        
464        if dims.len() != 3 {
465            return Err(format!("Expected 3D hidden states, got {}D", dims.len()));
466        }
467        
468        let batch_size = dims[0];
469        let seq_length = dims[1];
470        let embed_dim = dims[2];
471        
472        let mut result = Vec::with_capacity(batch_size * embed_dim);
473        
474        for b in 0..batch_size {
475            let start = b * seq_length * embed_dim + (seq_length - 1) * embed_dim;
476            let end = start + embed_dim;
477            result.extend_from_slice(&data[start..end]);
478        }
479        
480        Tensor::from_slice(&result, &[batch_size, embed_dim])
481            .map_err(|e| format!("Failed to extract last token: {:?}", e))
482    }
483}
484
485#[cfg(test)]
486mod tests {
487    use super::*;
488    
489    #[test]
490    fn test_gpt_config() {
491        let config = GPTConfig::gpt2_small();
492        assert_eq!(config.embed_dim, 768);
493        assert_eq!(config.num_layers, 12);
494        
495        let config = GPTConfig::gpt2_xl();
496        assert_eq!(config.embed_dim, 1600);
497        assert_eq!(config.num_layers, 48);
498        
499        let config = GPTConfig::gpt3_large();
500        assert_eq!(config.embed_dim, 1280);
501        assert_eq!(config.context_length, 2048);
502    }
503    
504    #[test]
505    fn test_gpt_embeddings() {
506        let config = GPTConfig::gpt_tiny();
507        let embeddings = GPTEmbeddings::new(config);
508        
509        let input_ids = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
510        let output = embeddings.forward(&input_ids).unwrap();
511        
512        assert_eq!(output.dims(), &[2, 2, 128]); // batch=2, seq=2, embed=128
513    }
514    
515    #[test]
516    fn test_gpt_model() {
517        let config = GPTConfig::gpt_tiny();
518        let gpt = GPTModel::new(config);
519        
520        let input_ids = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
521        let output = gpt.forward(&input_ids).unwrap();
522        
523        assert_eq!(output.dims(), &[2, 2, 128]); // batch=2, seq=2, embed=128
524    }
525    
526    #[test]
527    fn test_gpt_for_causal_lm() {
528        let config = GPTConfig::gpt_tiny();
529        let gpt = GPTForCausalLM::new(config.clone());
530        
531        let input_ids = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
532        let output = gpt.forward(&input_ids).unwrap();
533        
534        assert_eq!(output.dims(), &[2, 2, 1000]); // batch=2, seq=2, vocab=1000
535    }
536    
537    #[test]
538    fn test_gpt_for_classification() {
539        let config = GPTConfig::gpt_tiny();
540        let gpt = GPTForSequenceClassification::new(config, 2);
541        
542        let input_ids = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
543        let output = gpt.forward(&input_ids).unwrap();
544        
545        assert_eq!(output.dims(), &[2, 2]); // batch=2, num_labels=2
546    }
547}