use burn::{
module::{Ignored, Module, Param},
nn::{
conv::{Conv1d, Conv1dConfig},
transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput},
Dropout, DropoutConfig, LayerNorm, LayerNormConfig,
},
prelude::Backend,
tensor::{Distribution, Tensor},
};
use crate::config::{
ENCODER_DROPOUT, ENCODER_FF_DIM, ENCODER_MAX_PATCHES, ENCODER_NUM_HEADS, ENCODER_NUM_LAYERS,
ENCODER_OUTPUT_DIM, PATCH_SIZE, TRANSFORMER_INPUT_DIM,
};
#[derive(Debug, Clone)]
pub struct TransformerCnnEncoderConfig {
pub output_dim: usize,
pub transformer_input_dim: usize,
pub patch_size: usize,
pub num_heads: usize,
pub num_layers: usize,
pub ff_dim: usize,
pub dropout: f64,
pub max_patches: usize,
}
impl Default for TransformerCnnEncoderConfig {
fn default() -> Self {
Self {
output_dim: ENCODER_OUTPUT_DIM,
transformer_input_dim: TRANSFORMER_INPUT_DIM,
patch_size: PATCH_SIZE,
num_heads: ENCODER_NUM_HEADS,
num_layers: ENCODER_NUM_LAYERS,
ff_dim: ENCODER_FF_DIM,
dropout: ENCODER_DROPOUT,
max_patches: ENCODER_MAX_PATCHES,
}
}
}
impl TransformerCnnEncoderConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> TransformerCnnEncoder<B> {
let patch_embed = Conv1dConfig::new(1, self.transformer_input_dim, self.patch_size)
.with_stride(self.patch_size)
.with_bias(false)
.init(device);
let pos_embed = Param::from_tensor(Tensor::<B, 3>::random(
[1, self.max_patches, self.transformer_input_dim],
Distribution::Normal(0.0, 0.02),
device,
));
let input_norm = LayerNormConfig::new(self.transformer_input_dim).init::<B>(device);
let input_dropout = DropoutConfig::new(self.dropout).init();
let transformer = TransformerEncoderConfig::new(
self.transformer_input_dim,
self.ff_dim,
self.num_heads,
self.num_layers,
)
.with_dropout(self.dropout)
.init(device);
TransformerCnnEncoder {
patch_embed,
pos_embed,
input_norm,
input_dropout,
transformer,
patch_size: Ignored(self.patch_size),
max_patches: Ignored(self.max_patches),
}
}
}
#[derive(Module, Debug)]
pub struct TransformerCnnEncoder<B: Backend> {
patch_embed: Conv1d<B>,
pos_embed: Param<Tensor<B, 3>>,
input_norm: LayerNorm<B>,
input_dropout: Dropout,
transformer: TransformerEncoder<B>,
patch_size: Ignored<usize>,
max_patches: Ignored<usize>,
}
impl<B: Backend> TransformerCnnEncoder<B> {
pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 3> {
let [b, l] = x.dims();
assert!(
l % *self.patch_size == 0,
"Sequence length {l} must be divisible by patch_size {}",
*self.patch_size
);
let x = x.reshape([b, 1, l]);
let x = self.patch_embed.forward(x);
let x = x.swap_dims(1, 2);
let [_, n_patches, d] = x.dims();
assert!(
n_patches <= *self.max_patches,
"Number of patches {n_patches} exceeds max_patches {}.",
*self.max_patches
);
let pos = self
.pos_embed
.val()
.slice([0..1, 0..n_patches, 0..d])
.expand([b, n_patches, d]);
let x = x + pos;
let x = self.input_norm.forward(x);
let x = self.input_dropout.forward(x);
let input = TransformerEncoderInput::new(x);
self.transformer.forward(input)
}
}