1use ghostflow_core::Tensor;
10use crate::transformer::{TransformerEncoder, TransformerDecoderLayer};
11use crate::linear::Linear;
12use crate::norm::LayerNorm;
13use crate::Module;
14
15#[derive(Debug, Clone)]
17pub struct T5Config {
18 pub vocab_size: usize,
20 pub d_model: usize,
22 pub d_kv: usize,
24 pub d_ff: usize,
26 pub num_encoder_layers: usize,
28 pub num_decoder_layers: usize,
30 pub num_heads: usize,
32 pub dropout: f32,
34 pub relative_attention: bool,
36}
37
38impl Default for T5Config {
39 fn default() -> Self {
40 T5Config {
41 vocab_size: 32128,
42 d_model: 512,
43 d_kv: 64,
44 d_ff: 2048,
45 num_encoder_layers: 6,
46 num_decoder_layers: 6,
47 num_heads: 8,
48 dropout: 0.1,
49 relative_attention: true,
50 }
51 }
52}
53
54impl T5Config {
55 pub fn t5_small() -> Self {
57 Self::default()
58 }
59
60 pub fn t5_base() -> Self {
62 T5Config {
63 d_model: 768,
64 d_kv: 64,
65 d_ff: 3072,
66 num_encoder_layers: 12,
67 num_decoder_layers: 12,
68 num_heads: 12,
69 ..Default::default()
70 }
71 }
72
73 pub fn t5_large() -> Self {
75 T5Config {
76 d_model: 1024,
77 d_kv: 64,
78 d_ff: 4096,
79 num_encoder_layers: 24,
80 num_decoder_layers: 24,
81 num_heads: 16,
82 ..Default::default()
83 }
84 }
85
86 pub fn t5_3b() -> Self {
88 T5Config {
89 d_model: 1024,
90 d_kv: 128,
91 d_ff: 16384,
92 num_encoder_layers: 24,
93 num_decoder_layers: 24,
94 num_heads: 32,
95 ..Default::default()
96 }
97 }
98
99 pub fn t5_11b() -> Self {
101 T5Config {
102 d_model: 1024,
103 d_kv: 128,
104 d_ff: 65536,
105 num_encoder_layers: 24,
106 num_decoder_layers: 24,
107 num_heads: 128,
108 ..Default::default()
109 }
110 }
111
112 pub fn t5_tiny() -> Self {
114 T5Config {
115 vocab_size: 1000,
116 d_model: 128,
117 d_kv: 16,
118 d_ff: 512,
119 num_encoder_layers: 2,
120 num_decoder_layers: 2,
121 num_heads: 4,
122 dropout: 0.1,
123 relative_attention: true,
124 }
125 }
126}
127
128pub struct T5Embeddings {
130 token_embeddings: Tensor,
132 config: T5Config,
134}
135
136impl T5Embeddings {
137 pub fn new(config: T5Config) -> Self {
139 let token_embeddings = Tensor::randn(&[config.vocab_size, config.d_model]);
140
141 T5Embeddings {
142 token_embeddings,
143 config,
144 }
145 }
146
147 pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor, String> {
149 let ids_data = input_ids.data_f32();
150 let embed_data = self.token_embeddings.data_f32();
151
152 let dims = input_ids.dims();
153 if dims.len() != 2 {
154 return Err(format!("Expected 2D input_ids, got {}D", dims.len()));
155 }
156
157 let batch_size = dims[0];
158 let seq_length = dims[1];
159 let d_model = self.config.d_model;
160
161 let mut result = Vec::with_capacity(batch_size * seq_length * d_model);
162
163 for &id in ids_data.iter() {
164 let idx = id as usize;
165 if idx >= self.config.vocab_size {
166 return Err(format!("Token ID {} out of vocabulary range", idx));
167 }
168
169 let start = idx * d_model;
170 let end = start + d_model;
171 result.extend_from_slice(&embed_data[start..end]);
172 }
173
174 Tensor::from_slice(&result, &[batch_size, seq_length, d_model])
175 .map_err(|e| format!("Failed to create embeddings: {:?}", e))
176 }
177}
178
179pub struct T5Encoder {
181 embeddings: T5Embeddings,
183 encoder: TransformerEncoder,
185 final_layer_norm: LayerNorm,
187 dropout: f32,
189}
190
191impl T5Encoder {
192 pub fn new(config: &T5Config, embeddings: T5Embeddings) -> Self {
194 let encoder = TransformerEncoder::new(
195 config.d_model,
196 config.num_heads,
197 config.d_ff,
198 config.num_encoder_layers,
199 config.dropout,
200 );
201
202 let final_layer_norm = LayerNorm::new(&[config.d_model]);
203
204 T5Encoder {
205 embeddings,
206 encoder,
207 final_layer_norm,
208 dropout: config.dropout,
209 }
210 }
211
212 pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor, String> {
214 let hidden_states = self.embeddings.forward(input_ids)?;
216
217 let hidden_states = self.encoder.forward(&hidden_states);
219
220 let hidden_states = self.final_layer_norm.forward(&hidden_states);
222
223 Ok(hidden_states)
224 }
225}
226
227pub struct T5Decoder {
229 embeddings: T5Embeddings,
231 layers: Vec<TransformerDecoderLayer>,
233 final_layer_norm: LayerNorm,
235 dropout: f32,
237}
238
239impl T5Decoder {
240 pub fn new(config: &T5Config, embeddings: T5Embeddings) -> Self {
242 let layers = (0..config.num_decoder_layers)
243 .map(|_| TransformerDecoderLayer::new(config.d_model, config.num_heads, config.d_ff, config.dropout))
244 .collect();
245
246 let final_layer_norm = LayerNorm::new(&[config.d_model]);
247
248 T5Decoder {
249 embeddings,
250 layers,
251 final_layer_norm,
252 dropout: config.dropout,
253 }
254 }
255
256 pub fn forward(&self, decoder_input_ids: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor, String> {
258 let mut hidden_states = self.embeddings.forward(decoder_input_ids)?;
260
261 for layer in &self.layers {
263 hidden_states = layer.forward_with_memory(&hidden_states, encoder_hidden_states, None, None);
264 }
265
266 let hidden_states = self.final_layer_norm.forward(&hidden_states);
268
269 Ok(hidden_states)
270 }
271}
272
273pub struct T5Model {
275 config: T5Config,
277 shared_embeddings: T5Embeddings,
279 encoder: T5Encoder,
281 decoder: T5Decoder,
283}
284
285impl T5Model {
286 pub fn new(config: T5Config) -> Self {
288 let shared_embeddings = T5Embeddings::new(config.clone());
290
291 let encoder_embeddings = T5Embeddings::new(config.clone());
293 let encoder = T5Encoder::new(&config, encoder_embeddings);
294
295 let decoder_embeddings = T5Embeddings::new(config.clone());
297 let decoder = T5Decoder::new(&config, decoder_embeddings);
298
299 T5Model {
300 config,
301 shared_embeddings,
302 encoder,
303 decoder,
304 }
305 }
306
307 pub fn forward(&self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<T5Output, String> {
309 let encoder_hidden_states = self.encoder.forward(input_ids)?;
311
312 let decoder_hidden_states = self.decoder.forward(decoder_input_ids, &encoder_hidden_states)?;
314
315 Ok(T5Output {
316 last_hidden_state: decoder_hidden_states,
317 encoder_last_hidden_state: encoder_hidden_states,
318 })
319 }
320}
321
322pub struct T5Output {
324 pub last_hidden_state: Tensor,
326 pub encoder_last_hidden_state: Tensor,
328}
329
330pub struct T5ForConditionalGeneration {
332 t5: T5Model,
334 lm_head: Linear,
336}
337
338impl T5ForConditionalGeneration {
339 pub fn new(config: T5Config) -> Self {
341 let t5 = T5Model::new(config.clone());
342 let lm_head = Linear::new(config.d_model, config.vocab_size);
343
344 T5ForConditionalGeneration {
345 t5,
346 lm_head,
347 }
348 }
349
350 pub fn forward(&self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor, String> {
352 let output = self.t5.forward(input_ids, decoder_input_ids)?;
353 let logits = self.lm_head.forward(&output.last_hidden_state);
354 Ok(logits)
355 }
356
357 pub fn generate(&self, input_ids: &Tensor, max_length: usize) -> Result<Vec<usize>, String> {
359 let mut generated = vec![0usize];
361
362 for _ in 0..max_length {
363 let decoder_input = Tensor::from_slice(
365 &generated.iter().map(|&x| x as f32).collect::<Vec<_>>(),
366 &[1, generated.len()]
367 ).map_err(|e| format!("Failed to create decoder input: {:?}", e))?;
368
369 let logits = self.forward(input_ids, &decoder_input)?;
371
372 let next_token = self.sample_next_token(&logits)?;
374
375 if next_token == 1 {
377 break;
378 }
379
380 generated.push(next_token);
381 }
382
383 Ok(generated)
384 }
385
386 fn sample_next_token(&self, logits: &Tensor) -> Result<usize, String> {
388 let data = logits.data_f32();
389 let dims = logits.dims();
390
391 if dims.len() != 3 {
392 return Err(format!("Expected 3D logits, got {}D", dims.len()));
393 }
394
395 let seq_length = dims[1];
396 let vocab_size = dims[2];
397
398 let start = (seq_length - 1) * vocab_size;
400 let end = start + vocab_size;
401 let last_logits = &data[start..end];
402
403 let next_token = last_logits.iter()
405 .enumerate()
406 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
407 .map(|(idx, _)| idx)
408 .ok_or_else(|| "Failed to sample token".to_string())?;
409
410 Ok(next_token)
411 }
412}
413
414pub struct T5ForSequenceClassification {
416 t5: T5Model,
418 classifier: Linear,
420 num_labels: usize,
422}
423
424impl T5ForSequenceClassification {
425 pub fn new(config: T5Config, num_labels: usize) -> Self {
427 let t5 = T5Model::new(config.clone());
428 let classifier = Linear::new(config.d_model, num_labels);
429
430 T5ForSequenceClassification {
431 t5,
432 classifier,
433 num_labels,
434 }
435 }
436
437 pub fn forward(&self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor, String> {
439 let output = self.t5.forward(input_ids, decoder_input_ids)?;
440
441 let first_token = self.extract_first_token(&output.last_hidden_state)?;
443
444 let logits = self.classifier.forward(&first_token);
445 Ok(logits)
446 }
447
448 fn extract_first_token(&self, hidden_states: &Tensor) -> Result<Tensor, String> {
450 let data = hidden_states.data_f32();
451 let dims = hidden_states.dims();
452
453 if dims.len() != 3 {
454 return Err(format!("Expected 3D hidden states, got {}D", dims.len()));
455 }
456
457 let batch_size = dims[0];
458 let d_model = dims[2];
459
460 let mut result = Vec::with_capacity(batch_size * d_model);
461
462 for b in 0..batch_size {
463 let start = b * dims[1] * d_model;
464 let end = start + d_model;
465 result.extend_from_slice(&data[start..end]);
466 }
467
468 Tensor::from_slice(&result, &[batch_size, d_model])
469 .map_err(|e| format!("Failed to extract first token: {:?}", e))
470 }
471}
472
473#[cfg(test)]
474mod tests {
475 use super::*;
476
477 #[test]
478 fn test_t5_config() {
479 let config = T5Config::t5_small();
480 assert_eq!(config.d_model, 512);
481 assert_eq!(config.num_encoder_layers, 6);
482
483 let config = T5Config::t5_base();
484 assert_eq!(config.d_model, 768);
485 assert_eq!(config.num_encoder_layers, 12);
486
487 let config = T5Config::t5_large();
488 assert_eq!(config.d_model, 1024);
489 assert_eq!(config.num_encoder_layers, 24);
490 }
491
492 #[test]
493 fn test_t5_embeddings() {
494 let config = T5Config::t5_tiny();
495 let embeddings = T5Embeddings::new(config);
496
497 let input_ids = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
498 let output = embeddings.forward(&input_ids).unwrap();
499
500 assert_eq!(output.dims(), &[2, 2, 128]); }
502
503 #[test]
504 fn test_t5_model() {
505 let config = T5Config::t5_tiny();
506 let t5 = T5Model::new(config);
507
508 let input_ids = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
509 let decoder_input_ids = Tensor::from_slice(&[1.0, 2.0], &[2, 1]).unwrap();
510
511 let output = t5.forward(&input_ids, &decoder_input_ids).unwrap();
512
513 assert_eq!(output.last_hidden_state.dims(), &[2, 1, 128]); }
515
516 #[test]
517 fn test_t5_for_conditional_generation() {
518 let config = T5Config::t5_tiny();
519 let t5 = T5ForConditionalGeneration::new(config.clone());
520
521 let input_ids = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
522 let decoder_input_ids = Tensor::from_slice(&[1.0, 2.0], &[2, 1]).unwrap();
523
524 let output = t5.forward(&input_ids, &decoder_input_ids).unwrap();
525
526 assert_eq!(output.dims(), &[2, 1, 1000]); }
528}