use burn::module::{Param, ParamId};
use burn::prelude::*;
use crate::model::block::Block;
use crate::model::norm::LNorm;
use crate::model::patch_embed::FlexiPatchEmbed;
use crate::model::pos_embed::BrainHarmonyPosEmbed;
#[derive(Module, Debug)]
pub struct FlexVisionTransformer<B: Backend> {
pub patch_embed: FlexiPatchEmbed<B>,
pub pos_embed: BrainHarmonyPosEmbed<B>,
pub cls_token: Option<Param<Tensor<B, 3>>>,
pub blocks: Vec<Block<B>>,
pub norm: LNorm<B>,
pub embed_dim: usize,
pub num_heads: usize,
}
impl<B: Backend> FlexVisionTransformer<B> {
pub fn new(
signal_size: (usize, usize),
patch_size: usize,
in_chans: usize,
embed_dim: usize,
depth: usize,
num_heads: usize,
mlp_ratio: f64,
qkv_bias: bool,
norm_eps: f64,
grad_dim: usize,
geoh_dim: usize,
pred_embed_dim: usize,
pos_mode: &str,
use_cls_token: bool,
use_decoder: bool,
device: &B::Device,
) -> crate::error::Result<Self> {
let patch_embed =
FlexiPatchEmbed::new(signal_size, patch_size, in_chans, embed_dim, device);
let grid_size = patch_embed.num_patches_2d;
let pos_embed = BrainHarmonyPosEmbed::new(
grad_dim,
geoh_dim,
embed_dim,
pred_embed_dim,
grid_size,
pos_mode,
use_cls_token,
use_decoder,
device,
)?;
let cls_token = if use_cls_token {
Some(Param::initialized(
ParamId::new(),
Tensor::zeros([1, 1, embed_dim], device),
))
} else {
None
};
let blocks = (0..depth)
.map(|_| Block::new(embed_dim, num_heads, mlp_ratio, qkv_bias, norm_eps, device))
.collect();
let norm = LNorm::new(embed_dim, norm_eps, device);
Ok(Self {
patch_embed,
pos_embed,
cls_token,
blocks,
norm,
embed_dim,
num_heads,
})
}
pub fn forward(
&self,
x: Tensor<B, 4>,
gradient: Option<&Tensor<B, 2>>,
geoh: Option<&Tensor<B, 2>>,
patch_size: Option<usize>,
masks: Option<&[Tensor<B, 2, Int>]>,
attn_mask: Option<&Tensor<B, 2>>,
) -> Tensor<B, 3> {
let mut x = self.patch_embed.forward(x, patch_size);
let (pos_emb_enc, _pos_emb_dec) = self.pos_embed.forward(gradient, geoh);
if self.cls_token.is_some() {
let pos_patches = pos_emb_enc.clone().narrow(1, 1, x.dims()[1]);
x = x + pos_patches;
if let Some(mask_list) = masks {
x = apply_masks(x, mask_list);
}
let cls = self.cls_token.as_ref().unwrap().val();
let cls_pos = pos_emb_enc.narrow(1, 0, 1);
let cls_with_pos = cls + cls_pos;
let cls_expanded = cls_with_pos.expand([x.dims()[0], 1, self.embed_dim]);
x = Tensor::cat(vec![cls_expanded, x], 1);
} else {
x = x + pos_emb_enc;
if let Some(mask_list) = masks {
x = apply_masks(x, mask_list);
}
}
for block in &self.blocks {
x = block.forward(x, attn_mask);
}
self.norm.forward(x)
}
}
pub fn apply_masks<B: Backend>(
x: Tensor<B, 3>,
masks: &[Tensor<B, 2, Int>],
) -> Tensor<B, 3> {
let [_b, _n, d] = x.dims();
let parts: Vec<Tensor<B, 3>> = masks
.iter()
.map(|m| {
let [b_m, k] = m.dims();
let mask_exp = m.clone().unsqueeze_dim::<3>(2).expand([b_m, k, d]);
x.clone().gather(1, mask_exp)
})
.collect();
Tensor::cat(parts, 0)
}