1use ghostflow_core::Tensor;
11use crate::transformer::TransformerEncoder;
12use crate::linear::Linear;
13use crate::norm::LayerNorm;
14use crate::Module;
15
16#[derive(Debug, Clone)]
18pub struct GPTConfig {
19 pub vocab_size: usize,
21 pub context_length: usize,
23 pub embed_dim: usize,
25 pub num_layers: usize,
27 pub num_heads: usize,
29 pub ff_dim: usize,
31 pub dropout: f32,
33 pub bias: bool,
35}
36
37impl Default for GPTConfig {
38 fn default() -> Self {
39 GPTConfig {
40 vocab_size: 50257,
41 context_length: 1024,
42 embed_dim: 768,
43 num_layers: 12,
44 num_heads: 12,
45 ff_dim: 3072,
46 dropout: 0.1,
47 bias: true,
48 }
49 }
50}
51
52impl GPTConfig {
53 pub fn gpt2_small() -> Self {
55 Self::default()
56 }
57
58 pub fn gpt2_medium() -> Self {
60 GPTConfig {
61 embed_dim: 1024,
62 num_layers: 24,
63 num_heads: 16,
64 ff_dim: 4096,
65 ..Default::default()
66 }
67 }
68
69 pub fn gpt2_large() -> Self {
71 GPTConfig {
72 embed_dim: 1280,
73 num_layers: 36,
74 num_heads: 20,
75 ff_dim: 5120,
76 ..Default::default()
77 }
78 }
79
80 pub fn gpt2_xl() -> Self {
82 GPTConfig {
83 embed_dim: 1600,
84 num_layers: 48,
85 num_heads: 25,
86 ff_dim: 6400,
87 ..Default::default()
88 }
89 }
90
91 pub fn gpt3_small() -> Self {
93 GPTConfig {
94 vocab_size: 50257,
95 context_length: 2048,
96 embed_dim: 768,
97 num_layers: 12,
98 num_heads: 12,
99 ff_dim: 3072,
100 dropout: 0.0,
101 bias: false,
102 }
103 }
104
105 pub fn gpt3_medium() -> Self {
107 GPTConfig {
108 vocab_size: 50257,
109 context_length: 2048,
110 embed_dim: 1024,
111 num_layers: 24,
112 num_heads: 16,
113 ff_dim: 4096,
114 dropout: 0.0,
115 bias: false,
116 }
117 }
118
119 pub fn gpt3_large() -> Self {
121 GPTConfig {
122 vocab_size: 50257,
123 context_length: 2048,
124 embed_dim: 1280,
125 num_layers: 36,
126 num_heads: 20,
127 ff_dim: 5120,
128 dropout: 0.0,
129 bias: false,
130 }
131 }
132
133 pub fn gpt3_xl() -> Self {
135 GPTConfig {
136 vocab_size: 50257,
137 context_length: 2048,
138 embed_dim: 1536,
139 num_layers: 48,
140 num_heads: 24,
141 ff_dim: 6144,
142 dropout: 0.0,
143 bias: false,
144 }
145 }
146
147 pub fn gpt_tiny() -> Self {
149 GPTConfig {
150 vocab_size: 1000,
151 context_length: 128,
152 embed_dim: 128,
153 num_layers: 2,
154 num_heads: 2,
155 ff_dim: 512,
156 dropout: 0.1,
157 bias: true,
158 }
159 }
160}
161
162pub struct GPTEmbeddings {
164 token_embeddings: Tensor,
166 position_embeddings: Tensor,
168 dropout: f32,
170 config: GPTConfig,
172}
173
174impl GPTEmbeddings {
175 pub fn new(config: GPTConfig) -> Self {
177 let token_embeddings = Tensor::randn(&[config.vocab_size, config.embed_dim]);
178 let position_embeddings = Tensor::randn(&[config.context_length, config.embed_dim]);
179
180 GPTEmbeddings {
181 token_embeddings,
182 position_embeddings,
183 dropout: config.dropout,
184 config,
185 }
186 }
187
188 pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor, String> {
190 let dims = input_ids.dims();
191 if dims.len() != 2 {
192 return Err(format!("Expected 2D input_ids, got {}D", dims.len()));
193 }
194
195 let seq_length = dims[1];
196
197 if seq_length > self.config.context_length {
198 return Err(format!("Sequence length {} exceeds context length {}",
199 seq_length, self.config.context_length));
200 }
201
202 let token_embeds = self.get_token_embeddings(input_ids)?;
204
205 let position_embeds = self.get_position_embeddings(seq_length)?;
207
208 self.sum_embeddings(&token_embeds, &position_embeds)
210 }
211
212 fn get_token_embeddings(&self, input_ids: &Tensor) -> Result<Tensor, String> {
214 let ids_data = input_ids.data_f32();
215 let embed_data = self.token_embeddings.data_f32();
216
217 let dims = input_ids.dims();
218 let batch_size = dims[0];
219 let seq_length = dims[1];
220 let embed_dim = self.config.embed_dim;
221
222 let mut result = Vec::with_capacity(batch_size * seq_length * embed_dim);
223
224 for &id in ids_data.iter() {
225 let idx = id as usize;
226 if idx >= self.config.vocab_size {
227 return Err(format!("Token ID {} out of vocabulary range", idx));
228 }
229
230 let start = idx * embed_dim;
231 let end = start + embed_dim;
232 result.extend_from_slice(&embed_data[start..end]);
233 }
234
235 Tensor::from_slice(&result, &[batch_size, seq_length, embed_dim])
236 .map_err(|e| format!("Failed to create token embeddings: {:?}", e))
237 }
238
239 fn get_position_embeddings(&self, seq_length: usize) -> Result<Tensor, String> {
241 let embed_data = self.position_embeddings.data_f32();
242 let embed_dim = self.config.embed_dim;
243
244 let result = embed_data[..seq_length * embed_dim].to_vec();
245
246 Tensor::from_slice(&result, &[seq_length, embed_dim])
247 .map_err(|e| format!("Failed to create position embeddings: {:?}", e))
248 }
249
250 fn sum_embeddings(&self, token: &Tensor, position: &Tensor) -> Result<Tensor, String> {
252 let token_data = token.data_f32();
253 let pos_data = position.data_f32();
254
255 let dims = token.dims();
256 let batch_size = dims[0];
257 let seq_length = dims[1];
258 let embed_dim = dims[2];
259
260 let mut result = Vec::with_capacity(token_data.len());
261
262 for b in 0..batch_size {
263 for s in 0..seq_length {
264 for e in 0..embed_dim {
265 let token_idx = b * seq_length * embed_dim + s * embed_dim + e;
266 let pos_idx = s * embed_dim + e;
267 result.push(token_data[token_idx] + pos_data[pos_idx]);
268 }
269 }
270 }
271
272 Tensor::from_slice(&result, &[batch_size, seq_length, embed_dim])
273 .map_err(|e| format!("Failed to sum embeddings: {:?}", e))
274 }
275}
276
277pub struct GPTModel {
279 config: GPTConfig,
281 embeddings: GPTEmbeddings,
283 transformer: TransformerEncoder,
285 ln_f: LayerNorm,
287}
288
289impl GPTModel {
290 pub fn new(config: GPTConfig) -> Self {
292 let embeddings = GPTEmbeddings::new(config.clone());
293
294 let transformer = TransformerEncoder::new(
295 config.embed_dim,
296 config.num_heads,
297 config.ff_dim,
298 config.num_layers,
299 config.dropout,
300 );
301
302 let ln_f = LayerNorm::new(&[config.embed_dim]);
303
304 GPTModel {
305 config,
306 embeddings,
307 transformer,
308 ln_f,
309 }
310 }
311
312 pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor, String> {
314 let hidden_states = self.embeddings.forward(input_ids)?;
316
317 let hidden_states = self.transformer.forward(&hidden_states);
319
320 let hidden_states = self.ln_f.forward(&hidden_states);
322
323 Ok(hidden_states)
324 }
325}
326
327pub struct GPTForCausalLM {
329 gpt: GPTModel,
331 lm_head: Linear,
333}
334
335impl GPTForCausalLM {
336 pub fn new(config: GPTConfig) -> Self {
338 let gpt = GPTModel::new(config.clone());
339 let lm_head = Linear::new(config.embed_dim, config.vocab_size);
340
341 GPTForCausalLM {
342 gpt,
343 lm_head,
344 }
345 }
346
347 pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor, String> {
349 let hidden_states = self.gpt.forward(input_ids)?;
350 let logits = self.lm_head.forward(&hidden_states);
351 Ok(logits)
352 }
353
354 pub fn generate(&self, input_ids: &Tensor, max_new_tokens: usize,
356 temperature: f32) -> Result<Vec<usize>, String> {
357 let mut current_ids = input_ids.data_f32().iter().map(|&x| x as usize).collect::<Vec<_>>();
358
359 for _ in 0..max_new_tokens {
360 let input_tensor = Tensor::from_slice(
362 ¤t_ids.iter().map(|&x| x as f32).collect::<Vec<_>>(),
363 &[1, current_ids.len()]
364 ).map_err(|e| format!("Failed to create input tensor: {:?}", e))?;
365
366 let logits = self.forward(&input_tensor)?;
367
368 let last_logits = self.extract_last_token_logits(&logits)?;
370
371 let next_token = self.sample_token(&last_logits, temperature)?;
373
374 current_ids.push(next_token);
375 }
376
377 Ok(current_ids)
378 }
379
380 fn extract_last_token_logits(&self, logits: &Tensor) -> Result<Tensor, String> {
382 let data = logits.data_f32();
383 let dims = logits.dims();
384
385 if dims.len() != 3 {
386 return Err(format!("Expected 3D logits, got {}D", dims.len()));
387 }
388
389 let seq_length = dims[1];
390 let vocab_size = dims[2];
391
392 let start = (seq_length - 1) * vocab_size;
394 let end = start + vocab_size;
395 let last_logits = data[start..end].to_vec();
396
397 Tensor::from_slice(&last_logits, &[vocab_size])
398 .map_err(|e| format!("Failed to extract last token logits: {:?}", e))
399 }
400
401 fn sample_token(&self, logits: &Tensor, temperature: f32) -> Result<usize, String> {
403 let data = logits.data_f32();
404
405 let scaled: Vec<f32> = data.iter().map(|&x| x / temperature).collect();
407
408 let max_val = scaled.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
410 let exp_vals: Vec<f32> = scaled.iter().map(|&x| (x - max_val).exp()).collect();
411 let sum: f32 = exp_vals.iter().sum();
412 let probs: Vec<f32> = exp_vals.iter().map(|&x| x / sum).collect();
413
414 let next_token = probs.iter()
416 .enumerate()
417 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
418 .map(|(idx, _)| idx)
419 .ok_or_else(|| "Failed to sample token".to_string())?;
420
421 Ok(next_token)
422 }
423}
424
425pub struct GPTForSequenceClassification {
427 gpt: GPTModel,
429 classifier: Linear,
431 num_labels: usize,
433}
434
435impl GPTForSequenceClassification {
436 pub fn new(config: GPTConfig, num_labels: usize) -> Self {
438 let gpt = GPTModel::new(config.clone());
439 let classifier = Linear::new(config.embed_dim, num_labels);
440
441 GPTForSequenceClassification {
442 gpt,
443 classifier,
444 num_labels,
445 }
446 }
447
448 pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor, String> {
450 let hidden_states = self.gpt.forward(input_ids)?;
451
452 let last_hidden = self.extract_last_token(&hidden_states)?;
454
455 let logits = self.classifier.forward(&last_hidden);
456 Ok(logits)
457 }
458
459 fn extract_last_token(&self, hidden_states: &Tensor) -> Result<Tensor, String> {
461 let data = hidden_states.data_f32();
462 let dims = hidden_states.dims();
463
464 if dims.len() != 3 {
465 return Err(format!("Expected 3D hidden states, got {}D", dims.len()));
466 }
467
468 let batch_size = dims[0];
469 let seq_length = dims[1];
470 let embed_dim = dims[2];
471
472 let mut result = Vec::with_capacity(batch_size * embed_dim);
473
474 for b in 0..batch_size {
475 let start = b * seq_length * embed_dim + (seq_length - 1) * embed_dim;
476 let end = start + embed_dim;
477 result.extend_from_slice(&data[start..end]);
478 }
479
480 Tensor::from_slice(&result, &[batch_size, embed_dim])
481 .map_err(|e| format!("Failed to extract last token: {:?}", e))
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488
489 #[test]
490 fn test_gpt_config() {
491 let config = GPTConfig::gpt2_small();
492 assert_eq!(config.embed_dim, 768);
493 assert_eq!(config.num_layers, 12);
494
495 let config = GPTConfig::gpt2_xl();
496 assert_eq!(config.embed_dim, 1600);
497 assert_eq!(config.num_layers, 48);
498
499 let config = GPTConfig::gpt3_large();
500 assert_eq!(config.embed_dim, 1280);
501 assert_eq!(config.context_length, 2048);
502 }
503
504 #[test]
505 fn test_gpt_embeddings() {
506 let config = GPTConfig::gpt_tiny();
507 let embeddings = GPTEmbeddings::new(config);
508
509 let input_ids = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
510 let output = embeddings.forward(&input_ids).unwrap();
511
512 assert_eq!(output.dims(), &[2, 2, 128]); }
514
515 #[test]
516 fn test_gpt_model() {
517 let config = GPTConfig::gpt_tiny();
518 let gpt = GPTModel::new(config);
519
520 let input_ids = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
521 let output = gpt.forward(&input_ids).unwrap();
522
523 assert_eq!(output.dims(), &[2, 2, 128]); }
525
526 #[test]
527 fn test_gpt_for_causal_lm() {
528 let config = GPTConfig::gpt_tiny();
529 let gpt = GPTForCausalLM::new(config.clone());
530
531 let input_ids = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
532 let output = gpt.forward(&input_ids).unwrap();
533
534 assert_eq!(output.dims(), &[2, 2, 1000]); }
536
537 #[test]
538 fn test_gpt_for_classification() {
539 let config = GPTConfig::gpt_tiny();
540 let gpt = GPTForSequenceClassification::new(config, 2);
541
542 let input_ids = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
543 let output = gpt.forward(&input_ids).unwrap();
544
545 assert_eq!(output.dims(), &[2, 2]); }
547}