use burn::{
module::Module,
nn::{Embedding, EmbeddingConfig, LayerNorm, LayerNormConfig},
tensor::{Int, Tensor, TensorData, backend::Backend},
};
use crate::model::{attention::MultiHeadAttention, feed_forward::FeedForward};
fn causal_mask<B: Backend>(seq_len: usize, device: &B::Device) -> Tensor<B, 2> {
let mut data = vec![0.0f32; seq_len * seq_len];
for i in 0..seq_len {
for j in (i + 1)..seq_len {
data[i * seq_len + j] = f32::NEG_INFINITY;
}
}
Tensor::from_data(TensorData::new(data, [seq_len, seq_len]), device)
}
#[derive(Module, Debug)]
pub struct DecoderBlock<B: Backend> {
norm1: LayerNorm<B>,
self_attn: MultiHeadAttention<B>,
norm2: LayerNorm<B>,
cross_attn: MultiHeadAttention<B>,
norm3: LayerNorm<B>,
ffn: FeedForward<B>,
}
impl<B: Backend> DecoderBlock<B> {
pub fn new(d_model: usize, n_heads: usize, device: &B::Device) -> Self {
Self {
norm1: LayerNormConfig::new(d_model).init(device),
self_attn: MultiHeadAttention::new(d_model, n_heads, device),
norm2: LayerNormConfig::new(d_model).init(device),
cross_attn: MultiHeadAttention::new(d_model, n_heads, device),
norm3: LayerNormConfig::new(d_model).init(device),
ffn: FeedForward::new(d_model, d_model * 4, device),
}
}
pub fn forward(
&self,
x: Tensor<B, 3>,
encoder_output: Tensor<B, 3>,
mask: Tensor<B, 2>,
) -> Tensor<B, 3> {
let residual = x.clone();
let x = self.norm1.forward(x);
let x = self.self_attn.forward(x, None, Some(mask));
let x = x + residual;
let residual = x.clone();
let x = self.norm2.forward(x);
let x = self.cross_attn.forward(x, Some(encoder_output), None);
let x = x + residual;
let residual = x.clone();
let x = self.norm3.forward(x);
let x = self.ffn.forward(x);
x + residual
}
}
#[derive(Module, Debug)]
pub struct WhisperDecoder<B: Backend> {
token_embedding: Embedding<B>,
positional_embedding: Embedding<B>,
blocks: Vec<DecoderBlock<B>>,
norm: LayerNorm<B>,
}
impl<B: Backend> WhisperDecoder<B> {
pub fn new(
vocab_size: usize,
d_model: usize,
n_heads: usize,
n_layers: usize,
max_target_positions: usize,
device: &B::Device,
) -> Self {
let blocks = (0..n_layers)
.map(|_| DecoderBlock::new(d_model, n_heads, device))
.collect();
Self {
token_embedding: EmbeddingConfig::new(vocab_size, d_model).init(device),
positional_embedding: EmbeddingConfig::new(max_target_positions, d_model).init(device),
blocks,
norm: LayerNormConfig::new(d_model).init(device),
}
}
pub fn forward(&self, tokens: Tensor<B, 2, Int>, encoder_output: Tensor<B, 3>) -> Tensor<B, 3> {
let seq_len = tokens.shape().dims[1];
let device = tokens.device();
let tok_emb = self.token_embedding.forward(tokens);
let pos_ids: Vec<i64> = (0..seq_len as i64).collect();
let positions =
Tensor::<B, 2, Int>::from_data(TensorData::new(pos_ids, [1, seq_len]), &device);
let pos_emb = self.positional_embedding.forward(positions);
let x = tok_emb + pos_emb;
let mask = causal_mask::<B>(seq_len, &device);
let x = self.blocks.iter().fold(x, |x, block| {
block.forward(x, encoder_output.clone(), mask.clone())
});
let x = self.norm.forward(x);
let dims = x.shape().dims;
let (batch, seq, d) = (dims[0], dims[1], dims[2]);
let vocab = self.token_embedding.weight.val().shape().dims[0];
let w = self.token_embedding.weight.val().transpose(); x.reshape([batch * seq, d])
.matmul(w)
.reshape([batch, seq, vocab])
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
type TestBackend = NdArray;
#[test]
fn test_decoder_output_shape() {
let device = Default::default();
let decoder = WhisperDecoder::<TestBackend>::new(
51865, 384, 6, 2, 448, &device,
);
let tokens = Tensor::<TestBackend, 2, Int>::zeros([1, 5], &device);
let encoder_out = Tensor::<TestBackend, 3>::random(
[1, 50, 384],
burn::tensor::Distribution::Normal(0.0, 1.0),
&device,
);
let logits = decoder.forward(tokens, encoder_out);
assert_eq!(logits.shape().dims, [1, 5, 51865]);
}
}