use burn::prelude::*;
use burn::nn::Linear;
use crate::config::ModelConfig;
use super::linear_zeros;
use super::encoder::EEGEncoder;
#[derive(Module, Debug)]
pub struct ClassificationModel<B: Backend> {
pub encoder: EEGEncoder<B>,
pub full_linear: Linear<B>,
pub channel_linear: Linear<B>,
pub cls_fc1: Linear<B>,
pub cls_fc2: Linear<B>,
pub cls_fc3: Linear<B>,
pub feature_size: usize,
pub num_global_tokens: usize,
}
impl<B: Backend> ClassificationModel<B> {
pub fn new(cfg: &ModelConfig, num_classes: usize, device: &B::Device) -> Self {
let d = cfg.feature_size;
Self {
encoder: EEGEncoder::new(cfg, device),
full_linear: linear_zeros::<B>(d, d, true, device),
channel_linear: linear_zeros::<B>(d, d, true, device),
cls_fc1: linear_zeros::<B>(d, d / 2, true, device),
cls_fc2: linear_zeros::<B>(d / 2, d / 4, true, device),
cls_fc3: linear_zeros::<B>(d / 4, num_classes, true, device),
feature_size: d,
num_global_tokens: cfg.num_global_tokens,
}
}
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 2> {
let [bs, ch, seq_len, _feat] = x.dims();
let d = self.feature_size;
let features = self.encoder.forward(x);
let total_seq = features.dims()[1];
let tokens = features.slice([0..bs, self.num_global_tokens..total_seq]);
let flat = tokens.reshape([bs * ch * seq_len, d]);
let processed = burn::tensor::activation::gelu(self.full_linear.forward(flat));
let reshaped = processed.reshape([bs, ch, seq_len, d]);
let channel_pooled = reshaped.mean_dim(1);
let flat = channel_pooled.reshape([bs * seq_len, d]);
let processed = burn::tensor::activation::gelu(self.channel_linear.forward(flat));
let processed = processed.reshape([bs, seq_len, d]);
let time_pooled = processed.mean_dim(1).reshape([bs, d]);
let h = burn::tensor::activation::gelu(self.cls_fc1.forward(time_pooled));
let h = burn::tensor::activation::gelu(self.cls_fc2.forward(h));
self.cls_fc3.forward(h)
}
}