use burn::prelude::*;
use burn::nn::conv::{Conv1d, Conv1dConfig, Conv2d, Conv2dConfig};
#[derive(Module, Debug)]
pub struct PatchEmbed<B: Backend> {
pub conv1d: Option<Conv1d<B>>,
pub conv2d: Option<Conv2d<B>>,
pub lead_wise: usize,
}
impl<B: Backend> PatchEmbed<B> {
pub fn new(
num_leads: usize,
width: usize,
patch_size_time: usize,
patch_size_ch: usize,
lead_wise: usize,
device: &B::Device,
) -> Self {
if lead_wise == 0 {
let conv = Conv1dConfig::new(num_leads, width, patch_size_time)
.with_stride(patch_size_time)
.with_bias(false)
.init(device);
Self { conv1d: Some(conv), conv2d: None, lead_wise }
} else {
let conv = Conv2dConfig::new([1, width], [patch_size_ch, patch_size_time])
.with_stride([patch_size_ch, patch_size_time])
.with_bias(false)
.init(device);
Self { conv1d: None, conv2d: Some(conv), lead_wise }
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
if self.lead_wise == 0 {
let conv = self.conv1d.as_ref().unwrap();
let out = conv.forward(x); let [_b, _w, _nt] = out.dims();
out.swap_dims(1, 2) } else {
let conv = self.conv2d.as_ref().unwrap();
let x = x.unsqueeze_dim::<4>(1); let out = conv.forward(x); let [b, w, lr, nt] = out.dims();
out.swap_dims(1, 2) .swap_dims(2, 3) .reshape([b, lr * nt, w]) }
}
}