use alloc::vec::Vec;
use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig};
use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
use crate::tensor::Bool;
use crate::{
self as burn,
nn::{attention::MhaCache, cache::TensorCache, Initializer},
};
use crate::{
config::Config,
nn::{
attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
Dropout, DropoutConfig, LayerNorm, LayerNormConfig,
},
tensor::{backend::Backend, Tensor},
};
#[derive(Config)]
pub struct TransformerDecoderConfig {
pub d_model: usize,
pub d_ff: usize,
pub n_heads: usize,
pub n_layers: usize,
#[config(default = 0.1)]
pub dropout: f64,
#[config(default = false)]
pub norm_first: bool,
#[config(default = false)]
pub quiet_softmax: bool,
#[config(
default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
)]
pub initializer: Initializer,
}
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct TransformerDecoder<B: Backend> {
pub layers: Vec<TransformerDecoderLayer<B>>,
pub d_model: usize,
pub d_ff: usize,
pub n_heads: usize,
pub n_layers: usize,
pub dropout: f64,
pub norm_first: bool,
pub quiet_softmax: bool,
}
impl<B: Backend> ModuleDisplay for TransformerDecoder<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("d_model", &self.d_model)
.add("d_ff", &self.d_ff)
.add("n_heads", &self.n_heads)
.add("n_layers", &self.n_layers)
.add("dropout", &self.dropout)
.add("norm_first", &self.norm_first)
.add("quiet_softmax", &self.quiet_softmax)
.optional()
}
}
impl TransformerDecoderConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> TransformerDecoder<B> {
let layers = (0..self.n_layers)
.map(|_| TransformerDecoderLayer::new(self, device))
.collect::<Vec<_>>();
TransformerDecoder {
layers,
d_model: self.d_model,
d_ff: self.d_ff,
n_heads: self.n_heads,
n_layers: self.n_layers,
dropout: self.dropout,
norm_first: self.norm_first,
quiet_softmax: self.quiet_softmax,
}
}
}
#[derive(Debug)]
pub struct TransformerDecoderInput<B: Backend> {
target: Tensor<B, 3>,
target_mask_pad: Option<Tensor<B, 2, Bool>>,
target_mask_attn: Option<Tensor<B, 3, Bool>>,
memory: Tensor<B, 3>,
memory_mask_pad: Option<Tensor<B, 2, Bool>>,
memory_mask_attn: Option<Tensor<B, 3, Bool>>,
}
impl<B: Backend> TransformerDecoderInput<B> {
pub fn new(target: Tensor<B, 3>, memory: Tensor<B, 3>) -> Self {
Self {
target,
target_mask_pad: None,
target_mask_attn: None,
memory,
memory_mask_pad: None,
memory_mask_attn: None,
}
}
pub fn memory_mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
self.memory_mask_pad = Some(mask_pad);
self
}
pub fn memory_mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
self.memory_mask_attn = Some(mask_attn);
self
}
pub fn target_mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
self.target_mask_pad = Some(mask_pad);
self
}
pub fn target_mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
self.target_mask_attn = Some(mask_attn);
self
}
}
#[derive(Module, Debug)]
pub struct TransformerDecoderLayer<B: Backend> {
cross_attn: MultiHeadAttention<B>,
self_attn: MultiHeadAttention<B>,
pwff: PositionWiseFeedForward<B>,
norm_1: LayerNorm<B>,
norm_2: LayerNorm<B>,
norm_3: LayerNorm<B>,
dropout: Dropout,
norm_first: bool,
}
struct TransformerDecoderLayerAutoregressiveCache<B: Backend> {
cross_attn: MhaCache<B>,
self_attn: MhaCache<B>,
pwff: TensorCache<B, 3>,
norm_1: TensorCache<B, 3>,
norm_2: TensorCache<B, 3>,
norm_3: TensorCache<B, 3>,
}
impl<B: Backend> TransformerDecoderLayerAutoregressiveCache<B> {
fn empty() -> Self {
Self {
cross_attn: MhaCache::autoregressive_cross_attention(),
self_attn: MhaCache::autoregressive(),
pwff: TensorCache::empty(),
norm_1: TensorCache::empty(),
norm_2: TensorCache::empty(),
norm_3: TensorCache::empty(),
}
}
}
pub struct TransformerDecoderAutoregressiveCache<B: Backend> {
layers: Vec<TransformerDecoderLayerAutoregressiveCache<B>>,
}
impl<B: Backend> TransformerDecoderAutoregressiveCache<B> {
fn empty(num_layers: usize) -> Self {
Self {
layers: (0..num_layers)
.map(|_| TransformerDecoderLayerAutoregressiveCache::empty())
.collect(),
}
}
}
impl<B: Backend> TransformerDecoderLayer<B> {
fn new(config: &TransformerDecoderConfig, device: &B::Device) -> Self {
let self_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
.with_initializer(config.initializer.clone())
.with_dropout(config.dropout)
.with_quiet_softmax(config.quiet_softmax)
.init(device);
let cross_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
.with_initializer(config.initializer.clone())
.with_dropout(config.dropout)
.with_quiet_softmax(config.quiet_softmax)
.init(device);
let norm_1 = LayerNormConfig::new(config.d_model).init(device);
let norm_2 = LayerNormConfig::new(config.d_model).init(device);
let norm_3 = LayerNormConfig::new(config.d_model).init(device);
let dropout = DropoutConfig::new(config.dropout).init();
let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff)
.with_dropout(config.dropout)
.init(device);
Self {
cross_attn,
self_attn,
norm_1,
norm_2,
norm_3,
pwff,
dropout,
norm_first: config.norm_first,
}
}
fn forward(&self, mut input: TransformerDecoderInput<B>) -> TransformerDecoderInput<B> {
let x = input.target;
let mut residual_path = x.clone();
if self.norm_first {
residual_path = self.norm_3.forward(residual_path);
}
let mut self_attn_input = MhaInput::self_attn(residual_path);
if let Some(mask_pad) = &input.target_mask_pad {
self_attn_input = self_attn_input.mask_pad(mask_pad.clone());
}
if let Some(mask_attn) = &input.target_mask_attn {
self_attn_input = self_attn_input.mask_attn(mask_attn.clone());
}
let residual_path = self.self_attn.forward(self_attn_input).context;
let residual_path = self.dropout.forward(residual_path);
let mut x = x + residual_path;
let residual_path = if self.norm_first {
self.norm_1.forward(x.clone())
} else {
x = self.norm_1.forward(x);
x.clone()
};
let mut cross_attn_input =
MhaInput::new(residual_path, input.memory.clone(), input.memory.clone());
if let Some(mask_pad) = &input.memory_mask_pad {
cross_attn_input = cross_attn_input.mask_pad(mask_pad.clone());
}
if let Some(mask_attn) = &input.memory_mask_attn {
cross_attn_input = cross_attn_input.mask_attn(mask_attn.clone());
}
let residual_path = self.cross_attn.forward(cross_attn_input).context;
let residual_path = self.dropout.forward(residual_path);
let mut x = x + residual_path;
let residual_path = if self.norm_first {
self.norm_2.forward(x.clone())
} else {
x = self.norm_2.forward(x);
x.clone()
};
let residual_path = self.pwff.forward(residual_path);
let residual_path = self.dropout.forward(residual_path);
let mut x = x + residual_path;
if !self.norm_first {
x = self.norm_3.forward(x)
}
input.target = x;
input
}
fn forward_autoregressive_inference(
&self,
mut input: TransformerDecoderInput<B>,
cache: &mut TransformerDecoderLayerAutoregressiveCache<B>,
) -> TransformerDecoderInput<B> {
let x = input.target;
let mut residual_path = x.clone();
if self.norm_first {
residual_path = cache
.norm_3
.forward_autoregressive(residual_path, 1, |x| self.norm_3.forward(x));
}
let mut self_attn_input = MhaInput::self_attn(residual_path);
if let Some(mask_pad) = &input.target_mask_pad {
self_attn_input = self_attn_input.mask_pad(mask_pad.clone());
}
if let Some(mask_attn) = &input.target_mask_attn {
self_attn_input = self_attn_input.mask_attn(mask_attn.clone());
}
let residual_path = self
.self_attn
.forward_cache(self_attn_input, &mut cache.self_attn)
.context;
let residual_path = self.dropout.forward(residual_path);
let mut x = x + residual_path;
let residual_path = if self.norm_first {
cache
.norm_1
.forward_autoregressive(x.clone(), 1, |x| self.norm_1.forward(x))
} else {
x = cache
.norm_1
.forward_autoregressive(x, 1, |x| self.norm_1.forward(x));
x.clone()
};
let mut cross_attn_input =
MhaInput::new(residual_path, input.memory.clone(), input.memory.clone());
if let Some(mask_pad) = &input.memory_mask_pad {
cross_attn_input = cross_attn_input.mask_pad(mask_pad.clone());
}
if let Some(mask_attn) = &input.memory_mask_attn {
cross_attn_input = cross_attn_input.mask_attn(mask_attn.clone());
}
let residual_path = self
.cross_attn
.forward_cache(cross_attn_input, &mut cache.cross_attn)
.context;
let residual_path = self.dropout.forward(residual_path);
let mut x = x + residual_path;
let residual_path = if self.norm_first {
cache
.norm_2
.forward_autoregressive(x.clone(), 1, |x| self.norm_2.forward(x))
} else {
x = cache
.norm_2
.forward_autoregressive(x, 1, |x| self.norm_2.forward(x));
x.clone()
};
let residual_path = cache
.pwff
.forward_autoregressive(residual_path, 1, |x| self.pwff.forward(x));
let residual_path = self.dropout.forward(residual_path);
let mut x = x + residual_path;
if !self.norm_first {
x = cache
.norm_3
.forward_autoregressive(x, 1, |x| self.norm_3.forward(x))
}
input.target = x;
input
}
}
impl<B: Backend> TransformerDecoder<B> {
pub fn forward(&self, mut input: TransformerDecoderInput<B>) -> Tensor<B, 3> {
for layer in self.layers.iter() {
input = layer.forward(input);
}
input.target
}
pub fn forward_autoregressive_inference(
&self,
mut input: TransformerDecoderInput<B>,
cache: &mut TransformerDecoderAutoregressiveCache<B>,
) -> Tensor<B, 3> {
for i in 0..self.layers.len() {
let layer = self.layers.get(i).unwrap();
let cache = cache.layers.get_mut(i).unwrap();
input = layer.forward_autoregressive_inference(input, cache);
}
input.target
}
pub fn new_autoregressive_cache(&self) -> TransformerDecoderAutoregressiveCache<B> {
TransformerDecoderAutoregressiveCache::empty(self.layers.len())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::Distribution;
use crate::{nn::attention::generate_autoregressive_mask, TestBackend};
#[test]
fn test_autoregressive_norm_last() {
let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
TestBackend::seed(0);
test_autoregressive(
TransformerDecoderConfig::new(d_model, d_ff, n_heads, num_layers)
.with_norm_first(false),
)
}
#[test]
fn test_autoregressive_norm_first() {
let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
TestBackend::seed(0);
test_autoregressive(
TransformerDecoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(true),
)
}
fn test_autoregressive(config: TransformerDecoderConfig) {
let device = Default::default();
let [batch_size, seq_length, d_model] = [3, 4, config.d_model];
let transformer = config.init(&device);
let memory = Tensor::<TestBackend, 3>::random(
[batch_size, seq_length, d_model],
Distribution::Default,
&device,
);
let target = Tensor::<TestBackend, 3>::random(
[batch_size, seq_length, d_model],
Distribution::Default,
&device,
);
let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &target.device());
let input = TransformerDecoderInput::new(target.clone(), memory.clone())
.target_mask_attn(mask_attn);
let output_1 = transformer.forward(input);
let mut output_2 = Vec::new();
let mut cache = transformer.new_autoregressive_cache();
for i in 1..seq_length + 1 {
let target = target.clone().slice([0..batch_size, 0..i, 0..d_model]);
let mask_attn = generate_autoregressive_mask(batch_size, i, &target.device());
let input = TransformerDecoderInput::new(target.clone(), memory.clone())
.target_mask_attn(mask_attn);
let next_tok = transformer .forward_autoregressive_inference(input, &mut cache)
.slice([0..batch_size, i - 1..i, 0..d_model]);
output_2.push(next_tok);
}
let output_2 = Tensor::cat(output_2, 1);
output_1
.into_data()
.assert_approx_eq(&output_2.into_data(), 3);
}
#[test]
fn display() {
let config = TransformerDecoderConfig::new(2, 4, 2, 3);
let transformer = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{}", transformer),
"TransformerDecoder {d_model: 2, d_ff: 4, n_heads: 2, n_layers: 3, \
dropout: 0.1, norm_first: false, quiet_softmax: false, params: 246}"
);
}
}