use burn::nn::Linear;
use burn::prelude::*;
use crate::model::linear_zeros;
#[derive(Module, Debug)]
pub struct FlexiPatchEmbed<B: Backend> {
pub proj: Linear<B>,
pub patch_size: usize,
pub num_patches: usize,
pub num_patches_2d: (usize, usize),
}
impl<B: Backend> FlexiPatchEmbed<B> {
pub fn new(
signal_size: (usize, usize), patch_size: usize,
in_chans: usize,
embed_dim: usize,
device: &B::Device,
) -> Self {
let n_rois = signal_size.0;
let n_time_patches = signal_size.1 / patch_size;
let num_patches = n_rois * n_time_patches;
let proj = linear_zeros(in_chans * patch_size, embed_dim, true, device);
Self {
proj,
patch_size,
num_patches,
num_patches_2d: (n_rois, n_time_patches),
}
}
pub fn forward(&self, x: Tensor<B, 4>, _patch_size: Option<usize>) -> Tensor<B, 3> {
let [b, _c, h, w] = x.dims();
let ps = _patch_size.unwrap_or(self.patch_size);
let n_t = w / ps;
let x = x.reshape([b, h, w]);
let x = x.reshape([b, h, n_t, ps]);
let x = x.reshape([b, h * n_t, ps]);
self.proj.forward(x)
}
}