use burn::{
config::Config,
module::{Module, Param},
nn::{
Dropout, DropoutConfig, Gelu, Initializer, LayerNorm, LayerNormConfig, Linear,
LinearConfig,
conv::{Conv2d, Conv2dConfig},
},
tensor::{Bool, FloatDType, Tensor, backend::Backend, module, ops::AttentionModuleOptions},
};
#[derive(Module, Debug)]
pub struct PatchEmbed<B: Backend> {
pub proj: Conv2d<B>,
pub patch_size: usize,
pub embed_dim: usize,
}
impl<B: Backend> PatchEmbed<B> {
pub fn new(in_chans: usize, embed_dim: usize, patch_size: usize, device: &B::Device) -> Self {
let proj = Conv2dConfig::new([in_chans, embed_dim], [patch_size, patch_size])
.with_stride([patch_size, patch_size])
.init(device);
Self {
proj,
patch_size,
embed_dim,
}
}
pub fn forward(&self, x: Tensor<B, 4>) -> (Tensor<B, 3>, usize, usize) {
let x = self.proj.forward(x);
let [_, _, height, width] = x.dims();
let x = x.flatten(2, -1);
(x.swap_dims(1, 2), height, width)
}
}
#[derive(Module, Debug)]
pub struct RopePositionEmbedding<B: Backend> {
pub periods: Param<Tensor<B, 1>>,
pub d_head: usize,
}
impl<B: Backend> RopePositionEmbedding<B> {
pub fn new(embed_dim: usize, num_heads: usize, base: f32, device: &B::Device) -> Self {
let d_head = embed_dim / num_heads;
let periods = Param::from_tensor(
Tensor::from_floats([base], device)
.powf(Tensor::arange_step(0..d_head as i64, 4, device).float() / d_head as f32),
)
.no_grad();
Self { periods, d_head }
}
pub fn forward(&self, height: usize, width: usize) -> (Tensor<B, 2>, Tensor<B, 2>) {
let device = self.periods.device();
let coords_h = (Tensor::arange(0..height as i64, &device).float() + 0.5) / (height as f32);
let coords_w = (Tensor::arange(0..width as i64, &device).float() + 0.5) / (width as f32);
let ch = coords_h
.unsqueeze_dim::<2>(1) .repeat_dim(1, width) .reshape([-1, 1]); let cw = coords_w
.unsqueeze::<2>() .repeat_dim(0, height) .reshape([-1, 1]); let mut coords = Tensor::cat(vec![ch, cw], 1); coords = coords * 2.0 - 1.0;
let angles = coords.unsqueeze_dim::<3>(2) * std::f64::consts::PI * 2.0
/ self
.periods
.val()
.cast(FloatDType::F32) .unsqueeze::<3>();
let angles = angles.flatten(1, 2); let angles_tiled = Tensor::cat(vec![angles.clone(), angles], 1);
let sin = angles_tiled.clone().sin();
let cos = angles_tiled.cos();
(sin, cos)
}
}
#[derive(Module, Debug)]
pub struct LinearKMaskedBias<B: Backend> {
pub linear: Linear<B>,
pub bias_mask: Param<Tensor<B, 1>>,
}
impl<B: Backend> LinearKMaskedBias<B> {
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
let masked_bias = self
.linear
.bias
.as_ref()
.map(|b| b.val() * self.bias_mask.val());
module::linear(input, self.linear.weight.val(), masked_bias)
}
}
#[derive(Module, Debug)]
pub struct LayerScale<B: Backend> {
pub gamma: Param<Tensor<B, 1>>,
}
impl<B: Backend> LayerScale<B> {
pub fn new(dim: usize, init_values: f32, device: &B::Device) -> Self {
let gamma = Param::from_tensor(Tensor::ones([dim], device) * init_values);
Self { gamma }
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let dim = x.shape()[2];
x * self.gamma.val().reshape([1, 1, dim])
}
}
#[derive(Module, Debug)]
pub struct LoRA<B: Backend> {
pub a: Param<Tensor<B, 2>>,
pub b: Param<Tensor<B, 2>>,
}
#[derive(Config, Debug)]
pub struct LoRAConfig {
pub dim: usize,
pub rank: usize,
#[config(default = "Initializer::Zeros")]
pub a_initializer: Initializer,
#[config(default = "Initializer::KaimingUniform{gain:1.0/3.0f64.sqrt(), fan_out_only:false}")]
pub b_initializer: Initializer,
}
impl LoRAConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> LoRA<B> {
LoRA {
a: self.a_initializer.init([self.dim, self.rank], device),
b: self.b_initializer.init([self.rank, self.dim], device),
}
}
}
impl<B: Backend> LoRA<B> {
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
x.matmul(self.a.val().unsqueeze())
.matmul(self.b.val().unsqueeze())
}
}
#[derive(Module, Debug)]
pub struct Attention<B: Backend> {
pub qkv: LinearKMaskedBias<B>,
pub proj: Linear<B>,
pub drop_out: Dropout,
pub lora: Option<LoRA<B>>,
pub num_heads: usize,
}
impl<B: Backend> Attention<B> {
pub fn new(dim: usize, num_heads: usize, device: &B::Device) -> Self {
Self {
qkv: LinearKMaskedBias {
linear: LinearConfig::new(dim, dim * 3).with_bias(true).init(device),
bias_mask: Param::from_tensor(Tensor::zeros([dim * 3], device)),
},
proj: LinearConfig::new(dim, dim).with_bias(true).init(device),
drop_out: DropoutConfig::new(0.0).init(), lora: None,
num_heads,
}
}
pub fn forward(
&self,
x: Tensor<B, 3>,
repo: Option<&(Tensor<B, 2>, Tensor<B, 2>)>,
) -> Tensor<B, 3> {
let [batch_size, seq_len, dim] = x.dims();
let qkv = self.qkv.forward(x.clone());
let qkv = qkv.reshape([batch_size, seq_len, 3, self.num_heads, dim / self.num_heads]);
let [mut q, mut k, mut v]: [Tensor<B, 4>; 3] = qkv
.chunk(3, 2)
.into_iter()
.map(|tensor| tensor.squeeze_dim::<4>(2).swap_dims(1, 2))
.collect::<Vec<_>>() .try_into()
.unwrap();
if let Some(lora) = self.lora.as_ref() {
q = q + lora
.forward(x.clone())
.reshape([batch_size, seq_len, self.num_heads, dim / self.num_heads])
.swap_dims(1, 2);
v = v + lora
.forward(x)
.reshape([batch_size, seq_len, self.num_heads, dim / self.num_heads])
.swap_dims(1, 2);
}
if let Some((sin, cos)) = repo {
q = Self::apply_rope(q, sin, cos);
k = Self::apply_rope(k, sin, cos);
}
let out = module::attention(q, k, v, None, None, AttentionModuleOptions::default());
let out = out.swap_dims(1, 2).reshape([batch_size, seq_len, dim]);
self.drop_out.forward(self.proj.forward(out))
}
fn apply_rope(x: Tensor<B, 4>, sin: &Tensor<B, 2>, cos: &Tensor<B, 2>) -> Tensor<B, 4> {
let [_, _, seq, head_dim] = x.dims();
let [rope_seq, _h_dim] = sin.dims();
let num_cls_and_storage_tokens = seq - rope_seq;
let [prefix, mut rope] = x
.split_with_sizes(vec![num_cls_and_storage_tokens, rope_seq], 2)
.try_into()
.unwrap();
let half_head_dim = head_dim / 2;
let [x1, x2] = rope.clone().split(half_head_dim, 3).try_into().unwrap();
let x_half = Tensor::cat(vec![x2.mul_scalar(-1.0), x1], 3);
rope = (rope * cos.clone().reshape([1, 1, rope_seq, head_dim]))
+ (x_half * sin.clone().reshape([1, 1, rope_seq, head_dim]));
Tensor::cat(vec![prefix, rope], 2)
}
}
#[derive(Module, Debug)]
pub struct Mlp<B: Backend> {
pub fc1: Linear<B>,
pub act: Gelu,
pub fc2: Linear<B>,
}
impl<B: Backend> Mlp<B> {
pub fn new(in_features: usize, hidden_features: usize, device: &B::Device) -> Self {
Self {
fc1: LinearConfig::new(in_features, hidden_features).init(device),
act: Gelu::new(),
fc2: LinearConfig::new(hidden_features, in_features).init(device),
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let x = self.fc1.forward(x);
let x = self.act.forward(x);
self.fc2.forward(x)
}
}
#[derive(Module, Debug)]
pub struct Block<B: Backend> {
pub norm1: LayerNorm<B>,
pub attn: Attention<B>,
pub ls1: LayerScale<B>,
pub norm2: LayerNorm<B>,
pub mlp: Mlp<B>,
pub ls2: LayerScale<B>,
}
impl<B: Backend> Block<B> {
pub fn new(dim: usize, num_heads: usize, ffn_ratio: f64, device: &B::Device) -> Self {
let hidden_dim = (dim as f64 * ffn_ratio) as usize;
Self {
norm1: LayerNormConfig::new(dim).with_bias(true).init(device),
attn: Attention::new(dim, num_heads, device),
ls1: LayerScale::new(dim, 1e-5, device),
norm2: LayerNormConfig::new(dim).with_bias(true).init(device),
mlp: Mlp::new(dim, hidden_dim, device),
ls2: LayerScale::new(dim, 1e-5, device),
}
}
pub fn forward(
&self,
x: Tensor<B, 3>,
repo: Option<&(Tensor<B, 2>, Tensor<B, 2>)>,
) -> Tensor<B, 3> {
let x = self
.ls1
.forward(self.attn.forward(self.norm1.forward(x.clone()), repo))
+ x;
self.ls2
.forward(self.mlp.forward(self.norm2.forward(x.clone())))
+ x
}
}
#[derive(Module, Debug)]
pub struct DinoVisionTransformer<B: Backend> {
pub patch_embed: PatchEmbed<B>,
pub cls_token: Param<Tensor<B, 3>>,
pub storage_tokens: Param<Tensor<B, 3>>,
pub rope_embed: RopePositionEmbedding<B>,
pub blocks: Vec<Block<B>>,
pub norm: LayerNorm<B>,
pub mask_token: Option<Param<Tensor<B, 2>>>,
}
impl<B: Backend> DinoVisionTransformer<B> {
pub fn new(
patch_size: usize,
embed_dim: usize,
num_storage_tokens: usize,
depth: usize,
num_heads: usize,
ffn_ratio: f64,
device: &B::Device,
) -> Self {
let patch_embed = PatchEmbed::new(3, embed_dim, patch_size, device);
let cls_token = Param::from_tensor(Tensor::zeros([1, 1, embed_dim], device));
let storage_tokens =
Param::from_tensor(Tensor::zeros([1, num_storage_tokens, embed_dim], device));
let rope_embed = RopePositionEmbedding::new(embed_dim, num_heads, 100.0, device).no_grad();
let blocks = vec![Block::new(embed_dim, num_heads, ffn_ratio, device); depth];
let norm = LayerNormConfig::new(embed_dim).init(device);
let mask_token = Param::from_tensor(Tensor::zeros([1, embed_dim], device));
Self {
patch_embed,
cls_token,
storage_tokens,
rope_embed,
blocks,
norm,
mask_token: Some(mask_token),
}
}
pub fn forward(&self, x: Tensor<B, 4>, masks: Option<&Tensor<B, 2, Bool>>) -> Tensor<B, 3> {
let (mut x, height, width) = self.patch_embed.forward(x);
let [batch_size, seq, dim] = x.dims();
if let Some(masks) = masks
&& let Some(mask_token) = self.mask_token.as_ref()
{
x = x.mask_where(
masks
.clone()
.reshape([1, seq, dim])
.repeat_dim(0, batch_size),
mask_token
.val()
.reshape([1, 1, dim])
.repeat_dim(1, seq)
.repeat_dim(0, batch_size),
);
}
let repo = self.rope_embed.forward(height, width);
let cls_token_batch = self.cls_token.val().repeat_dim(0, batch_size);
let storage_tokens_batch = self.storage_tokens.val().repeat_dim(0, batch_size);
x = Tensor::cat(vec![cls_token_batch, storage_tokens_batch, x], 1);
for block in &self.blocks {
x = block.forward(x, Some(&repo));
}
self.norm.forward(x)
}
}
pub fn vit_small<B: Backend>(patch_size: usize, device: &B::Device) -> DinoVisionTransformer<B> {
DinoVisionTransformer::new(patch_size, 384, 4, 12, 6, 4.0, device)
}
pub fn vit_base<B: Backend>(patch_size: usize, device: &B::Device) -> DinoVisionTransformer<B> {
DinoVisionTransformer::new(patch_size, 768, 4, 12, 12, 4.0, device)
}
pub fn vit_large<B: Backend>(patch_size: usize, device: &B::Device) -> DinoVisionTransformer<B> {
DinoVisionTransformer::new(patch_size, 1024, 4, 24, 16, 4.0, device)
}
pub fn vit_so400m<B: Backend>(patch_size: usize, device: &B::Device) -> DinoVisionTransformer<B> {
DinoVisionTransformer::new(patch_size, 1152, 4, 27, 18, 3.777777778, device)
}
pub fn vit_huge2<B: Backend>(patch_size: usize, device: &B::Device) -> DinoVisionTransformer<B> {
DinoVisionTransformer::new(patch_size, 1280, 4, 32, 20, 4.0, device)
}
pub fn vit_giant2<B: Backend>(patch_size: usize, device: &B::Device) -> DinoVisionTransformer<B> {
DinoVisionTransformer::new(patch_size, 1536, 4, 40, 24, 4.0, device)
}
pub fn vit_7b<B: Backend>(patch_size: usize, device: &B::Device) -> DinoVisionTransformer<B> {
DinoVisionTransformer::new(patch_size, 4096, 4, 40, 32, 3.0, device)
}