use burn::module::{Param, ParamId};
use burn::nn::Linear;
use burn::prelude::*;
use crate::model::block::Block;
use crate::model::encoder::apply_masks;
use crate::model::norm::LNorm;
use crate::model::pos_embed::BrainHarmonyPosEmbed;
use crate::model::linear_zeros;
#[derive(Module, Debug)]
pub struct VisionTransformerPredictor<B: Backend> {
pub predictor_embed: Linear<B>,
pub mask_token: Param<Tensor<B, 3>>,
pub pos_embed: BrainHarmonyPosEmbed<B>,
pub predictor_blocks: Vec<Block<B>>,
pub predictor_norm: LNorm<B>,
pub predictor_proj: Linear<B>,
pub predictor_embed_dim: usize,
pub embed_dim: usize,
}
impl<B: Backend> VisionTransformerPredictor<B> {
pub fn new(
num_patches_2d: (usize, usize),
embed_dim: usize,
predictor_embed_dim: usize,
depth: usize,
num_heads: usize,
mlp_ratio: f64,
qkv_bias: bool,
norm_eps: f64,
grad_dim: usize,
geoh_dim: usize,
pos_mode: &str,
use_cls_token: bool,
device: &B::Device,
) -> crate::error::Result<Self> {
let predictor_embed =
linear_zeros(embed_dim, predictor_embed_dim, true, device);
let mask_token = Param::initialized(
ParamId::new(),
Tensor::zeros([1, 1, predictor_embed_dim], device),
);
let pos_embed = BrainHarmonyPosEmbed::new(
grad_dim,
geoh_dim,
predictor_embed_dim,
predictor_embed_dim,
num_patches_2d,
pos_mode,
use_cls_token,
false, device,
)?;
let predictor_blocks = (0..depth)
.map(|_| {
Block::new(
predictor_embed_dim,
num_heads,
mlp_ratio,
qkv_bias,
norm_eps,
device,
)
})
.collect();
let predictor_norm = LNorm::new(predictor_embed_dim, norm_eps, device);
let predictor_proj =
linear_zeros(predictor_embed_dim, embed_dim, true, device);
Ok(Self {
predictor_embed,
mask_token,
pos_embed,
predictor_blocks,
predictor_norm,
predictor_proj,
predictor_embed_dim,
embed_dim,
})
}
pub fn forward(
&self,
x: Tensor<B, 3>,
gradient: Option<&Tensor<B, 2>>,
geoh: Option<&Tensor<B, 2>>,
masks_x: &[Tensor<B, 2, Int>],
masks: &[Tensor<B, 2, Int>],
) -> Tensor<B, 3> {
let b = x.dims()[0] / masks_x.len();
let mut x = self.predictor_embed.forward(x);
let (pos_emb, _) = self.pos_embed.forward(gradient, geoh);
let [_, n_pos, d_pos] = pos_emb.dims();
let pos_emb_ctx = pos_emb.clone().expand([b, n_pos, d_pos]);
let ctx_pos = apply_masks(pos_emb_ctx, masks_x);
x = x + ctx_pos;
let [_, n_ctxt, _d] = x.dims();
let pos_emb_tgt = pos_emb.expand([b, n_pos, d_pos]);
let tgt_pos = apply_masks(pos_emb_tgt, masks);
let tgt_pos = repeat_interleave_batch(tgt_pos, b, masks_x.len());
let [n_total, n_pred, pred_dim] = tgt_pos.dims();
let pred_tokens = self.mask_token.val().expand([n_total, n_pred, pred_dim]);
let pred_tokens = pred_tokens + tgt_pos;
let x = x.repeat_dim(0, masks.len());
let x = Tensor::cat(vec![x, pred_tokens], 1);
let mut x = x;
for block in &self.predictor_blocks {
x = block.forward(x, None);
}
let x = self.predictor_norm.forward(x);
let x = x.narrow(1, n_ctxt, n_pred);
self.predictor_proj.forward(x)
}
}
fn repeat_interleave_batch<B: Backend>(
x: Tensor<B, 3>,
batch_size: usize,
repeat: usize,
) -> Tensor<B, 3> {
let n = x.dims()[0] / batch_size;
let parts: Vec<Tensor<B, 3>> = (0..n)
.flat_map(|i| {
let chunk = x.clone().narrow(0, i * batch_size, batch_size);
(0..repeat).map(move |_| chunk.clone())
})
.collect();
Tensor::cat(parts, 0)
}