use burn::prelude::*;
use burn::module::{Param, ParamId};
use burn::nn::{Embedding, EmbeddingConfig};
use crate::model::patch_embed::PatchEmbed;
use crate::model::transformer_block::TransformerBlock;
use crate::model::norm::OsfLayerNorm;
#[derive(Module, Debug)]
pub struct OsfViT<B: Backend> {
pub patch_embed: PatchEmbed<B>,
pub cls_token: Param<Tensor<B, 3>>,
pub pos_embedding: Param<Tensor<B, 3>>,
pub lead_emb: Option<Embedding<B>>,
pub blocks: Vec<TransformerBlock<B>>,
pub norm: OsfLayerNorm<B>,
pub width: usize,
pub depth: usize,
pub lead_wise: usize,
}
impl<B: Backend> OsfViT<B> {
pub fn new(
num_leads: usize,
seq_len: usize,
patch_size_time: usize,
patch_size_ch: usize,
lead_wise: usize,
width: usize,
depth: usize,
mlp_dim: usize,
heads: usize,
dim_head: usize,
device: &B::Device,
) -> Self {
let num_patches_time = seq_len / patch_size_time;
let n_max = if lead_wise == 0 {
num_patches_time
} else {
let lr = num_leads / patch_size_ch;
lr * num_patches_time
};
let patch_embed = PatchEmbed::new(
num_leads, width, patch_size_time, patch_size_ch, lead_wise, device,
);
let cls_token = Param::initialized(
ParamId::new(),
Tensor::zeros([1, 1, width], device),
);
let pos_embedding = Param::initialized(
ParamId::new(),
Tensor::zeros([1, n_max + 1, width], device),
);
let lead_emb = if lead_wise != 0 {
let lr = num_leads / patch_size_ch;
Some(EmbeddingConfig::new(lr, width).init(device))
} else {
None
};
let blocks = (0..depth)
.map(|_| TransformerBlock::new(
width, width, mlp_dim, heads, dim_head, true, device,
))
.collect();
let norm = OsfLayerNorm::new(width, 1e-5, device);
Self {
patch_embed,
cls_token,
pos_embedding,
lead_emb,
blocks,
norm,
width,
depth,
lead_wise,
}
}
pub fn forward_encoding(&self, series: Tensor<B, 3>) -> (Tensor<B, 3>, Tensor<B, 3>) {
let tokens = self.patch_embed.forward(series); let b = tokens.dims()[0];
let cls_tok = self.cls_token.val().expand([b, 1, self.width]);
let x = Tensor::cat(vec![cls_tok, tokens], 1);
let pe = self.pos_embedding.val()
.narrow(1, 0, x.dims()[1])
.to_device(&x.device());
let mut x = x + pe;
for block in &self.blocks {
x = block.forward(x);
}
x = self.norm.forward(x);
let n_plus_one = x.dims()[1];
let cls = x.clone().narrow(1, 0, 1); let patches = x.narrow(1, 1, n_plus_one - 1);
(cls, patches)
}
pub fn forward(&self, series: Tensor<B, 3>) -> Tensor<B, 3> {
let (cls, _) = self.forward_encoding(series);
cls
}
pub fn forward_avg_pool(&self, series: Tensor<B, 3>) -> Tensor<B, 3> {
let (_, patches) = self.forward_encoding(series);
patches.mean_dim(1)
}
}