use burn::{
module::Module,
nn::{
Gelu, LayerNorm, LayerNormConfig, PaddingConfig1d,
conv::{Conv1d, Conv1dConfig},
},
tensor::{Tensor, backend::Backend},
};
use crate::model::{
attention::MultiHeadAttention, feed_forward::FeedForward, positional::sinusoids,
};
#[derive(Module, Debug)]
pub struct EncoderBlock<B: Backend> {
norm1: LayerNorm<B>,
self_attn: MultiHeadAttention<B>,
norm2: LayerNorm<B>,
ffn: FeedForward<B>,
}
impl<B: Backend> EncoderBlock<B> {
pub fn new(d_model: usize, n_heads: usize, device: &B::Device) -> Self {
Self {
norm1: LayerNormConfig::new(d_model).init(device),
self_attn: MultiHeadAttention::new(d_model, n_heads, device),
norm2: LayerNormConfig::new(d_model).init(device),
ffn: FeedForward::new(d_model, d_model * 4, device),
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let residual = x.clone();
let x = self.norm1.forward(x);
let x = self.self_attn.forward(x, None, None);
let x = x + residual;
let residual = x.clone();
let x = self.norm2.forward(x);
let x = self.ffn.forward(x);
x + residual
}
}
#[derive(Module, Debug)]
pub struct WhisperEncoder<B: Backend> {
conv1: Conv1d<B>,
conv2: Conv1d<B>,
gelu: Gelu,
blocks: Vec<EncoderBlock<B>>,
norm: LayerNorm<B>,
d_model: usize,
}
impl<B: Backend> WhisperEncoder<B> {
pub fn new(
n_mels: usize,
d_model: usize,
n_heads: usize,
n_layers: usize,
device: &B::Device,
) -> Self {
let conv1 = Conv1dConfig::new(n_mels, d_model, 3)
.with_padding(PaddingConfig1d::Explicit(1))
.init(device);
let conv2 = Conv1dConfig::new(d_model, d_model, 3)
.with_stride(2)
.with_padding(PaddingConfig1d::Explicit(1))
.init(device);
let blocks = (0..n_layers)
.map(|_| EncoderBlock::new(d_model, n_heads, device))
.collect();
Self {
conv1,
conv2,
gelu: Gelu::new(),
blocks,
norm: LayerNormConfig::new(d_model).init(device),
d_model,
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let x = self.gelu.forward(self.conv1.forward(x));
let x = self.gelu.forward(self.conv2.forward(x));
let x = x.swap_dims(1, 2);
let n_samples = x.shape().dims[1];
let device = x.device();
let pos_emb = sinusoids::<B>(n_samples, self.d_model, &device);
let x = x + pos_emb.unsqueeze::<3>();
let x = self.blocks.iter().fold(x, |x, block| block.forward(x));
self.norm.forward(x)
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
type TestBackend = NdArray;
#[test]
fn test_encoder_output_shape() {
let device = Default::default();
let encoder = WhisperEncoder::<TestBackend>::new(80, 384, 6, 2, &device);
let mel = Tensor::<TestBackend, 3>::random(
[1, 80, 100],
burn::tensor::Distribution::Normal(0.0, 1.0),
&device,
);
let out = encoder.forward(mel);
assert_eq!(out.shape().dims, [1, 50, 384]);
}
}