use burn::{
module::{Module, Param},
nn::{
conv::{Conv2d, Conv2dConfig},
Dropout, DropoutConfig, LayerNorm, LayerNormConfig, Linear, LinearConfig,
},
tensor::{
activation,
backend::Backend,
Distribution, Tensor,
},
};
use crate::config::{PoolType, SensorEncoderConfig};
#[derive(Module, Debug)]
pub struct PatchEmbedding<B: Backend> {
proj: Conv2d<B>,
num_patches_t: usize,
num_patches_c: usize,
d_model: usize,
}
impl<B: Backend> PatchEmbedding<B> {
pub fn new(
in_channels: usize,
d_model: usize,
patch_h: usize,
patch_w: usize,
time_steps: usize,
num_channels: usize,
device: &B::Device,
) -> Self {
let proj = Conv2dConfig::new(
[in_channels, d_model],
[patch_h, patch_w],
)
.with_stride([patch_h, patch_w])
.with_padding(burn::nn::PaddingConfig2d::Valid)
.with_bias(true)
.init(device);
let num_patches_t = time_steps / patch_h;
let num_patches_c = (num_channels + patch_w - 1) / patch_w;
Self {
proj,
num_patches_t,
num_patches_c,
d_model,
}
}
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 3> {
let out = self.proj.forward(x); let [batch, d, _pt, _pc] = out.dims();
let num_patches = self.num_patches_t * self.num_patches_c;
out.reshape([batch, d, num_patches]).swap_dims(1, 2)
}
pub fn num_patches(&self) -> usize {
self.num_patches_t * self.num_patches_c
}
}
#[derive(Module, Debug)]
pub struct MlpBlock<B: Backend> {
fc1: Linear<B>,
fc2: Linear<B>,
dropout: Dropout,
}
impl<B: Backend> MlpBlock<B> {
pub fn new(d_model: usize, mlp_dim: usize, dropout: f64, device: &B::Device) -> Self {
Self {
fc1: LinearConfig::new(d_model, mlp_dim).init(device),
fc2: LinearConfig::new(mlp_dim, d_model).init(device),
dropout: DropoutConfig::new(dropout).init(),
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let x = self.fc1.forward(x);
let x = activation::gelu(x);
let x = self.dropout.forward(x);
let x = self.fc2.forward(x);
self.dropout.forward(x)
}
}
#[derive(Module, Debug)]
pub struct MultiHeadSelfAttention<B: Backend> {
q_proj: Linear<B>,
k_proj: Linear<B>,
v_proj: Linear<B>,
out_proj: Linear<B>,
num_heads: usize,
head_dim: usize,
scale: f32,
chunk_size: usize, dropout: Dropout,
}
impl<B: Backend> MultiHeadSelfAttention<B> {
pub fn new(
d_model: usize,
num_heads: usize,
dropout: f64,
chunk_size: usize,
device: &B::Device,
) -> Self {
assert_eq!(d_model % num_heads, 0);
let head_dim = d_model / num_heads;
Self {
q_proj: LinearConfig::new(d_model, d_model).init(device),
k_proj: LinearConfig::new(d_model, d_model).init(device),
v_proj: LinearConfig::new(d_model, d_model).init(device),
out_proj: LinearConfig::new(d_model, d_model).init(device),
num_heads,
head_dim,
scale: (head_dim as f32).powf(-0.5),
chunk_size,
dropout: DropoutConfig::new(dropout).init(),
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let [batch, seq, _d] = x.dims();
let h = self.num_heads;
let hd = self.head_dim;
let q = self.q_proj.forward(x.clone())
.reshape([batch, seq, h, hd]).swap_dims(1, 2); let k = self.k_proj.forward(x.clone())
.reshape([batch, seq, h, hd]).swap_dims(1, 2); let v = self.v_proj.forward(x)
.reshape([batch, seq, h, hd]).swap_dims(1, 2);
let ctx = if self.chunk_size == 0 || self.chunk_size >= seq {
let scores = q.matmul(k.swap_dims(2, 3)).mul_scalar(self.scale);
let attn = activation::softmax(scores, 3);
let attn = self.dropout.forward(attn);
attn.matmul(v) } else {
let k_t = k.swap_dims(2, 3); let mut chunks: Vec<Tensor<B, 4>> = Vec::new();
let mut start = 0;
while start < seq {
let end = (start + self.chunk_size).min(seq);
let q_chunk = q.clone().slice([0..batch, 0..h, start..end, 0..hd]);
let scores = q_chunk.matmul(k_t.clone()).mul_scalar(self.scale);
let attn = activation::softmax(scores, 3);
let attn = self.dropout.forward(attn);
chunks.push(attn.matmul(v.clone()));
start = end;
}
Tensor::cat(chunks, 2) };
let ctx = ctx.swap_dims(1, 2).reshape([batch, seq, h * hd]);
self.out_proj.forward(ctx)
}
}
#[derive(Module, Debug)]
pub struct EncoderBlock<B: Backend> {
norm1: LayerNorm<B>,
attn: MultiHeadSelfAttention<B>,
norm2: LayerNorm<B>,
mlp: MlpBlock<B>,
dropout: Dropout,
}
impl<B: Backend> EncoderBlock<B> {
pub fn new(
d_model: usize,
num_heads: usize,
mlp_dim: usize,
dropout: f64,
chunk_size: usize,
device: &B::Device,
) -> Self {
Self {
norm1: LayerNormConfig::new(d_model).init(device),
attn: MultiHeadSelfAttention::new(d_model, num_heads, dropout, chunk_size, device),
norm2: LayerNormConfig::new(d_model).init(device),
mlp: MlpBlock::new(d_model, mlp_dim, dropout, device),
dropout: DropoutConfig::new(dropout).init(),
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let residual = x.clone();
let y = self.attn.forward(self.norm1.forward(x));
let y = self.dropout.forward(y);
let x = y + residual;
let residual = x.clone();
let y = self.mlp.forward(self.norm2.forward(x));
y + residual
}
}
#[derive(Module, Debug)]
pub struct MAPHead<B: Backend> {
probe: Param<Tensor<B, 3>>,
q_proj: Linear<B>,
k_proj: Linear<B>,
v_proj: Linear<B>,
out_proj: Linear<B>,
norm: LayerNorm<B>,
mlp: MlpBlock<B>,
num_heads: usize,
head_dim: usize,
scale: f32,
}
impl<B: Backend> MAPHead<B> {
pub fn new(
d_model: usize,
num_heads: usize,
mlp_dim: usize,
device: &B::Device,
) -> Self {
let head_dim = d_model / num_heads;
let probe = Tensor::<B, 3>::random(
[1, 1, d_model],
Distribution::Uniform(-0.02, 0.02),
device,
);
Self {
probe: Param::from_tensor(probe),
q_proj: LinearConfig::new(d_model, d_model).init(device),
k_proj: LinearConfig::new(d_model, d_model).init(device),
v_proj: LinearConfig::new(d_model, d_model).init(device),
out_proj: LinearConfig::new(d_model, d_model).init(device),
norm: LayerNormConfig::new(d_model).init(device),
mlp: MlpBlock::new(d_model, mlp_dim, 0.0, device),
num_heads,
head_dim,
scale: (head_dim as f32).powf(-0.5),
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
let [batch, seq, d] = x.dims();
let h = self.num_heads;
let hd = self.head_dim;
let probe = self.probe.val().expand([batch, 1, d]);
let q = self.q_proj.forward(probe);
let k = self.k_proj.forward(x.clone());
let v = self.v_proj.forward(x);
let rq = |t: Tensor<B, 3>, n: usize| t.reshape([batch, n, h, hd]).swap_dims(1, 2);
let q = rq(q, 1);
let k = rq(k, seq);
let v = rq(v, seq);
let scores = q.matmul(k.swap_dims(2, 3)).mul_scalar(self.scale);
let attn = activation::softmax(scores, 3);
let ctx = attn
.matmul(v)
.swap_dims(1, 2)
.reshape([batch, 1, h * hd]);
let ctx = self.out_proj.forward(ctx);
let ctx_2d = ctx.squeeze(1);
let normed = self.norm.forward(ctx_2d.clone().unsqueeze_dim(1));
let mlp_out = self.mlp.forward(normed).squeeze(1);
ctx_2d + mlp_out
}
}
#[derive(Module, Debug)]
pub struct SensorEncoder<B: Backend> {
patch_embed: PatchEmbedding<B>,
pos_embed: Param<Tensor<B, 3>>,
blocks: Vec<EncoderBlock<B>>,
norm: LayerNorm<B>,
map_head: Option<MAPHead<B>>,
dropout: Dropout,
d_model: usize,
}
impl<B: Backend> SensorEncoder<B> {
pub fn new(cfg: &SensorEncoderConfig, device: &B::Device) -> Self {
let num_patches = cfg.num_patches();
let patch_embed = PatchEmbedding::new(
1,
cfg.d_model,
cfg.patch_h,
cfg.patch_w,
cfg.time_steps,
cfg.num_channels,
device,
);
let pos_embed = Tensor::<B, 3>::random(
[1, num_patches, cfg.d_model],
Distribution::Normal(0.0, (1.0 / cfg.d_model as f64).sqrt()),
device,
);
let blocks: Vec<EncoderBlock<B>> = (0..cfg.depth)
.map(|_| EncoderBlock::new(cfg.d_model, cfg.num_heads, cfg.mlp_dim, cfg.dropout, cfg.attn_chunk_size, device))
.collect();
let norm = LayerNormConfig::new(cfg.d_model).init(device);
let map_head = if cfg.pool_type == PoolType::Map {
Some(MAPHead::new(cfg.d_model, cfg.num_heads, cfg.mlp_dim, device))
} else {
None
};
Self {
patch_embed,
pos_embed: Param::from_tensor(pos_embed),
blocks,
norm,
map_head,
dropout: DropoutConfig::new(cfg.dropout).init(),
d_model: cfg.d_model,
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
let [batch, _t, _c] = x.dims();
let x = x.unsqueeze_dim(1);
let mut tokens = self.patch_embed.forward(x);
let num_patches = tokens.dims()[1];
let pos = self.pos_embed.val().expand([batch, num_patches, self.d_model]);
tokens = tokens + pos;
tokens = self.dropout.forward(tokens);
for block in &self.blocks {
tokens = block.forward(tokens);
}
tokens = self.norm.forward(tokens);
let embedding: Tensor<B, 2> = match &self.map_head {
Some(map) => map.forward(tokens),
None => tokens.mean_dim(1).squeeze(1),
};
l2_normalize(embedding)
}
}
pub fn l2_normalize<B: Backend>(x: Tensor<B, 2>) -> Tensor<B, 2> {
let [batch, d] = x.dims();
let norm = x.clone().powf_scalar(2.0).sum_dim(1).sqrt().clamp_min(1e-12);
x / norm.expand([batch, d])
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
use crate::config::SensorEncoderConfig;
type B = NdArray;
fn tiny_cfg() -> SensorEncoderConfig {
SensorEncoderConfig {
time_steps: 40,
num_channels: 4,
patch_h: 10,
patch_w: 2,
d_model: 32,
depth: 2,
num_heads: 4,
mlp_dim: 64,
dropout: 0.0,
pool_type: PoolType::Gap,
head_zeroinit: false,
attn_chunk_size: 0, }
}
#[test]
fn test_patch_embedding_shape() {
let device = Default::default();
let cfg = tiny_cfg();
let pe = PatchEmbedding::<B>::new(1, cfg.d_model, cfg.patch_h, cfg.patch_w,
cfg.time_steps, cfg.num_channels, &device);
let x = Tensor::<B, 4>::zeros([2, 1, 40, 4], &device);
let out = pe.forward(x);
let [b, n, d] = out.dims();
assert_eq!(b, 2);
assert_eq!(n, (40 / 10) * (4 / 2)); assert_eq!(d, cfg.d_model);
}
#[test]
fn test_encoder_forward_shape() {
let device = Default::default();
let cfg = tiny_cfg();
let encoder = SensorEncoder::<B>::new(&cfg, &device);
let x = Tensor::<B, 3>::zeros([2, 40, 4], &device);
let out = encoder.forward(x);
let [b, d] = out.dims();
assert_eq!(b, 2);
assert_eq!(d, cfg.d_model);
}
}