use burn::prelude::*;
#[allow(unused_imports)]
use burn::nn::{Linear, LinearConfig, LayerNorm, LayerNormConfig};
use crate::model::patch_embedding::PatchEmbedding;
use crate::model::transformer::TransformerEncoder;
#[derive(Module, Debug)]
pub struct CBraMod<B: Backend> {
pub patch_embedding: PatchEmbedding<B>,
pub encoder: TransformerEncoder<B>,
pub proj_out: Linear<B>,
pub final_ln: Option<LayerNorm<B>>,
pub final_linear: Linear<B>,
pub patch_size: usize,
pub n_outputs: usize,
pub n_chans: usize,
pub n_times: usize,
pub emb_dim: usize,
pub return_encoder_output: bool,
}
impl<B: Backend> CBraMod<B> {
pub fn new(
n_outputs: usize,
n_chans: usize,
n_times: usize,
patch_size: usize,
dim_feedforward: usize,
n_layer: usize,
nhead: usize,
emb_dim: usize,
return_encoder_output: bool,
device: &B::Device,
) -> Self {
let patch_embedding = PatchEmbedding::new(patch_size, device);
let d_model = patch_embedding.d_model;
let encoder = TransformerEncoder::new(d_model, nhead, dim_feedforward, n_layer, device);
let proj_out = LinearConfig::new(d_model, emb_dim).with_bias(true).init(device);
let n_patches = n_times / patch_size;
let final_dim = n_chans * n_patches * emb_dim;
let (final_ln, final_linear) = if return_encoder_output {
(None, LinearConfig::new(emb_dim, emb_dim).with_bias(true).init(device))
} else {
(None, LinearConfig::new(final_dim, n_outputs).with_bias(true).init(device))
};
Self {
patch_embedding, encoder, proj_out, final_ln, final_linear,
patch_size, n_outputs, n_chans, n_times, emb_dim, return_encoder_output,
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 2> {
let [batch, n_chans, n_times] = x.dims();
let n_patches = n_times / self.patch_size;
let x = x.reshape([batch, n_chans, n_patches, self.patch_size]);
let x = self.patch_embedding.forward(x);
let x = self.encoder.forward(x);
let x = self.proj_out.forward(x);
if self.return_encoder_output {
let flat_dim = n_chans * n_patches * self.emb_dim;
return x.reshape([batch, flat_dim]);
}
let flat_dim = n_chans * n_patches * self.emb_dim;
let x = x.reshape([batch, flat_dim]);
self.final_linear.forward(x)
}
}