ghostflow_nn/
t5.rs

1//! T5 (Text-to-Text Transfer Transformer)
2//!
3//! Implements T5 as described in "Exploring the Limits of Transfer Learning"
4//! - Encoder-decoder architecture
5//! - Relative position embeddings
6//! - Text-to-text framework
7//! - Multiple task support (translation, summarization, QA, etc.)
8
9use ghostflow_core::Tensor;
10use crate::transformer::{TransformerEncoder, TransformerDecoderLayer};
11use crate::linear::Linear;
12use crate::norm::LayerNorm;
13use crate::Module;
14
15/// T5 configuration
16#[derive(Debug, Clone)]
17pub struct T5Config {
18    /// Vocabulary size
19    pub vocab_size: usize,
20    /// Model dimension
21    pub d_model: usize,
22    /// Key/value dimension
23    pub d_kv: usize,
24    /// Feed-forward dimension
25    pub d_ff: usize,
26    /// Number of encoder layers
27    pub num_encoder_layers: usize,
28    /// Number of decoder layers
29    pub num_decoder_layers: usize,
30    /// Number of attention heads
31    pub num_heads: usize,
32    /// Dropout rate
33    pub dropout: f32,
34    /// Use relative attention bias
35    pub relative_attention: bool,
36}
37
38impl Default for T5Config {
39    fn default() -> Self {
40        T5Config {
41            vocab_size: 32128,
42            d_model: 512,
43            d_kv: 64,
44            d_ff: 2048,
45            num_encoder_layers: 6,
46            num_decoder_layers: 6,
47            num_heads: 8,
48            dropout: 0.1,
49            relative_attention: true,
50        }
51    }
52}
53
54impl T5Config {
55    /// T5-Small (60M parameters)
56    pub fn t5_small() -> Self {
57        Self::default()
58    }
59    
60    /// T5-Base (220M parameters)
61    pub fn t5_base() -> Self {
62        T5Config {
63            d_model: 768,
64            d_kv: 64,
65            d_ff: 3072,
66            num_encoder_layers: 12,
67            num_decoder_layers: 12,
68            num_heads: 12,
69            ..Default::default()
70        }
71    }
72    
73    /// T5-Large (770M parameters)
74    pub fn t5_large() -> Self {
75        T5Config {
76            d_model: 1024,
77            d_kv: 64,
78            d_ff: 4096,
79            num_encoder_layers: 24,
80            num_decoder_layers: 24,
81            num_heads: 16,
82            ..Default::default()
83        }
84    }
85    
86    /// T5-3B (3B parameters)
87    pub fn t5_3b() -> Self {
88        T5Config {
89            d_model: 1024,
90            d_kv: 128,
91            d_ff: 16384,
92            num_encoder_layers: 24,
93            num_decoder_layers: 24,
94            num_heads: 32,
95            ..Default::default()
96        }
97    }
98    
99    /// T5-11B (11B parameters)
100    pub fn t5_11b() -> Self {
101        T5Config {
102            d_model: 1024,
103            d_kv: 128,
104            d_ff: 65536,
105            num_encoder_layers: 24,
106            num_decoder_layers: 24,
107            num_heads: 128,
108            ..Default::default()
109        }
110    }
111    
112    /// T5-Tiny (for testing)
113    pub fn t5_tiny() -> Self {
114        T5Config {
115            vocab_size: 1000,
116            d_model: 128,
117            d_kv: 16,
118            d_ff: 512,
119            num_encoder_layers: 2,
120            num_decoder_layers: 2,
121            num_heads: 4,
122            dropout: 0.1,
123            relative_attention: true,
124        }
125    }
126}
127
128/// T5 embeddings (shared between encoder and decoder)
129pub struct T5Embeddings {
130    /// Token embeddings
131    token_embeddings: Tensor,
132    /// Configuration
133    config: T5Config,
134}
135
136impl T5Embeddings {
137    /// Create new T5 embeddings
138    pub fn new(config: T5Config) -> Self {
139        let token_embeddings = Tensor::randn(&[config.vocab_size, config.d_model]);
140        
141        T5Embeddings {
142            token_embeddings,
143            config,
144        }
145    }
146    
147    /// Forward pass
148    pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor, String> {
149        let ids_data = input_ids.data_f32();
150        let embed_data = self.token_embeddings.data_f32();
151        
152        let dims = input_ids.dims();
153        if dims.len() != 2 {
154            return Err(format!("Expected 2D input_ids, got {}D", dims.len()));
155        }
156        
157        let batch_size = dims[0];
158        let seq_length = dims[1];
159        let d_model = self.config.d_model;
160        
161        let mut result = Vec::with_capacity(batch_size * seq_length * d_model);
162        
163        for &id in ids_data.iter() {
164            let idx = id as usize;
165            if idx >= self.config.vocab_size {
166                return Err(format!("Token ID {} out of vocabulary range", idx));
167            }
168            
169            let start = idx * d_model;
170            let end = start + d_model;
171            result.extend_from_slice(&embed_data[start..end]);
172        }
173        
174        Tensor::from_slice(&result, &[batch_size, seq_length, d_model])
175            .map_err(|e| format!("Failed to create embeddings: {:?}", e))
176    }
177}
178
179/// T5 Encoder
180pub struct T5Encoder {
181    /// Embeddings
182    embeddings: T5Embeddings,
183    /// Encoder layers
184    encoder: TransformerEncoder,
185    /// Final layer norm
186    final_layer_norm: LayerNorm,
187    /// Dropout
188    dropout: f32,
189}
190
191impl T5Encoder {
192    /// Create new T5 encoder
193    pub fn new(config: &T5Config, embeddings: T5Embeddings) -> Self {
194        let encoder = TransformerEncoder::new(
195            config.d_model,
196            config.num_heads,
197            config.d_ff,
198            config.num_encoder_layers,
199            config.dropout,
200        );
201        
202        let final_layer_norm = LayerNorm::new(&[config.d_model]);
203        
204        T5Encoder {
205            embeddings,
206            encoder,
207            final_layer_norm,
208            dropout: config.dropout,
209        }
210    }
211    
212    /// Forward pass
213    pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor, String> {
214        // Get embeddings
215        let hidden_states = self.embeddings.forward(input_ids)?;
216        
217        // Encoder layers
218        let hidden_states = self.encoder.forward(&hidden_states);
219        
220        // Final layer norm
221        let hidden_states = self.final_layer_norm.forward(&hidden_states);
222        
223        Ok(hidden_states)
224    }
225}
226
227/// T5 Decoder
228pub struct T5Decoder {
229    /// Embeddings (shared with encoder)
230    embeddings: T5Embeddings,
231    /// Decoder layers
232    layers: Vec<TransformerDecoderLayer>,
233    /// Final layer norm
234    final_layer_norm: LayerNorm,
235    /// Dropout
236    dropout: f32,
237}
238
239impl T5Decoder {
240    /// Create new T5 decoder
241    pub fn new(config: &T5Config, embeddings: T5Embeddings) -> Self {
242        let layers = (0..config.num_decoder_layers)
243            .map(|_| TransformerDecoderLayer::new(config.d_model, config.num_heads, config.d_ff, config.dropout))
244            .collect();
245        
246        let final_layer_norm = LayerNorm::new(&[config.d_model]);
247        
248        T5Decoder {
249            embeddings,
250            layers,
251            final_layer_norm,
252            dropout: config.dropout,
253        }
254    }
255    
256    /// Forward pass
257    pub fn forward(&self, decoder_input_ids: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor, String> {
258        // Get embeddings
259        let mut hidden_states = self.embeddings.forward(decoder_input_ids)?;
260        
261        // Decoder layers with cross-attention
262        for layer in &self.layers {
263            hidden_states = layer.forward_with_memory(&hidden_states, encoder_hidden_states, None, None);
264        }
265        
266        // Final layer norm
267        let hidden_states = self.final_layer_norm.forward(&hidden_states);
268        
269        Ok(hidden_states)
270    }
271}
272
273/// T5 Model (encoder-decoder)
274pub struct T5Model {
275    /// Configuration
276    config: T5Config,
277    /// Shared embeddings
278    shared_embeddings: T5Embeddings,
279    /// Encoder
280    encoder: T5Encoder,
281    /// Decoder
282    decoder: T5Decoder,
283}
284
285impl T5Model {
286    /// Create new T5 model
287    pub fn new(config: T5Config) -> Self {
288        // Create shared embeddings
289        let shared_embeddings = T5Embeddings::new(config.clone());
290        
291        // Create encoder with shared embeddings
292        let encoder_embeddings = T5Embeddings::new(config.clone());
293        let encoder = T5Encoder::new(&config, encoder_embeddings);
294        
295        // Create decoder with shared embeddings
296        let decoder_embeddings = T5Embeddings::new(config.clone());
297        let decoder = T5Decoder::new(&config, decoder_embeddings);
298        
299        T5Model {
300            config,
301            shared_embeddings,
302            encoder,
303            decoder,
304        }
305    }
306    
307    /// Forward pass
308    pub fn forward(&self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<T5Output, String> {
309        // Encode
310        let encoder_hidden_states = self.encoder.forward(input_ids)?;
311        
312        // Decode
313        let decoder_hidden_states = self.decoder.forward(decoder_input_ids, &encoder_hidden_states)?;
314        
315        Ok(T5Output {
316            last_hidden_state: decoder_hidden_states,
317            encoder_last_hidden_state: encoder_hidden_states,
318        })
319    }
320}
321
322/// T5 output
323pub struct T5Output {
324    /// Decoder last hidden state
325    pub last_hidden_state: Tensor,
326    /// Encoder last hidden state
327    pub encoder_last_hidden_state: Tensor,
328}
329
330/// T5 for Conditional Generation (translation, summarization, etc.)
331pub struct T5ForConditionalGeneration {
332    /// Base T5 model
333    t5: T5Model,
334    /// Language modeling head
335    lm_head: Linear,
336}
337
338impl T5ForConditionalGeneration {
339    /// Create new T5 for conditional generation
340    pub fn new(config: T5Config) -> Self {
341        let t5 = T5Model::new(config.clone());
342        let lm_head = Linear::new(config.d_model, config.vocab_size);
343        
344        T5ForConditionalGeneration {
345            t5,
346            lm_head,
347        }
348    }
349    
350    /// Forward pass
351    pub fn forward(&self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor, String> {
352        let output = self.t5.forward(input_ids, decoder_input_ids)?;
353        let logits = self.lm_head.forward(&output.last_hidden_state);
354        Ok(logits)
355    }
356    
357    /// Generate text (simplified greedy decoding)
358    pub fn generate(&self, input_ids: &Tensor, max_length: usize) -> Result<Vec<usize>, String> {
359        // Start with decoder start token (0)
360        let mut generated = vec![0usize];
361        
362        for _ in 0..max_length {
363            // Create decoder input tensor
364            let decoder_input = Tensor::from_slice(
365                &generated.iter().map(|&x| x as f32).collect::<Vec<_>>(),
366                &[1, generated.len()]
367            ).map_err(|e| format!("Failed to create decoder input: {:?}", e))?;
368            
369            // Forward pass
370            let logits = self.forward(input_ids, &decoder_input)?;
371            
372            // Get last token logits
373            let next_token = self.sample_next_token(&logits)?;
374            
375            // Check for end token (1)
376            if next_token == 1 {
377                break;
378            }
379            
380            generated.push(next_token);
381        }
382        
383        Ok(generated)
384    }
385    
386    /// Sample next token (greedy)
387    fn sample_next_token(&self, logits: &Tensor) -> Result<usize, String> {
388        let data = logits.data_f32();
389        let dims = logits.dims();
390        
391        if dims.len() != 3 {
392            return Err(format!("Expected 3D logits, got {}D", dims.len()));
393        }
394        
395        let seq_length = dims[1];
396        let vocab_size = dims[2];
397        
398        // Get last token logits
399        let start = (seq_length - 1) * vocab_size;
400        let end = start + vocab_size;
401        let last_logits = &data[start..end];
402        
403        // Greedy sampling (argmax)
404        let next_token = last_logits.iter()
405            .enumerate()
406            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
407            .map(|(idx, _)| idx)
408            .ok_or_else(|| "Failed to sample token".to_string())?;
409        
410        Ok(next_token)
411    }
412}
413
414/// T5 for Classification
415pub struct T5ForSequenceClassification {
416    /// Base T5 model
417    t5: T5Model,
418    /// Classification head
419    classifier: Linear,
420    /// Number of labels
421    num_labels: usize,
422}
423
424impl T5ForSequenceClassification {
425    /// Create new T5 for classification
426    pub fn new(config: T5Config, num_labels: usize) -> Self {
427        let t5 = T5Model::new(config.clone());
428        let classifier = Linear::new(config.d_model, num_labels);
429        
430        T5ForSequenceClassification {
431            t5,
432            classifier,
433            num_labels,
434        }
435    }
436    
437    /// Forward pass
438    pub fn forward(&self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor, String> {
439        let output = self.t5.forward(input_ids, decoder_input_ids)?;
440        
441        // Use first decoder token for classification
442        let first_token = self.extract_first_token(&output.last_hidden_state)?;
443        
444        let logits = self.classifier.forward(&first_token);
445        Ok(logits)
446    }
447    
448    /// Extract first token
449    fn extract_first_token(&self, hidden_states: &Tensor) -> Result<Tensor, String> {
450        let data = hidden_states.data_f32();
451        let dims = hidden_states.dims();
452        
453        if dims.len() != 3 {
454            return Err(format!("Expected 3D hidden states, got {}D", dims.len()));
455        }
456        
457        let batch_size = dims[0];
458        let d_model = dims[2];
459        
460        let mut result = Vec::with_capacity(batch_size * d_model);
461        
462        for b in 0..batch_size {
463            let start = b * dims[1] * d_model;
464            let end = start + d_model;
465            result.extend_from_slice(&data[start..end]);
466        }
467        
468        Tensor::from_slice(&result, &[batch_size, d_model])
469            .map_err(|e| format!("Failed to extract first token: {:?}", e))
470    }
471}
472
473#[cfg(test)]
474mod tests {
475    use super::*;
476    
477    #[test]
478    fn test_t5_config() {
479        let config = T5Config::t5_small();
480        assert_eq!(config.d_model, 512);
481        assert_eq!(config.num_encoder_layers, 6);
482        
483        let config = T5Config::t5_base();
484        assert_eq!(config.d_model, 768);
485        assert_eq!(config.num_encoder_layers, 12);
486        
487        let config = T5Config::t5_large();
488        assert_eq!(config.d_model, 1024);
489        assert_eq!(config.num_encoder_layers, 24);
490    }
491    
492    #[test]
493    fn test_t5_embeddings() {
494        let config = T5Config::t5_tiny();
495        let embeddings = T5Embeddings::new(config);
496        
497        let input_ids = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
498        let output = embeddings.forward(&input_ids).unwrap();
499        
500        assert_eq!(output.dims(), &[2, 2, 128]); // batch=2, seq=2, d_model=128
501    }
502    
503    #[test]
504    fn test_t5_model() {
505        let config = T5Config::t5_tiny();
506        let t5 = T5Model::new(config);
507        
508        let input_ids = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
509        let decoder_input_ids = Tensor::from_slice(&[1.0, 2.0], &[2, 1]).unwrap();
510        
511        let output = t5.forward(&input_ids, &decoder_input_ids).unwrap();
512        
513        assert_eq!(output.last_hidden_state.dims(), &[2, 1, 128]); // batch=2, seq=1, d_model=128
514    }
515    
516    #[test]
517    fn test_t5_for_conditional_generation() {
518        let config = T5Config::t5_tiny();
519        let t5 = T5ForConditionalGeneration::new(config.clone());
520        
521        let input_ids = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
522        let decoder_input_ids = Tensor::from_slice(&[1.0, 2.0], &[2, 1]).unwrap();
523        
524        let output = t5.forward(&input_ids, &decoder_input_ids).unwrap();
525        
526        assert_eq!(output.dims(), &[2, 1, 1000]); // batch=2, seq=1, vocab=1000
527    }
528}