use burn::prelude::*;
use burn::nn::conv::{Conv2d, Conv2dConfig};
#[derive(Module, Debug)]
pub struct PatchEmbed<B: Backend> {
pub proj: Conv2d<B>,
pub n_chans: usize,
pub n_times: usize,
pub patch_size: usize,
pub patch_stride: usize,
pub padding_size: usize,
pub n_patches: usize,
}
impl<B: Backend> PatchEmbed<B> {
pub fn new(
n_chans: usize, n_times: usize,
patch_size: usize, patch_stride: usize,
embed_dim: usize, device: &B::Device,
) -> Self {
let remainder = (n_times - patch_size) % patch_stride;
let padding_size = if remainder != 0 { patch_stride - remainder } else { 0 };
let n_times_padded = n_times + padding_size;
let n_patches = (n_times_padded - patch_size) / patch_stride + 1;
let proj = Conv2dConfig::new([1, embed_dim], [1, patch_size])
.with_stride([1, patch_stride])
.with_padding(burn::nn::PaddingConfig2d::Valid)
.with_bias(true)
.init(device);
Self { proj, n_chans, n_times, patch_size, patch_stride, padding_size, n_patches }
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 4> {
let [batch, n_chans, _n_times] = x.dims();
let device = x.device();
let x = if self.padding_size > 0 {
let pad = Tensor::zeros([batch, n_chans, self.padding_size], &device);
Tensor::cat(vec![x, pad], 2)
} else {
x
};
let x = x.unsqueeze_dim::<4>(1);
let x = self.proj.forward(x);
let [_b, _embed_dim, _c, _np] = x.dims();
x.swap_dims(1, 3) }
}