1use crate::models::with_tracing::Embedding;
63use candle::{DType, Device, Module, Result, Tensor, D};
64use candle_nn::{Activation, VarBuilder};
65use serde::Deserialize;
66use std::sync::Arc;
67
68#[derive(Debug, Clone)]
69pub struct Linear {
70 weight: Tensor,
71 span: tracing::Span,
72}
73
74pub fn linear_no_bias(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> {
75 let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
76 let weight = vb.get_with_hints((d2, d1), "weight", init_ws)?;
77 let span = tracing::span!(tracing::Level::TRACE, "linear");
78 Ok(Linear { weight, span })
79}
80
81impl Module for Linear {
82 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
83 let _enter = self.span.enter();
84 let weight = self.weight.to_dtype(xs.dtype())?;
85 let w = match *xs.dims() {
86 [b1, b2, _, _] => weight.broadcast_left((b1, b2))?.t()?,
87 [bsize, _, _] => weight.broadcast_left(bsize)?.t()?,
88 _ => weight.t()?,
89 };
90 xs.matmul(&w)
91 }
92}
93
94fn default_relative_attention_max_distance() -> usize {
95 128
96}
97
98fn default_is_decoder() -> bool {
99 false
100}
101
102fn default_use_cache() -> bool {
103 true
104}
105
106fn default_tie_word_embeddings() -> bool {
107 true
108}
109
110fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
111 let mask: Vec<_> = (0..size)
112 .flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
113 .collect();
114 Tensor::from_slice(&mask, (size, size), device)
115}
116
117fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
118 let shape = mask.shape();
119 let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
120 let m = mask.where_cond(&on_true, on_false)?;
121 Ok(m)
122}
123
124#[derive(Debug, Deserialize, Default, Clone, PartialEq)]
125pub struct ActivationWithOptionalGating {
126 pub gated: bool,
127 pub activation: candle_nn::Activation,
128}
129
130pub fn deserialize_feed_forward_proj_activation<'de, D>(
131 deserializer: D,
132) -> std::result::Result<ActivationWithOptionalGating, D::Error>
133where
134 D: serde::de::Deserializer<'de>,
135{
136 match String::deserialize(deserializer)?.as_str() {
137 "gated-gelu" => Ok(ActivationWithOptionalGating {
138 gated: true,
139 activation: candle_nn::Activation::NewGelu,
140 }),
141 "gated-silu" => Ok(ActivationWithOptionalGating {
142 gated: true,
143 activation: candle_nn::Activation::Silu,
144 }),
145 buf => {
146 let activation = serde_plain::from_str(buf).map_err(serde::de::Error::custom)?;
147 Ok(ActivationWithOptionalGating {
148 gated: false,
149 activation,
150 })
151 }
152 }
153}
154
155#[derive(Debug, Clone, PartialEq, Deserialize)]
156pub struct Config {
157 pub vocab_size: usize,
158 pub d_model: usize,
159 pub d_kv: usize,
160 pub d_ff: usize,
161 pub num_layers: usize,
162 pub num_decoder_layers: Option<usize>,
163 pub num_heads: usize,
164 pub relative_attention_num_buckets: usize,
165 #[serde(default = "default_relative_attention_max_distance")]
166 pub relative_attention_max_distance: usize,
167 pub dropout_rate: f64,
168 pub layer_norm_epsilon: f64,
169 pub initializer_factor: f64,
170 #[serde(default, deserialize_with = "deserialize_feed_forward_proj_activation")]
171 pub feed_forward_proj: ActivationWithOptionalGating,
172 #[serde(default = "default_tie_word_embeddings")]
173 pub tie_word_embeddings: bool,
174 #[serde(default = "default_is_decoder")]
175 pub is_decoder: bool,
176 pub is_encoder_decoder: bool,
177 #[serde(default = "default_use_cache")]
178 pub use_cache: bool,
179 pub pad_token_id: usize,
180 pub eos_token_id: usize,
181 pub decoder_start_token_id: Option<usize>,
182}
183
184impl Default for Config {
185 fn default() -> Self {
186 Self {
187 vocab_size: 32128,
188 d_model: 512,
189 d_kv: 64,
190 d_ff: 2048,
191 num_layers: 6,
192 num_decoder_layers: None,
193 num_heads: 8,
194 relative_attention_num_buckets: 32,
195 relative_attention_max_distance: 128,
196 dropout_rate: 0.1,
197 layer_norm_epsilon: 1e-6,
198 initializer_factor: 1.0,
199 feed_forward_proj: ActivationWithOptionalGating {
200 gated: false,
201 activation: Activation::Relu,
202 },
203 tie_word_embeddings: true,
204 is_decoder: false,
205 is_encoder_decoder: true,
206 use_cache: true,
207 pad_token_id: 0,
208 eos_token_id: 1,
209 decoder_start_token_id: Some(0),
210 }
211 }
212}
213
214impl Config {
215 pub fn musicgen_small() -> Self {
217 Self {
218 d_ff: 3072,
219 d_kv: 64,
220 d_model: 768,
221 dropout_rate: 0.1,
222 eos_token_id: 1,
223 feed_forward_proj: ActivationWithOptionalGating {
224 gated: false,
225 activation: Activation::Relu,
226 },
227 tie_word_embeddings: true,
228 initializer_factor: 1.0,
229 is_decoder: false,
230 is_encoder_decoder: true,
231 layer_norm_epsilon: 1e-6,
232 num_decoder_layers: Some(12),
233 num_heads: 12,
234 num_layers: 12,
235 pad_token_id: 0,
236 decoder_start_token_id: Some(0),
237 relative_attention_max_distance: 128,
238 relative_attention_num_buckets: 32,
239 use_cache: true,
240 vocab_size: 32128,
241 }
242 }
243}
244
245#[derive(Debug, Clone)]
246struct T5LayerNorm {
247 weight: Tensor,
248 variance_epsilon: f64,
249 span: tracing::Span,
250}
251
252impl T5LayerNorm {
253 fn load(h: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
254 let weight = vb.get(h, "weight")?;
255 Ok(Self {
256 weight,
257 variance_epsilon: eps,
258 span: tracing::span!(tracing::Level::TRACE, "layer-norm"),
259 })
260 }
261}
262
263impl Module for T5LayerNorm {
264 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
265 let _enter = self.span.enter();
266 let dtype = xs.dtype();
267 let xs_f32 = xs.to_dtype(DType::F32)?;
268 let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?;
270 let xs = xs_f32.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?;
271 let xs = xs.to_dtype(dtype)?;
272 let xs = xs.broadcast_mul(&self.weight.to_dtype(dtype)?)?;
273 Ok(xs)
274 }
275}
276
277#[derive(Debug, Clone)]
278struct T5DenseActDense {
279 wi: Linear,
280 wo: Linear,
281 act: Activation,
282 span: tracing::Span,
283}
284
285impl T5DenseActDense {
286 fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
287 let wi = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi"))?;
288 let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
289 Ok(Self {
290 wi,
291 wo,
292 act: Activation::Relu,
293 span: tracing::span!(tracing::Level::TRACE, "dense-act-dense"),
294 })
295 }
296}
297
298impl Module for T5DenseActDense {
299 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
300 let _enter = self.span.enter();
301 let xs = self.wi.forward(xs)?;
302 let xs = self.act.forward(&xs)?;
303 let xs = self.wo.forward(&xs)?;
304 Ok(xs)
305 }
306}
307
308#[derive(Debug, Clone)]
309struct T5DenseGatedActDense {
310 wi_0: Linear,
311 wi_1: Linear,
312 wo: Linear,
313 act: Activation,
314 span: tracing::Span,
315}
316
317impl T5DenseGatedActDense {
318 fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
319 let wi_0 = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?;
320 let wi_1 = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?;
321 let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
322 Ok(Self {
323 wi_0,
324 wi_1,
325 wo,
326 act: cfg.feed_forward_proj.activation,
327 span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"),
328 })
329 }
330}
331
332impl Module for T5DenseGatedActDense {
333 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
334 let _enter = self.span.enter();
335 let hidden_gelu = self.act.forward(&self.wi_0.forward(xs)?)?;
336 let hidden_linear = self.wi_1.forward(xs)?;
337 let xs = hidden_gelu.broadcast_mul(&hidden_linear)?;
338 let xs = self.wo.forward(&xs)?;
339 Ok(xs)
340 }
341}
342
343#[derive(Debug, Clone)]
344struct T5LayerFF {
345 dense_act: Option<T5DenseActDense>,
346 gated_dense_act: Option<T5DenseGatedActDense>,
347 layer_norm: T5LayerNorm,
348 span: tracing::Span,
349}
350
351impl T5LayerFF {
352 fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
353 let layer_norm =
354 T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
355 let (dense_act, gated_dense_act) = if cfg.feed_forward_proj.gated {
356 (
357 None,
358 Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),
359 )
360 } else {
361 (
362 Some(T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?),
363 None,
364 )
365 };
366 Ok(Self {
367 dense_act,
368 gated_dense_act,
369 layer_norm,
370 span: tracing::span!(tracing::Level::TRACE, "layer-ff"),
371 })
372 }
373}
374
375impl Module for T5LayerFF {
376 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
377 let _enter = self.span.enter();
378 let ys = self.layer_norm.forward(xs)?;
379 let ys = match &self.dense_act {
380 Some(dense_act) => dense_act.forward(&ys)?,
381 None => self.gated_dense_act.as_ref().unwrap().forward(&ys)?,
382 };
383 let xs = (xs + ys)?;
384 Ok(xs)
385 }
386}
387
388#[derive(Debug, Clone)]
389struct T5Attention {
390 q: Linear,
391 k: Linear,
392 v: Linear,
393 o: Linear,
394 n_heads: usize,
395 d_kv: usize,
396 relative_attention_bias: Option<Embedding>,
397 relative_attention_num_buckets: usize,
398 relative_attention_max_distance: usize,
399 inner_dim: usize,
400 use_cache: bool,
401 kv_cache: Option<(Tensor, Tensor)>,
402 span: tracing::Span,
403 span_cache: tracing::Span,
404 span_mm: tracing::Span,
405 span_sm: tracing::Span,
406}
407
408impl T5Attention {
409 fn load(
410 has_relative_attention_bias: bool,
411 decoder: bool,
412 vb: VarBuilder,
413 cfg: &Config,
414 ) -> Result<Self> {
415 let inner_dim = cfg.num_heads * cfg.d_kv;
416 let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?;
417 let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?;
418 let v = linear_no_bias(cfg.d_model, inner_dim, vb.pp("v"))?;
419 let o = linear_no_bias(inner_dim, cfg.d_model, vb.pp("o"))?;
420 let relative_attention_bias = if has_relative_attention_bias {
421 let emb = Embedding::new(
422 cfg.relative_attention_num_buckets,
423 cfg.num_heads,
424 vb.pp("relative_attention_bias"),
425 )?;
426 Some(emb)
427 } else {
428 None
429 };
430 Ok(Self {
431 q,
432 k,
433 v,
434 o,
435 n_heads: cfg.num_heads,
436 d_kv: cfg.d_kv,
437 relative_attention_bias,
438 relative_attention_num_buckets: cfg.relative_attention_num_buckets,
439 relative_attention_max_distance: cfg.relative_attention_max_distance,
440 inner_dim,
441 use_cache: cfg.use_cache && decoder,
442 kv_cache: None,
443 span: tracing::span!(tracing::Level::TRACE, "attention"),
444 span_cache: tracing::span!(tracing::Level::TRACE, "attention-cache"),
445 span_mm: tracing::span!(tracing::Level::TRACE, "attention-mm"),
446 span_sm: tracing::span!(tracing::Level::TRACE, "attention-sm"),
447 })
448 }
449
450 fn forward(
451 &mut self,
452 xs: &Tensor,
453 position_bias: Option<&Tensor>,
454 key_value_states: Option<&Tensor>,
455 mask: Option<&Tensor>,
456 ) -> Result<(Tensor, Option<Tensor>)> {
457 let _enter = self.span.enter();
460 let kv_input = match key_value_states {
461 None => xs,
462 Some(key_value_states) => key_value_states,
463 };
464 let (b_sz, q_len) = (xs.dim(0)?, xs.dim(1)?);
465 let kv_len = kv_input.dim(1)?;
466 let q = self.q.forward(xs)?;
467 let k = self.k.forward(kv_input)?;
468 let v = self.v.forward(kv_input)?;
469 let q = q
470 .reshape((b_sz, q_len, self.n_heads, self.d_kv))?
471 .transpose(1, 2)?
472 .contiguous()?;
473 let mut k = k
474 .reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
475 .transpose(1, 2)?;
476 let mut v = v
477 .reshape((b_sz, kv_len, self.n_heads, self.d_kv))?
478 .transpose(1, 2)?;
479
480 if self.use_cache && key_value_states.is_none() {
481 let _enter = self.span_cache.enter();
482 if let Some((kv_cache_k, kv_cache_v)) = &self.kv_cache {
483 k = Tensor::cat(&[kv_cache_k, &k], 2)?;
484 v = Tensor::cat(&[kv_cache_v, &v], 2)?;
485 };
486 self.kv_cache = Some((k.clone(), v.clone()));
487 };
488 let k = k.contiguous()?;
489 let v = v.contiguous()?;
490 let scores = {
492 let _enter = self.span_mm.enter();
493 q.matmul(&k.t()?)?
494 };
495 let scores = match mask {
496 None => scores,
497 Some(mask) => masked_fill(
498 &scores,
499 &mask
500 .unsqueeze(0)?
501 .unsqueeze(0)?
502 .repeat((b_sz, self.n_heads))?,
503 f32::NEG_INFINITY,
504 )?,
505 };
506
507 let (scores, position_bias) = match position_bias {
508 Some(position_bias) => (
509 scores.broadcast_add(position_bias)?,
510 Some(position_bias.clone()),
511 ),
512 None => match &self.relative_attention_bias {
513 None => (scores, None),
514 Some(relative_attention_bias) => {
515 let kv_len = k.dim(2)?;
517 let (q_start, q_end) = match self.use_cache {
518 true => ((kv_len - q_len) as u32, kv_len as u32),
519 false => (0_u32, kv_len as u32),
520 };
521 let num_buckets = self.relative_attention_num_buckets as u32 / 2;
522 let max_exact = num_buckets / 2;
523 let relative_position = (q_start..q_end)
524 .map(|i| {
525 (0..kv_len as u32)
526 .map(|j| {
527 if i < j {
528 if j - i < max_exact {
529 j - i + num_buckets
530 } else {
531 let b = f32::log(
532 (j - i) as f32 / max_exact as f32,
533 self.relative_attention_max_distance as f32
534 / max_exact as f32,
535 ) * (num_buckets - max_exact) as f32;
536 u32::min(
537 max_exact + num_buckets + b as u32,
538 self.relative_attention_num_buckets as u32 - 1,
539 )
540 }
541 } else if i - j < max_exact {
542 i - j
543 } else {
544 let b = f32::log(
545 (i - j) as f32 / max_exact as f32,
546 self.relative_attention_max_distance as f32
547 / max_exact as f32,
548 ) * (num_buckets - max_exact) as f32;
549 u32::min(max_exact + b as u32, num_buckets - 1)
550 }
551 })
552 .collect::<Vec<u32>>()
553 })
554 .collect::<Vec<Vec<_>>>();
555 let relative_buckets = Tensor::new(relative_position, q.device())?;
556 let position_bias = relative_attention_bias
557 .forward(&relative_buckets)?
558 .permute((2, 0, 1))?
559 .unsqueeze(0)?
560 .to_dtype(scores.dtype())?;
561 (scores.broadcast_add(&position_bias)?, Some(position_bias))
562 }
564 },
565 };
566
567 let attn_weights = {
568 let _enter = self.span_sm.enter();
569 candle_nn::ops::softmax_last_dim(&scores)?
570 };
571 let attn_output = attn_weights.matmul(&v)?;
572 let attn_output = attn_output
573 .transpose(1, 2)?
574 .reshape((b_sz, q_len, self.inner_dim))?;
575 let attn_output = self.o.forward(&attn_output)?;
576 Ok((attn_output, position_bias))
577 }
578
579 fn clear_kv_cache(&mut self) {
580 self.kv_cache = None
581 }
582}
583
584#[derive(Debug, Clone)]
585struct T5LayerSelfAttention {
586 self_attention: T5Attention,
587 layer_norm: T5LayerNorm,
588 span: tracing::Span,
589}
590
591impl T5LayerSelfAttention {
592 fn load(h: bool, d: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
593 let self_attention = T5Attention::load(h, d, vb.pp("SelfAttention"), cfg)?;
594 let layer_norm =
595 T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
596 Ok(Self {
597 self_attention,
598 layer_norm,
599 span: tracing::span!(tracing::Level::TRACE, "self-attn"),
600 })
601 }
602
603 fn forward(
604 &mut self,
605 xs: &Tensor,
606 position_bias: Option<&Tensor>,
607 mask: Option<&Tensor>,
608 ) -> Result<(Tensor, Option<Tensor>)> {
609 let _enter = self.span.enter();
610 let normed_xs = self.layer_norm.forward(xs)?;
611 let (ys, position_bias) =
612 self.self_attention
613 .forward(&normed_xs, position_bias, None, mask)?;
614 let ys = (xs + ys)?;
615 Ok((ys, position_bias))
616 }
617
618 fn clear_kv_cache(&mut self) {
619 self.self_attention.clear_kv_cache()
620 }
621}
622
623#[derive(Debug, Clone)]
624struct T5LayerCrossAttention {
625 cross_attention: T5Attention,
626 layer_norm: T5LayerNorm,
627 span: tracing::Span,
628}
629
630impl T5LayerCrossAttention {
631 fn load(decoder: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
632 let cross_attention = T5Attention::load(false, decoder, vb.pp("EncDecAttention"), cfg)?;
633 let layer_norm =
634 T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
635 Ok(Self {
636 cross_attention,
637 layer_norm,
638 span: tracing::span!(tracing::Level::TRACE, "cross-attn"),
639 })
640 }
641
642 fn forward(
643 &mut self,
644 hidden_states: &Tensor,
645 position_bias: Option<&Tensor>,
646 key_value_states: &Tensor,
647 ) -> Result<(Tensor, Option<Tensor>)> {
648 let _enter = self.span.enter();
649 let normed_hidden_states = self.layer_norm.forward(hidden_states)?;
650 let (ys, position_bias) = self.cross_attention.forward(
651 &normed_hidden_states,
652 position_bias,
653 Some(key_value_states),
654 None,
655 )?;
656 let ys = (hidden_states + ys)?;
657 Ok((ys, position_bias))
658 }
659
660 fn clear_kv_cache(&mut self) {
661 self.cross_attention.clear_kv_cache()
662 }
663}
664
665#[derive(Debug, Clone)]
666struct T5Block {
667 self_attn: T5LayerSelfAttention,
668 cross_attn: Option<T5LayerCrossAttention>,
669 ff: T5LayerFF,
670 span: tracing::Span,
671}
672
673impl T5Block {
674 fn load(
675 has_relative_attention_bias: bool,
676 decoder: bool,
677 vb: VarBuilder,
678 cfg: &Config,
679 ) -> Result<Self> {
680 let vb = vb.pp("layer");
681 let self_attn =
682 T5LayerSelfAttention::load(has_relative_attention_bias, decoder, vb.pp("0"), cfg)?;
683 let cross_attn = if cfg.is_decoder {
684 Some(T5LayerCrossAttention::load(decoder, vb.pp("1"), cfg)?)
685 } else {
686 None
687 };
688 let ff_i = if cross_attn.is_some() { 2 } else { 1 };
689 let ff = T5LayerFF::load(vb.pp(ff_i.to_string()), cfg)?;
690 Ok(Self {
691 self_attn,
692 cross_attn,
693 ff,
694 span: tracing::span!(tracing::Level::TRACE, "block"),
695 })
696 }
697
698 fn forward(
699 &mut self,
700 xs: &Tensor,
701 position_bias: Option<&Tensor>,
702 encoder_hidden_states: Option<&Tensor>,
703 ) -> Result<(Tensor, Option<Tensor>)> {
704 let _enter = self.span.enter();
705 let mask = match self.cross_attn.is_some() {
707 true => {
708 let mask_len = xs.dim(1)?;
709 if mask_len <= 1 {
712 None
713 } else {
714 Some(get_mask(mask_len, xs.device())?)
715 }
716 }
717 false => None,
718 };
719 let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias, mask.as_ref())?;
720 if let Some(cross_attn) = &mut self.cross_attn {
722 (xs, _) = cross_attn.forward(&xs, None, encoder_hidden_states.unwrap())?;
723 }
725 let xs = self.ff.forward(&xs)?;
726 Ok((xs, position_bias))
728 }
729
730 fn clear_kv_cache(&mut self) {
731 self.self_attn.clear_kv_cache();
732 self.cross_attn.iter_mut().for_each(|c| c.clear_kv_cache());
733 }
734}
735
736#[derive(Debug, Clone)]
737struct T5Stack {
738 block: Vec<T5Block>,
739 shared: Arc<Embedding>,
740 final_layer_norm: T5LayerNorm,
741 span: tracing::Span,
742}
743
744impl T5Stack {
745 fn load(decoder: bool, vb: VarBuilder, shared: &Arc<Embedding>, cfg: &Config) -> Result<Self> {
746 let block = (0..cfg.num_layers)
747 .map(|i| T5Block::load(i == 0, decoder, vb.pp(format!("block.{i}")), cfg))
748 .collect::<Result<Vec<_>>>()?;
749 let final_layer_norm = T5LayerNorm::load(
750 cfg.d_model,
751 cfg.layer_norm_epsilon,
752 vb.pp("final_layer_norm"),
753 )?;
754 Ok(Self {
755 block,
756 shared: shared.clone(),
757 final_layer_norm,
758 span: tracing::span!(tracing::Level::TRACE, "stack"),
759 })
760 }
761
762 fn forward(
763 &mut self,
764 input_ids: &Tensor,
765 encoder_hidden_states: Option<&Tensor>,
766 ) -> Result<Tensor> {
767 self.forward_dt(input_ids, encoder_hidden_states, None)
768 }
769
770 fn forward_dt(
771 &mut self,
772 input_ids: &Tensor,
773 encoder_hidden_states: Option<&Tensor>,
774 dtype: Option<DType>,
775 ) -> Result<Tensor> {
776 let _enter = self.span.enter();
777 let input_embeds = self.shared.as_ref().forward(input_ids)?;
778 let input_embeds = match dtype {
779 None => input_embeds,
780 Some(dtype) => input_embeds.to_dtype(dtype)?,
781 };
782 let mut hidden_states = input_embeds;
783 let mut position_bias = None;
784 for block in self.block.iter_mut() {
785 (hidden_states, position_bias) = block.forward(
786 &hidden_states,
787 position_bias.as_ref(),
788 encoder_hidden_states,
789 )?
790 }
791 self.final_layer_norm.forward(&hidden_states)
792 }
793
794 fn clear_kv_cache(&mut self) {
795 self.block.iter_mut().for_each(|b| b.clear_kv_cache())
796 }
797}
798
799#[derive(Debug, Clone)]
800pub struct T5EncoderModel {
801 encoder: T5Stack,
802 device: Device,
803 span: tracing::Span,
804}
805
806impl T5EncoderModel {
807 pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
808 let shared_vb = if vb.contains_tensor("shared.weight") {
809 vb.pp("shared")
810 } else if vb.contains_tensor("decoder.embed_tokens") {
811 vb.pp("decoder").pp("embed_tokens")
812 } else {
813 vb.pp("encoder").pp("embed_tokens")
814 };
815 let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
816 let shared = Arc::new(shared);
817 let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
818 Ok(Self {
819 encoder,
820 device: vb.device().clone(),
821 span: tracing::span!(tracing::Level::TRACE, "encoder"),
822 })
823 }
824
825 pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
826 let _enter = self.span.enter();
827 self.encoder.forward(input_ids, None)
828 }
829
830 pub fn forward_dt(&mut self, input_ids: &Tensor, dtype: Option<DType>) -> Result<Tensor> {
831 let _enter = self.span.enter();
832 self.encoder.forward_dt(input_ids, None, dtype)
833 }
834
835 pub fn device(&self) -> &Device {
836 &self.device
837 }
838
839 pub fn clear_kv_cache(&mut self) {
840 self.encoder.clear_kv_cache()
841 }
842}
843
844#[derive(Debug, Clone)]
845pub struct T5ForConditionalGeneration {
846 encoder: T5Stack,
847 decoder: T5Stack,
848 d_model: usize,
849 tie_word_embeddings: bool,
850 lm_head: Option<Linear>,
851 shared: Arc<Embedding>,
852 device: Device,
853 span_decode: tracing::Span,
854 span_decode_head: tracing::Span,
855}
856
857impl T5ForConditionalGeneration {
858 pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
859 assert!(cfg.is_encoder_decoder);
860 let d_model = cfg.d_model;
861 let shared_vb = if vb.contains_tensor("shared.weight") {
862 vb.pp("shared")
863 } else {
864 vb.pp("decoder").pp("embed_tokens")
865 };
866 let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
867 let shared = Arc::new(shared);
868
869 let mut encoder_cfg = cfg.clone();
870 encoder_cfg.is_decoder = false;
871 encoder_cfg.use_cache = false;
872 encoder_cfg.is_encoder_decoder = false;
873 let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, &encoder_cfg)?;
874
875 let mut decoder_cfg = cfg.clone();
876 decoder_cfg.is_decoder = true;
877 decoder_cfg.is_encoder_decoder = false;
878 decoder_cfg.num_layers = cfg.num_decoder_layers.unwrap_or(cfg.num_layers);
879 let decoder = T5Stack::load(true, vb.pp("decoder"), &shared, &decoder_cfg)?;
880
881 let tie_word_embeddings = cfg.tie_word_embeddings;
882 let lm_head = if tie_word_embeddings {
883 None
884 } else {
885 Some(linear_no_bias(
886 cfg.d_model,
887 cfg.vocab_size,
888 vb.pp("lm_head"),
889 )?)
890 };
891
892 Ok(Self {
893 encoder,
894 decoder,
895 d_model,
896 tie_word_embeddings,
897 lm_head,
898 shared,
899 device: vb.device().clone(),
900 span_decode: tracing::span!(tracing::Level::TRACE, "decode"),
901 span_decode_head: tracing::span!(tracing::Level::TRACE, "decode-head"),
902 })
903 }
904
905 pub fn encode(&mut self, input_ids: &Tensor) -> Result<Tensor> {
906 self.encoder.forward(input_ids, None)
907 }
908
909 pub fn decode(
910 &mut self,
911 decoder_input_ids: &Tensor,
912 encoder_output: &Tensor,
913 ) -> Result<Tensor> {
914 let _enter = self.span_decode.enter();
915 let decoder_output = self
916 .decoder
917 .forward(decoder_input_ids, Some(encoder_output))?;
918
919 let scaling_factor = if self.tie_word_embeddings {
920 (self.d_model as f64).sqrt()
923 } else {
924 1.0
925 };
926 let sequence_output = ((decoder_output
927 .narrow(1, decoder_output.dim(1)? - 1, 1)?
928 .squeeze(1)?)
929 * scaling_factor)?;
930 let output = {
931 let _enter = self.span_decode_head.enter();
932 match self.lm_head {
933 None => sequence_output.matmul(&self.shared.embeddings().t()?)?,
934 Some(ref lm_head) => lm_head.forward(&sequence_output)?,
935 }
936 };
937 Ok(output)
938 }
939
940 pub fn forward(&mut self, input_ids: &Tensor, decoder_input_ids: &Tensor) -> Result<Tensor> {
941 let encoder_output = self.encode(input_ids)?;
942 self.decode(decoder_input_ids, &encoder_output)
943 }
944
945 pub fn device(&self) -> &Device {
946 &self.device
947 }
948
949 pub fn clear_kv_cache(&mut self) {
950 self.encoder.clear_kv_cache();
951 self.decoder.clear_kv_cache();
952 }
953}