candle_transformers/models/
t5.rs

1//! T5 model implementation.
2//!
3//! T5 (Text-to-Text Transfer Transformer) is a unified text-to-text transformer model.
4//! This implementation follows the original model architecture.
5//!
6//! Key characteristics:
7//! - Text-to-text framework
8//! - Relative positional embeddings
9//! - T5-specific layer normalization
10//! - Encoder-decoder architecture
11//! - Support for sequence-to-sequence tasks
12//!
13//! References:
14//! - ⚡ [Interactive Wasm Example](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm)
15//! - 💻[GH Model](https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py)
16//! - 🤗 [HF Link](https://huggingface.co/docs/transformers/model_doc/t5)
17//! - 📝 [T5 Paper](https://arxiv.org/abs/1910.10683)
18//!
19//! # Encoder-decoder example:
20//!
21//! ```bash
22//! cargo run --example t5 --release -- \
23//!   --model-id "t5-small" \
24//!   --prompt "translate to German: A beautiful candle." \
25//!   --decode
26//! > ...
27//! >  Eine schöne Kerze.
28//! > 9 tokens generated (2.42 token/s)
29//! ```
30//!
31//! Variants such as [flan-t5](https://huggingface.co/google/flan-t5-small), [flan-ul2](https://huggingface.co/google/flan-ul2) (with `--revision "refs/pr/25"`), and [Co-EdIT](https://huggingface.co/grammarly/coedit-large) are also supported.
32//!
33//! # Translation with MADLAD
34//!
35//!
36//! [MADLAD-400](https://arxiv.org/abs/2309.04662) is a series of multilingual machine translation T5 models trained on 250 billion tokens covering over 450 languages using publicly available data. These models are competitive with significantly larger models.
37//!
38//! ```bash
39//! cargo run --example t5 --release  -- \
40//!   --model-id "jbochi/madlad400-3b-mt" \
41//!   --prompt "<2de> How are you, my friend?" \
42//!   --decode --temperature 0
43//! ...
44//!  Wie geht es dir, mein Freund?
45//! ```
46//!
47//! ## Sentence embedding example
48//!
49//! ```bash
50//! cargo run --example t5 --release -- \
51//!   --model-id "t5-small" --prompt "A beautiful candle."
52//! ...
53//! [[[ 0.0515, -0.0541, -0.0761, ..., -0.0392,  0.1511, -0.0265],
54//!   [-0.0974,  0.0998, -0.1659, ..., -0.2450,  0.1738, -0.0164],
55//!   [ 0.0624, -0.1024,  0.0430, ..., -0.1388,  0.0564, -0.2962],
56//!   [-0.0389, -0.1173,  0.0026, ...,  0.1064, -0.1065,  0.0990],
57//!   [ 0.1300,  0.0027, -0.0326, ...,  0.0026, -0.0317,  0.0851]]]
58//! Tensor[[1, 5, 512], f32]
59//! Took 303.766583ms
60//! ```
61
62use crate::models::with_tracing::Embedding;
63use candle::{DType, Device, Module, Result, Tensor, D};
64use candle_nn::{Activation, VarBuilder};
65use serde::Deserialize;
66use std::sync::Arc;
67
68#[derive(Debug, Clone)]
69pub struct Linear {
70    weight: Tensor,
71    span: tracing::Span,
72}
73
74pub fn linear_no_bias(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> {
75    let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
76    let weight = vb.get_with_hints((d2, d1), "weight", init_ws)?;
77    let span = tracing::span!(tracing::Level::TRACE, "linear");
78    Ok(Linear { weight, span })
79}
80
81impl Module for Linear {
82    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
83        let _enter = self.span.enter();
84        let weight = self.weight.to_dtype(xs.dtype())?;
85        let w = match *xs.dims() {
86            [b1, b2, _, _] => weight.broadcast_left((b1, b2))?.t()?,
87            [bsize, _, _] => weight.broadcast_left(bsize)?.t()?,
88            _ => weight.t()?,
89        };
90        xs.matmul(&w)
91    }
92}
93
94fn default_relative_attention_max_distance() -> usize {
95    128
96}
97
98fn default_is_decoder() -> bool {
99    false
100}
101
102fn default_use_cache() -> bool {
103    true
104}
105
106fn default_tie_word_embeddings() -> bool {
107    true
108}
109
110fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
111    let mask: Vec<_> = (0..size)
112        .flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
113        .collect();
114    Tensor::from_slice(&mask, (size, size), device)
115}
116
117fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
118    let shape = mask.shape();
119    let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
120    let m = mask.where_cond(&on_true, on_false)?;
121    Ok(m)
122}
123
124#[derive(Debug, Deserialize, Default, Clone, PartialEq)]
125pub struct ActivationWithOptionalGating {
126    pub gated: bool,
127    pub activation: candle_nn::Activation,
128}
129
130pub fn deserialize_feed_forward_proj_activation<'de, D>(
131    deserializer: D,
132) -> std::result::Result<ActivationWithOptionalGating, D::Error>
133where
134    D: serde::de::Deserializer<'de>,
135{
136    match String::deserialize(deserializer)?.as_str() {
137        "gated-gelu" => Ok(ActivationWithOptionalGating {
138            gated: true,
139            activation: candle_nn::Activation::NewGelu,
140        }),
141        "gated-silu" => Ok(ActivationWithOptionalGating {
142            gated: true,
143            activation: candle_nn::Activation::Silu,
144        }),
145        buf => {
146            let activation = serde_plain::from_str(buf).map_err(serde::de::Error::custom)?;
147            Ok(ActivationWithOptionalGating {
148                gated: false,
149                activation,
150            })
151        }
152    }
153}
154
155#[derive(Debug, Clone, PartialEq, Deserialize)]
156pub struct Config {
157    pub vocab_size: usize,
158    pub d_model: usize,
159    pub d_kv: usize,
160    pub d_ff: usize,
161    pub num_layers: usize,
162    pub num_decoder_layers: Option<usize>,
163    pub num_heads: usize,
164    pub relative_attention_num_buckets: usize,
165    #[serde(default = "default_relative_attention_max_distance")]
166    pub relative_attention_max_distance: usize,
167    pub dropout_rate: f64,
168    pub layer_norm_epsilon: f64,
169    pub initializer_factor: f64,
170    #[serde(default, deserialize_with = "deserialize_feed_forward_proj_activation")]
171    pub feed_forward_proj: ActivationWithOptionalGating,
172    #[serde(default = "default_tie_word_embeddings")]
173    pub tie_word_embeddings: bool,
174    #[serde(default = "default_is_decoder")]
175    pub is_decoder: bool,
176    pub is_encoder_decoder: bool,
177    #[serde(default = "default_use_cache")]
178    pub use_cache: bool,
179    pub pad_token_id: usize,
180    pub eos_token_id: usize,
181    pub decoder_start_token_id: Option<usize>,
182}
183
184impl Default for Config {
185    fn default() -> Self {
186        Self {
187            vocab_size: 32128,
188            d_model: 512,
189            d_kv: 64,
190            d_ff: 2048,
191            num_layers: 6,
192            num_decoder_layers: None,
193            num_heads: 8,
194            relative_attention_num_buckets: 32,
195            relative_attention_max_distance: 128,
196            dropout_rate: 0.1,
197            layer_norm_epsilon: 1e-6,
198            initializer_factor: 1.0,
199            feed_forward_proj: ActivationWithOptionalGating {
200                gated: false,
201                activation: Activation::Relu,
202            },
203            tie_word_embeddings: true,
204            is_decoder: false,
205            is_encoder_decoder: true,
206            use_cache: true,
207            pad_token_id: 0,
208            eos_token_id: 1,
209            decoder_start_token_id: Some(0),
210        }
211    }
212}
213
214impl Config {
215    // https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L184
216    pub fn musicgen_small() -> Self {
217        Self {
218            d_ff: 3072,
219            d_kv: 64,
220            d_model: 768,
221            dropout_rate: 0.1,
222            eos_token_id: 1,
223            feed_forward_proj: ActivationWithOptionalGating {
224                gated: false,
225                activation: Activation::Relu,
226            },
227            tie_word_embeddings: true,
228            initializer_factor: 1.0,
229            is_decoder: false,
230            is_encoder_decoder: true,
231            layer_norm_epsilon: 1e-6,
232            num_decoder_layers: Some(12),
233            num_heads: 12,
234            num_layers: 12,
235            pad_token_id: 0,
236            decoder_start_token_id: Some(0),
237            relative_attention_max_distance: 128,
238            relative_attention_num_buckets: 32,
239            use_cache: true,
240            vocab_size: 32128,
241        }
242    }
243}
244
245#[derive(Debug, Clone)]
246struct T5LayerNorm {
247    weight: Tensor,
248    variance_epsilon: f64,
249    span: tracing::Span,
250}
251
252impl T5LayerNorm {
253    fn load(h: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
254        let weight = vb.get(h, "weight")?;
255        Ok(Self {
256            weight,
257            variance_epsilon: eps,
258            span: tracing::span!(tracing::Level::TRACE, "layer-norm"),
259        })
260    }
261}
262
263impl Module for T5LayerNorm {
264    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
265        let _enter = self.span.enter();
266        let dtype = xs.dtype();
267        let xs_f32 = xs.to_dtype(DType::F32)?;
268        // variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
269        let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?;
270        let xs = xs_f32.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?;
271        let xs = xs.to_dtype(dtype)?;
272        let xs = xs.broadcast_mul(&self.weight.to_dtype(dtype)?)?;
273        Ok(xs)
274    }
275}
276
277#[derive(Debug, Clone)]
278struct T5DenseActDense {
279    wi: Linear,
280    wo: Linear,
281    act: Activation,
282    span: tracing::Span,
283}
284
285impl T5DenseActDense {
286    fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
287        let wi = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi"))?;
288        let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
289        Ok(Self {
290            wi,
291            wo,
292            act: Activation::Relu,
293            span: tracing::span!(tracing::Level::TRACE, "dense-act-dense"),
294        })
295    }
296}
297
298impl Module for T5DenseActDense {
299    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
300        let _enter = self.span.enter();
301        let xs = self.wi.forward(xs)?;
302        let xs = self.act.forward(&xs)?;
303        let xs = self.wo.forward(&xs)?;
304        Ok(xs)
305    }
306}
307
308#[derive(Debug, Clone)]
309struct T5DenseGatedActDense {
310    wi_0: Linear,
311    wi_1: Linear,
312    wo: Linear,
313    act: Activation,
314    span: tracing::Span,
315}
316
317impl T5DenseGatedActDense {
318    fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
319        let wi_0 = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?;
320        let wi_1 = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?;
321        let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
322        Ok(Self {
323            wi_0,
324            wi_1,
325            wo,
326            act: cfg.feed_forward_proj.activation,
327            span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"),
328        })
329    }
330}
331
332impl Module for T5DenseGatedActDense {
333    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
334        let _enter = self.span.enter();
335        let hidden_gelu = self.act.forward(&self.wi_0.forward(xs)?)?;
336        let hidden_linear = self.wi_1.forward(xs)?;
337        let xs = hidden_gelu.broadcast_mul(&hidden_linear)?;
338        let xs = self.wo.forward(&xs)?;
339        Ok(xs)
340    }
341}
342
343#[derive(Debug, Clone)]
344struct T5LayerFF {
345    dense_act: Option<T5DenseActDense>,
346    gated_dense_act: Option<T5DenseGatedActDense>,
347    layer_norm: T5LayerNorm,
348    span: tracing::Span,
349}
350
351impl T5LayerFF {
352    fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
353        let layer_norm =
354            T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
355        let (dense_act, gated_dense_act) = if cfg.feed_forward_proj.gated {
356            (
357                None,
358                Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),
359            )
360        } else {
361            (
362                Some(T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?),
363                None,
364            )
365        };
366        Ok(Self {
367            dense_act,
368            gated_dense_act,
369            layer_norm,
370            span: tracing::span!(tracing::Level::TRACE, "layer-ff"),
371        })
372    }
373}
374
375impl Module for T5LayerFF {
376    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
377        let _enter = self.span.enter();
378        let ys = self.layer_norm.forward(xs)?;
379        let ys = match &self.dense_act {
380            Some(dense_act) => dense_act.forward(&ys)?,
381            None => self.gated_dense_act.as_ref().unwrap().forward(&ys)?,
382        };
383        let xs = (xs + ys)?;
384        Ok(xs)
385    }
386}
387
388#[derive(Debug, Clone)]
389struct T5Attention {
390    q: Linear,
391    k: Linear,
392    v: Linear,
393    o: Linear,
394    n_heads: usize,
395    d_kv: usize,
396    relative_attention_bias: Option<Embedding>,
397    relative_attention_num_buckets: usize,
398    relative_attention_max_distance: usize,
399    inner_dim: usize,
400    use_cache: bool,
401    kv_cache: Option<(Tensor, Tensor)>,
402    span: tracing::Span,
403    span_cache: tracing::Span,
404    span_mm: tracing::Span,
405    span_sm: tracing::Span,
406}
407
408impl T5Attention {
409    fn load(
410        has_relative_attention_bias: bool,
411        decoder: bool,
412        vb: VarBuilder,
413        cfg: &Config,
414    ) -> Result<Self> {
415        let inner_dim = cfg.num_heads * cfg.d_kv;
416        let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?;
417        let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?;
418        let v = linear_no_bias(cfg.d_model, inner_dim, vb.pp("v"))?;
419        let o = linear_no_bias(inner_dim, cfg.d_model, vb.pp("o"))?;
420        let relative_attention_bias = if has_relative_attention_bias {
421            let emb = Embedding::new(
422                cfg.relative_attention_num_buckets,
423                cfg.num_heads,
424                vb.pp("relative_attention_bias"),
425            )?;
426            Some(emb)
427        } else {
428            None
429        };
430        Ok(Self {
431            q,
432            k,
433            v,
434            o,
435            n_heads: cfg.num_heads,
436            d_kv: cfg.d_kv,
437            relative_attention_bias,
438            relative_attention_num_buckets: cfg.relative_attention_num_buckets,
439            relative_attention_max_distance: cfg.relative_attention_max_distance,
440            inner_dim,
441            use_cache: cfg.use_cache && decoder,
442            kv_cache: None,
443            span: tracing::span!(tracing::Level::TRACE, "attention"),
444            span_cache: tracing::span!(tracing::Level::TRACE, "attention-cache"),
445            span_mm: tracing::span!(tracing::Level::TRACE, "attention-mm"),
446            span_sm: tracing::span!(tracing::Level::TRACE, "attention-sm"),
447        })
448    }
449
450    fn forward(
451        &mut self,
452        xs: &Tensor,
453        position_bias: Option<&Tensor>,
454        key_value_states: Option<&Tensor>,
455        mask: Option<&Tensor>,
456    ) -> Result<(Tensor, Option<Tensor>)> {
457        // Performs Self-attention (if key_value_states is None) or attention
458        // over source sentence (provided by key_value_states).
459        let _enter = self.span.enter();
460        let kv_input = match key_value_states {
461            None => xs,
462            Some(key_value_states) => key_value_states,
463        };
464        let (b_sz, q_len) = (xs.dim(0)?, xs.dim(1)?);
465        let kv_len = kv_input.dim(1)?;
466        let q = self.q.forward(xs)?;
467        let k = self.k.forward(kv_input)?;
468        let v = self.v.forward(kv_input)?;
469        let q = q
470            .reshape((b_sz, q_len, self.n_heads, self.d_kv))?
471            .transpose(1, 2)?
472            .contiguous()?;
473        let mut k = k
474            .reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
475            .transpose(1, 2)?;
476        let mut v = v
477            .reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
478            .transpose(1, 2)?;
479
480        if self.use_cache && key_value_states.is_none() {
481            let _enter = self.span_cache.enter();
482            if let Some((kv_cache_k, kv_cache_v)) = &self.kv_cache {
483                k = Tensor::cat(&[kv_cache_k, &k], 2)?;
484                v = Tensor::cat(&[kv_cache_v, &v], 2)?;
485            };
486            self.kv_cache = Some((k.clone(), v.clone()));
487        };
488        let k = k.contiguous()?;
489        let v = v.contiguous()?;
490        // TODO: Use flash_attn.
491        let scores = {
492            let _enter = self.span_mm.enter();
493            q.matmul(&k.t()?)?
494        };
495        let scores = match mask {
496            None => scores,
497            Some(mask) => masked_fill(
498                &scores,
499                &mask
500                    .unsqueeze(0)?
501                    .unsqueeze(0)?
502                    .repeat((b_sz, self.n_heads))?,
503                f32::NEG_INFINITY,
504            )?,
505        };
506
507        let (scores, position_bias) = match position_bias {
508            Some(position_bias) => (
509                scores.broadcast_add(position_bias)?,
510                Some(position_bias.clone()),
511            ),
512            None => match &self.relative_attention_bias {
513                None => (scores, None),
514                Some(relative_attention_bias) => {
515                    // This only handles the bidirectional case.
516                    let kv_len = k.dim(2)?;
517                    let (q_start, q_end) = match self.use_cache {
518                        true => ((kv_len - q_len) as u32, kv_len as u32),
519                        false => (0_u32, kv_len as u32),
520                    };
521                    let num_buckets = self.relative_attention_num_buckets as u32 / 2;
522                    let max_exact = num_buckets / 2;
523                    let relative_position = (q_start..q_end)
524                        .map(|i| {
525                            (0..kv_len as u32)
526                                .map(|j| {
527                                    if i < j {
528                                        if j - i < max_exact {
529                                            j - i + num_buckets
530                                        } else {
531                                            let b = f32::log(
532                                                (j - i) as f32 / max_exact as f32,
533                                                self.relative_attention_max_distance as f32
534                                                    / max_exact as f32,
535                                            ) * (num_buckets - max_exact) as f32;
536                                            u32::min(
537                                                max_exact + num_buckets + b as u32,
538                                                self.relative_attention_num_buckets as u32 - 1,
539                                            )
540                                        }
541                                    } else if i - j < max_exact {
542                                        i - j
543                                    } else {
544                                        let b = f32::log(
545                                            (i - j) as f32 / max_exact as f32,
546                                            self.relative_attention_max_distance as f32
547                                                / max_exact as f32,
548                                        ) * (num_buckets - max_exact) as f32;
549                                        u32::min(max_exact + b as u32, num_buckets - 1)
550                                    }
551                                })
552                                .collect::<Vec<u32>>()
553                        })
554                        .collect::<Vec<Vec<_>>>();
555                    let relative_buckets = Tensor::new(relative_position, q.device())?;
556                    let position_bias = relative_attention_bias
557                        .forward(&relative_buckets)?
558                        .permute((2, 0, 1))?
559                        .unsqueeze(0)?
560                        .to_dtype(scores.dtype())?;
561                    (scores.broadcast_add(&position_bias)?, Some(position_bias))
562                    // TODO: position_bias_masked?
563                }
564            },
565        };
566
567        let attn_weights = {
568            let _enter = self.span_sm.enter();
569            candle_nn::ops::softmax_last_dim(&scores)?
570        };
571        let attn_output = attn_weights.matmul(&v)?;
572        let attn_output = attn_output
573            .transpose(1, 2)?
574            .reshape((b_sz, q_len, self.inner_dim))?;
575        let attn_output = self.o.forward(&attn_output)?;
576        Ok((attn_output, position_bias))
577    }
578
579    fn clear_kv_cache(&mut self) {
580        self.kv_cache = None
581    }
582}
583
584#[derive(Debug, Clone)]
585struct T5LayerSelfAttention {
586    self_attention: T5Attention,
587    layer_norm: T5LayerNorm,
588    span: tracing::Span,
589}
590
591impl T5LayerSelfAttention {
592    fn load(h: bool, d: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
593        let self_attention = T5Attention::load(h, d, vb.pp("SelfAttention"), cfg)?;
594        let layer_norm =
595            T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
596        Ok(Self {
597            self_attention,
598            layer_norm,
599            span: tracing::span!(tracing::Level::TRACE, "self-attn"),
600        })
601    }
602
603    fn forward(
604        &mut self,
605        xs: &Tensor,
606        position_bias: Option<&Tensor>,
607        mask: Option<&Tensor>,
608    ) -> Result<(Tensor, Option<Tensor>)> {
609        let _enter = self.span.enter();
610        let normed_xs = self.layer_norm.forward(xs)?;
611        let (ys, position_bias) =
612            self.self_attention
613                .forward(&normed_xs, position_bias, None, mask)?;
614        let ys = (xs + ys)?;
615        Ok((ys, position_bias))
616    }
617
618    fn clear_kv_cache(&mut self) {
619        self.self_attention.clear_kv_cache()
620    }
621}
622
623#[derive(Debug, Clone)]
624struct T5LayerCrossAttention {
625    cross_attention: T5Attention,
626    layer_norm: T5LayerNorm,
627    span: tracing::Span,
628}
629
630impl T5LayerCrossAttention {
631    fn load(decoder: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
632        let cross_attention = T5Attention::load(false, decoder, vb.pp("EncDecAttention"), cfg)?;
633        let layer_norm =
634            T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
635        Ok(Self {
636            cross_attention,
637            layer_norm,
638            span: tracing::span!(tracing::Level::TRACE, "cross-attn"),
639        })
640    }
641
642    fn forward(
643        &mut self,
644        hidden_states: &Tensor,
645        position_bias: Option<&Tensor>,
646        key_value_states: &Tensor,
647    ) -> Result<(Tensor, Option<Tensor>)> {
648        let _enter = self.span.enter();
649        let normed_hidden_states = self.layer_norm.forward(hidden_states)?;
650        let (ys, position_bias) = self.cross_attention.forward(
651            &normed_hidden_states,
652            position_bias,
653            Some(key_value_states),
654            None,
655        )?;
656        let ys = (hidden_states + ys)?;
657        Ok((ys, position_bias))
658    }
659
660    fn clear_kv_cache(&mut self) {
661        self.cross_attention.clear_kv_cache()
662    }
663}
664
665#[derive(Debug, Clone)]
666struct T5Block {
667    self_attn: T5LayerSelfAttention,
668    cross_attn: Option<T5LayerCrossAttention>,
669    ff: T5LayerFF,
670    span: tracing::Span,
671}
672
673impl T5Block {
674    fn load(
675        has_relative_attention_bias: bool,
676        decoder: bool,
677        vb: VarBuilder,
678        cfg: &Config,
679    ) -> Result<Self> {
680        let vb = vb.pp("layer");
681        let self_attn =
682            T5LayerSelfAttention::load(has_relative_attention_bias, decoder, vb.pp("0"), cfg)?;
683        let cross_attn = if cfg.is_decoder {
684            Some(T5LayerCrossAttention::load(decoder, vb.pp("1"), cfg)?)
685        } else {
686            None
687        };
688        let ff_i = if cross_attn.is_some() { 2 } else { 1 };
689        let ff = T5LayerFF::load(vb.pp(ff_i.to_string()), cfg)?;
690        Ok(Self {
691            self_attn,
692            cross_attn,
693            ff,
694            span: tracing::span!(tracing::Level::TRACE, "block"),
695        })
696    }
697
698    fn forward(
699        &mut self,
700        xs: &Tensor,
701        position_bias: Option<&Tensor>,
702        encoder_hidden_states: Option<&Tensor>,
703    ) -> Result<(Tensor, Option<Tensor>)> {
704        let _enter = self.span.enter();
705        // TODO: Cache masks
706        let mask = match self.cross_attn.is_some() {
707            true => {
708                let mask_len = xs.dim(1)?;
709                // If the input seq length is 1, no need for a mask, this is also helpful to avoid shape
710                // issues when using the KV cache in the decoder.
711                if mask_len <= 1 {
712                    None
713                } else {
714                    Some(get_mask(mask_len, xs.device())?)
715                }
716            }
717            false => None,
718        };
719        let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias, mask.as_ref())?;
720        // TODO: clamp for f16?
721        if let Some(cross_attn) = &mut self.cross_attn {
722            (xs, _) = cross_attn.forward(&xs, None, encoder_hidden_states.unwrap())?;
723            // TODO: clamp for f16?
724        }
725        let xs = self.ff.forward(&xs)?;
726        // TODO: clamp for f16?
727        Ok((xs, position_bias))
728    }
729
730    fn clear_kv_cache(&mut self) {
731        self.self_attn.clear_kv_cache();
732        self.cross_attn.iter_mut().for_each(|c| c.clear_kv_cache());
733    }
734}
735
736#[derive(Debug, Clone)]
737struct T5Stack {
738    block: Vec<T5Block>,
739    shared: Arc<Embedding>,
740    final_layer_norm: T5LayerNorm,
741    span: tracing::Span,
742}
743
744impl T5Stack {
745    fn load(decoder: bool, vb: VarBuilder, shared: &Arc<Embedding>, cfg: &Config) -> Result<Self> {
746        let block = (0..cfg.num_layers)
747            .map(|i| T5Block::load(i == 0, decoder, vb.pp(format!("block.{i}")), cfg))
748            .collect::<Result<Vec<_>>>()?;
749        let final_layer_norm = T5LayerNorm::load(
750            cfg.d_model,
751            cfg.layer_norm_epsilon,
752            vb.pp("final_layer_norm"),
753        )?;
754        Ok(Self {
755            block,
756            shared: shared.clone(),
757            final_layer_norm,
758            span: tracing::span!(tracing::Level::TRACE, "stack"),
759        })
760    }
761
762    fn forward(
763        &mut self,
764        input_ids: &Tensor,
765        encoder_hidden_states: Option<&Tensor>,
766    ) -> Result<Tensor> {
767        self.forward_dt(input_ids, encoder_hidden_states, None)
768    }
769
770    fn forward_dt(
771        &mut self,
772        input_ids: &Tensor,
773        encoder_hidden_states: Option<&Tensor>,
774        dtype: Option<DType>,
775    ) -> Result<Tensor> {
776        let _enter = self.span.enter();
777        let input_embeds = self.shared.as_ref().forward(input_ids)?;
778        let input_embeds = match dtype {
779            None => input_embeds,
780            Some(dtype) => input_embeds.to_dtype(dtype)?,
781        };
782        let mut hidden_states = input_embeds;
783        let mut position_bias = None;
784        for block in self.block.iter_mut() {
785            (hidden_states, position_bias) = block.forward(
786                &hidden_states,
787                position_bias.as_ref(),
788                encoder_hidden_states,
789            )?
790        }
791        self.final_layer_norm.forward(&hidden_states)
792    }
793
794    fn clear_kv_cache(&mut self) {
795        self.block.iter_mut().for_each(|b| b.clear_kv_cache())
796    }
797}
798
799#[derive(Debug, Clone)]
800pub struct T5EncoderModel {
801    encoder: T5Stack,
802    device: Device,
803    span: tracing::Span,
804}
805
806impl T5EncoderModel {
807    pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
808        let shared_vb = if vb.contains_tensor("shared.weight") {
809            vb.pp("shared")
810        } else if vb.contains_tensor("decoder.embed_tokens") {
811            vb.pp("decoder").pp("embed_tokens")
812        } else {
813            vb.pp("encoder").pp("embed_tokens")
814        };
815        let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
816        let shared = Arc::new(shared);
817        let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
818        Ok(Self {
819            encoder,
820            device: vb.device().clone(),
821            span: tracing::span!(tracing::Level::TRACE, "encoder"),
822        })
823    }
824
825    pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
826        let _enter = self.span.enter();
827        self.encoder.forward(input_ids, None)
828    }
829
830    pub fn forward_dt(&mut self, input_ids: &Tensor, dtype: Option<DType>) -> Result<Tensor> {
831        let _enter = self.span.enter();
832        self.encoder.forward_dt(input_ids, None, dtype)
833    }
834
835    pub fn device(&self) -> &Device {
836        &self.device
837    }
838
839    pub fn clear_kv_cache(&mut self) {
840        self.encoder.clear_kv_cache()
841    }
842}
843
844#[derive(Debug, Clone)]
845pub struct T5ForConditionalGeneration {
846    encoder: T5Stack,
847    decoder: T5Stack,
848    d_model: usize,
849    tie_word_embeddings: bool,
850    lm_head: Option<Linear>,
851    shared: Arc<Embedding>,
852    device: Device,
853    span_decode: tracing::Span,
854    span_decode_head: tracing::Span,
855}
856
857impl T5ForConditionalGeneration {
858    pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
859        assert!(cfg.is_encoder_decoder);
860        let d_model = cfg.d_model;
861        let shared_vb = if vb.contains_tensor("shared.weight") {
862            vb.pp("shared")
863        } else {
864            vb.pp("decoder").pp("embed_tokens")
865        };
866        let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
867        let shared = Arc::new(shared);
868
869        let mut encoder_cfg = cfg.clone();
870        encoder_cfg.is_decoder = false;
871        encoder_cfg.use_cache = false;
872        encoder_cfg.is_encoder_decoder = false;
873        let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, &encoder_cfg)?;
874
875        let mut decoder_cfg = cfg.clone();
876        decoder_cfg.is_decoder = true;
877        decoder_cfg.is_encoder_decoder = false;
878        decoder_cfg.num_layers = cfg.num_decoder_layers.unwrap_or(cfg.num_layers);
879        let decoder = T5Stack::load(true, vb.pp("decoder"), &shared, &decoder_cfg)?;
880
881        let tie_word_embeddings = cfg.tie_word_embeddings;
882        let lm_head = if tie_word_embeddings {
883            None
884        } else {
885            Some(linear_no_bias(
886                cfg.d_model,
887                cfg.vocab_size,
888                vb.pp("lm_head"),
889            )?)
890        };
891
892        Ok(Self {
893            encoder,
894            decoder,
895            d_model,
896            tie_word_embeddings,
897            lm_head,
898            shared,
899            device: vb.device().clone(),
900            span_decode: tracing::span!(tracing::Level::TRACE, "decode"),
901            span_decode_head: tracing::span!(tracing::Level::TRACE, "decode-head"),
902        })
903    }
904
905    pub fn encode(&mut self, input_ids: &Tensor) -> Result<Tensor> {
906        self.encoder.forward(input_ids, None)
907    }
908
909    pub fn decode(
910        &mut self,
911        decoder_input_ids: &Tensor,
912        encoder_output: &Tensor,
913    ) -> Result<Tensor> {
914        let _enter = self.span_decode.enter();
915        let decoder_output = self
916            .decoder
917            .forward(decoder_input_ids, Some(encoder_output))?;
918
919        let scaling_factor = if self.tie_word_embeddings {
920            // Rescale output before projecting on vocab
921            // See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
922            (self.d_model as f64).sqrt()
923        } else {
924            1.0
925        };
926        let sequence_output = ((decoder_output
927            .narrow(1, decoder_output.dim(1)? - 1, 1)?
928            .squeeze(1)?)
929            * scaling_factor)?;
930        let output = {
931            let _enter = self.span_decode_head.enter();
932            match self.lm_head {
933                None => sequence_output.matmul(&self.shared.embeddings().t()?)?,
934                Some(ref lm_head) => lm_head.forward(&sequence_output)?,
935            }
936        };
937        Ok(output)
938    }
939
940    pub fn forward(&mut self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor> {
941        let encoder_output = self.encode(input_ids)?;
942        self.decode(decoder_input_ids, &encoder_output)
943    }
944
945    pub fn device(&self) -> &Device {
946        &self.device
947    }
948
949    pub fn clear_kv_cache(&mut self) {
950        self.encoder.clear_kv_cache();
951        self.decoder.clear_kv_cache();
952    }
953}