use burn::prelude::*;
use burn::nn::{Linear, LinearConfig};
use crate::model::transformer::EEGTransformer;
#[derive(Module, Debug)]
pub struct EEGPT<B: Backend> {
pub target_encoder: EEGTransformer<B>,
pub probe1: Linear<B>,
pub probe2: Linear<B>,
pub n_outputs: usize,
pub embed_num: usize,
pub embed_dim: usize,
pub n_patches: usize,
pub probe_hidden_dim: usize,
}
impl<B: Backend> EEGPT<B> {
pub fn new(
n_outputs: usize, n_chans: usize, n_times: usize,
patch_size: usize, patch_stride: usize,
embed_num: usize, embed_dim: usize,
depth: usize, n_heads: usize, mlp_ratio: f64,
qkv_bias: bool, n_chan_embeddings: usize,
probe_hidden_dim: usize, eps: f64,
device: &B::Device,
) -> Self {
let target_encoder = EEGTransformer::new(
n_chans, n_times, patch_size, patch_stride,
embed_num, embed_dim, depth, n_heads, mlp_ratio,
qkv_bias, n_chan_embeddings, eps, device,
);
let n_patches = target_encoder.n_patches;
let probe1 = LinearConfig::new(embed_num * embed_dim, probe_hidden_dim).with_bias(true).init(device);
let probe2 = LinearConfig::new(n_patches * probe_hidden_dim, n_outputs).with_bias(true).init(device);
Self { target_encoder, probe1, probe2, n_outputs, embed_num, embed_dim, n_patches, probe_hidden_dim }
}
pub fn forward(&self, x: Tensor<B, 3>, chan_ids: Tensor<B, 2, Int>) -> Tensor<B, 2> {
let [batch, _, _] = x.dims();
let z = self.target_encoder.forward(x, chan_ids);
let z = z.reshape([batch, self.n_patches, self.embed_num * self.embed_dim]);
let h = self.probe1.forward(z);
let h = h.reshape([batch, self.n_patches * self.probe_hidden_dim]);
self.probe2.forward(h)
}
}