use crate::layers::drop::drop_path::{DropPath, DropPathConfig};
use crate::models::swin::v2::window_attention::{
WindowAttention, WindowAttentionConfig, WindowAttentionMeta, sw_attn_mask,
};
use crate::models::swin::v2::windowing::{window_partition, window_reverse};
use bimm_contracts::{assert_shape_contract_periodically, define_shape_contract};
use burn::config::Config;
use burn::module::Module;
use burn::nn::activation::{Activation, ActivationConfig};
use burn::nn::{Dropout, DropoutConfig, LayerNorm, LayerNormConfig, Linear, LinearConfig};
use burn::prelude::{Backend, Tensor};
use burn::tensor::BasicOps;
pub trait BlockMlpMeta {
fn d_input(&self) -> usize;
fn d_hidden(&self) -> usize;
fn d_output(&self) -> usize;
fn drop(&self) -> f64;
}
#[derive(Config, Debug)]
pub struct BlockMlpConfig {
d_input: usize,
#[config(default = "None")]
d_hidden: Option<usize>,
#[config(default = "None")]
d_output: Option<usize>,
#[config(default = 0.)]
drop: f64,
#[config(default = "ActivationConfig::Relu")]
pub activation: ActivationConfig,
}
impl BlockMlpMeta for BlockMlpConfig {
fn d_input(&self) -> usize {
self.d_input
}
fn d_hidden(&self) -> usize {
self.d_hidden.unwrap_or(self.d_input)
}
fn d_output(&self) -> usize {
self.d_output.unwrap_or(self.d_input)
}
fn drop(&self) -> f64 {
self.drop
}
}
impl BlockMlpConfig {
pub fn init<B: Backend>(
&self,
device: &B::Device,
) -> BlockMlp<B> {
let d_input = self.d_input();
let d_hidden = self.d_hidden();
let d_output = self.d_output();
BlockMlp {
fc1: LinearConfig::new(d_input, d_hidden).init(device),
fc2: LinearConfig::new(d_hidden, d_output).init(device),
act: self.activation.init(device),
drop: DropoutConfig { prob: self.drop }.init(),
}
}
}
#[derive(Module, Debug)]
pub struct BlockMlp<B: Backend> {
fc1: Linear<B>,
fc2: Linear<B>,
act: Activation<B>,
drop: Dropout,
}
impl<B: Backend> BlockMlpMeta for BlockMlp<B> {
fn d_input(&self) -> usize {
self.fc1.weight.dims()[0]
}
fn d_hidden(&self) -> usize {
self.fc1.weight.dims()[1]
}
fn d_output(&self) -> usize {
self.fc2.weight.dims()[1]
}
fn drop(&self) -> f64 {
self.drop.prob
}
}
impl<B: Backend> BlockMlp<B> {
#[must_use]
pub fn forward<const D: usize>(
&self,
x: Tensor<B, D>,
) -> Tensor<B, D> {
assert_shape_contract_periodically!(
[..., "in"],
&x.dims(),
&[("in", self.d_input())]
);
let x = self.fc1.forward(x);
assert_shape_contract_periodically!(
[..., "h"],
&x.dims(),
&[("h", self.d_hidden())]
);
let x = self.act.forward(x);
let x = self.drop.forward(x);
let x = self.fc2.forward(x);
assert_shape_contract_periodically!(
[..., "out"],
&x.dims(),
&[("out", self.d_output())]
);
self.drop.forward(x)
}
}
#[must_use]
#[inline(always)]
fn with_shift<B: Backend, F, K>(
x: Tensor<B, 4, K>,
shift: isize,
f: F,
) -> Tensor<B, 4, K>
where
K: BasicOps<B>,
F: FnOnce(Tensor<B, 4, K>) -> Tensor<B, 4, K>,
{
let dims = [1, 2];
let x = if shift != 0 {
x.roll(&dims, &[-shift, -shift])
} else {
x
};
let x = f(x);
if shift != 0 {
x.roll(&dims, &[shift, shift])
} else {
x
}
}
pub trait ShiftedWindowTransformerBlockMeta {
fn d_input(&self) -> usize;
fn input_resolution(&self) -> [usize; 2];
fn input_height(&self) -> usize {
self.input_resolution()[0]
}
fn input_width(&self) -> usize {
self.input_resolution()[1]
}
fn d_output(&self) -> usize {
self.d_input()
}
fn output_resolution(&self) -> [usize; 2] {
self.input_resolution()
}
fn output_height(&self) -> usize {
self.input_height()
}
fn output_width(&self) -> usize {
self.input_width()
}
fn num_heads(&self) -> usize;
fn window_size(&self) -> usize;
fn shift_size(&self) -> usize;
fn swa_enabled(&self) -> bool {
self.shift_size() > 0
}
fn enable_qkv_bias(&self) -> bool;
fn drop_rate(&self) -> f64;
fn attn_drop_rate(&self) -> f64;
fn mlp_ratio(&self) -> f64;
fn drop_path_rate(&self) -> f64;
}
#[derive(Config, Debug)]
pub struct ShiftedWindowTransformerBlockConfig {
pub d_input: usize,
pub input_resolution: [usize; 2],
pub num_heads: usize,
#[config(default = 7)]
pub window_size: usize,
#[config(default = 0)]
pub shift_size: usize,
#[config(default = 4.0)]
pub mlp_ratio: f64,
#[config(default = true)]
pub enable_qkv_bias: bool,
#[config(default = 0.0)]
pub drop_rate: f64,
#[config(default = 0.0)]
pub attn_drop_rate: f64,
#[config(default = 512)]
pub attn_rpb_mlp_hidden_dim: usize,
#[config(default = "ActivationConfig::Relu")]
pub attn_rpb_mlp_activation: ActivationConfig,
#[config(default = 0.0)]
pub drop_path_rate: f64,
}
impl ShiftedWindowTransformerBlockMeta for ShiftedWindowTransformerBlockConfig {
fn d_input(&self) -> usize {
self.d_input
}
fn input_resolution(&self) -> [usize; 2] {
self.input_resolution
}
fn num_heads(&self) -> usize {
self.num_heads
}
fn window_size(&self) -> usize {
self.window_size
}
fn shift_size(&self) -> usize {
self.shift_size
}
fn enable_qkv_bias(&self) -> bool {
self.enable_qkv_bias
}
fn drop_rate(&self) -> f64 {
self.drop_rate
}
fn attn_drop_rate(&self) -> f64 {
self.attn_drop_rate
}
fn mlp_ratio(&self) -> f64 {
self.mlp_ratio
}
fn drop_path_rate(&self) -> f64 {
self.drop_path_rate
}
}
impl ShiftedWindowTransformerBlockConfig {
#[inline(always)]
fn check(&self) {
assert!(
self.d_input > 0,
"d_input must be greater than zero: {self:#?}"
);
assert!(
self.num_heads > 0,
"num_heads must be greater than zero: {self:#?}"
);
assert!(
self.window_size > 0,
"window_size must be greater than zero: {self:#?}"
);
let [h, w] = self.input_resolution;
assert!(
h > 0 && w > 0,
"input_resolution must be greater than zero: {self:#?}"
);
assert!(
h % self.window_size == 0 && w % self.window_size == 0,
"input_resolution must be divisible by window size: {self:#?}",
);
}
#[must_use]
pub fn init<B: Backend>(
&self,
device: &B::Device,
) -> ShiftedWindowTransformerBlock<B> {
self.check();
let hidden_dim = (self.d_input as f64 * self.mlp_ratio) as usize;
let block_mlp = BlockMlpConfig::new(self.d_input)
.with_d_hidden(Some(hidden_dim))
.with_drop(self.drop_rate)
.init(device);
let win_attn = WindowAttentionConfig::new(
self.d_input,
[self.window_size, self.window_size],
self.num_heads,
)
.with_enable_qkv_bias(self.enable_qkv_bias)
.with_attn_drop(self.attn_drop_rate)
.with_proj_drop(self.drop_rate)
.with_rpb_mlp_hidden_dim(self.attn_rpb_mlp_hidden_dim)
.with_rpb_mlp_activation(self.attn_rpb_mlp_activation.clone())
.init(device);
let shift_mask = if self.shift_size == 0 {
None
} else {
Some(
sw_attn_mask(
self.input_resolution,
self.window_size,
self.shift_size,
device,
)
.float()
.mul_scalar(-100.0),
)
};
ShiftedWindowTransformerBlock {
input_resolution: self.input_resolution,
window_size: self.window_size,
shift_size: self.shift_size,
shift_mask,
drop_path: DropPathConfig::new()
.with_drop_prob(self.drop_path_rate)
.init(),
norm1: LayerNormConfig::new(self.d_input).init(device),
norm2: LayerNormConfig::new(self.d_input).init(device),
win_attn,
block_mlp,
}
}
}
#[derive(Module, Debug)]
pub struct ShiftedWindowTransformerBlock<B: Backend> {
pub input_resolution: [usize; 2],
pub window_size: usize,
pub shift_size: usize,
pub shift_mask: Option<Tensor<B, 3>>,
pub drop_path: DropPath,
pub norm1: LayerNorm<B>,
pub norm2: LayerNorm<B>,
pub win_attn: WindowAttention<B>,
pub block_mlp: BlockMlp<B>,
}
impl<B: Backend> ShiftedWindowTransformerBlockMeta for ShiftedWindowTransformerBlock<B> {
fn d_input(&self) -> usize {
self.win_attn.d_input()
}
fn input_resolution(&self) -> [usize; 2] {
self.input_resolution
}
fn num_heads(&self) -> usize {
self.win_attn.num_heads()
}
fn window_size(&self) -> usize {
self.window_size
}
fn shift_size(&self) -> usize {
self.shift_size
}
fn enable_qkv_bias(&self) -> bool {
self.win_attn.enable_qkv_bias()
}
fn drop_rate(&self) -> f64 {
self.block_mlp.drop()
}
fn attn_drop_rate(&self) -> f64 {
self.win_attn.attn_drop()
}
fn mlp_ratio(&self) -> f64 {
self.block_mlp.d_hidden() as f64 / self.d_input() as f64
}
fn drop_path_rate(&self) -> f64 {
self.drop_path.drop_prob
}
}
impl<B: Backend> ShiftedWindowTransformerBlock<B> {
#[must_use]
pub fn forward(
&self,
x: Tensor<B, 3>,
) -> Tensor<B, 3> {
let [h, w] = self.input_resolution;
let env = [("height", h), ("width", w)];
define_shape_contract!(CONTRACT, ["batch", "height" * "width", "channels"]);
let [b, c] = CONTRACT.unpack_shape(&x.dims(), &["batch", "channels"], &env);
let x = self.with_skip(x, |x| {
let x = x.reshape([b, h, w, c]);
let x = with_shift(x, self.shift_size as isize, |x| self.apply_window(x, c));
let x = x.reshape([b, h * w, c]);
self.norm1.forward(x)
});
assert_shape_contract_periodically!(CONTRACT, &x.dims(), &env);
let x = self.with_skip(x, |x| self.norm2.forward(self.block_mlp.forward(x)));
assert_shape_contract_periodically!(CONTRACT, &x.dims(), &env);
x
}
#[must_use]
#[inline(always)]
fn with_skip<const D: usize, F>(
&self,
x: Tensor<B, D>,
f: F,
) -> Tensor<B, D>
where
F: FnOnce(Tensor<B, D>) -> Tensor<B, D>,
{
self.drop_path.with_skip(x, f)
}
#[must_use]
#[inline(always)]
fn apply_window(
&self,
x: Tensor<B, 4>,
c: usize,
) -> Tensor<B, 4> {
let [h, w] = self.input_resolution;
let ws = self.window_size as i32;
let c = c as i32;
let x_windows = window_partition(x, self.window_size);
let x_windows = x_windows.reshape([-1, ws * ws, c]);
let attn_windows = self.win_attn.forward(x_windows, self.shift_mask.clone());
let attn_windows = attn_windows.reshape([-1, ws, ws, c]);
window_reverse(attn_windows, self.window_size, h, w)
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
use burn::tensor::Distribution;
#[test]
fn test_block_mlp_meta() {
{
let d_input = 4;
let config = BlockMlpConfig::new(d_input);
assert_eq!(config.d_input(), d_input);
assert_eq!(config.d_hidden(), d_input);
assert_eq!(config.d_output(), d_input);
assert_eq!(config.drop(), 0.);
}
{
let d_input = 4;
let d_hidden = 8;
let d_output = 6;
let drop = 0.1;
let config = BlockMlpConfig::new(d_input)
.with_d_hidden(Some(d_hidden))
.with_d_output(Some(d_output))
.with_drop(drop);
assert_eq!(config.d_input(), d_input);
assert_eq!(config.d_hidden(), d_hidden);
assert_eq!(config.d_output(), d_output);
assert_eq!(config.drop(), drop);
}
}
#[test]
fn test_mlp() {
impl_test_mlp::<NdArray>();
}
fn impl_test_mlp<B: Backend>() {
let device: B::Device = Default::default();
let a = 2;
let b = 3;
let d_input = 4;
let d_hidden = 8;
let d_output = 6;
let drop = 0.1;
let config = BlockMlpConfig::new(d_input)
.with_d_hidden(Some(d_hidden))
.with_d_output(Some(d_output))
.with_drop(drop);
let mlp: BlockMlp<B> = config.init(&device);
assert_eq!(mlp.d_input(), config.d_input());
assert_eq!(mlp.d_hidden(), config.d_hidden());
assert_eq!(mlp.d_output(), config.d_output());
assert_eq!(mlp.drop(), config.drop());
let distribution = Distribution::Normal(0., 1.);
let x = Tensor::random([a, b, d_input], distribution, &device);
let y = mlp.forward(x);
assert_eq!(y.dims(), [a, b, d_output]);
}
#[test]
fn test_with_shift() {
let device = Default::default();
let b = 1;
let h = 4;
let w = 4;
let c = 3;
let distribution = burn::tensor::Distribution::Uniform(0.0, 1.0);
let input = Tensor::<NdArray, 4>::random([b, h, w, c], distribution, &device);
let idx: Tensor<NdArray, 4> =
Tensor::arange(0..input.shape().num_elements() as i64, &device)
.reshape([b, h, w, c])
.float();
with_shift(input.clone(), 0, |x| x + idx.clone())
.to_data()
.assert_eq(&(input.clone() + idx.clone()).to_data(), true);
with_shift(input.clone(), 1, |x| x + idx.clone())
.to_data()
.assert_eq(
&({
let x = input.clone();
let x = x.roll(&[1, 2], &[-1, -1]);
let x = x + idx.clone();
x.roll(&[1, 2], &[1, 1])
})
.to_data(),
true,
);
}
#[test]
fn test_shifted_window_transformer_block_meta() {
let d_input = 128;
let num_heads = 4;
let input_resolution = [14, 14];
let config = ShiftedWindowTransformerBlockConfig::new(d_input, input_resolution, num_heads);
assert_eq!(config.d_input(), d_input);
assert_eq!(config.input_resolution(), input_resolution);
assert_eq!(config.input_height(), 14);
assert_eq!(config.input_width(), 14);
assert_eq!(config.d_output(), d_input);
assert_eq!(config.output_resolution(), input_resolution);
assert_eq!(config.output_height(), 14);
assert_eq!(config.output_width(), 14);
assert_eq!(config.num_heads(), num_heads);
assert_eq!(config.window_size(), 7);
assert_eq!(config.shift_size(), 0);
assert!(!config.swa_enabled());
assert!(config.enable_qkv_bias());
assert_eq!(config.drop_rate(), 0.0);
assert_eq!(config.attn_drop_rate(), 0.0);
assert_eq!(config.mlp_ratio(), 4.0);
assert_eq!(config.drop_path_rate(), 0.0);
let device = Default::default();
let block = config.init::<NdArray>(&device);
assert_eq!(block.d_input(), d_input);
assert_eq!(block.input_resolution(), input_resolution);
assert_eq!(block.input_height(), 14);
assert_eq!(block.input_width(), 14);
assert_eq!(block.d_output(), d_input);
assert_eq!(block.output_resolution(), input_resolution);
assert_eq!(block.output_height(), 14);
assert_eq!(block.output_width(), 14);
assert_eq!(block.num_heads(), num_heads);
assert_eq!(block.window_size(), 7);
assert_eq!(block.shift_size(), 0);
assert!(!block.swa_enabled());
assert!(block.enable_qkv_bias());
assert_eq!(block.drop_rate(), 0.0);
assert_eq!(block.attn_drop_rate(), 0.0);
assert_eq!(block.mlp_ratio(), 4.0);
assert_eq!(block.drop_path_rate(), 0.0);
}
#[should_panic(expected = "input_resolution must be greater than zero")]
#[test]
fn test_shifted_window_transformer_block_config_zero_resolution() {
let d_input = 128;
let num_heads = 4;
let input_resolution = [0, 14];
let config = ShiftedWindowTransformerBlockConfig::new(d_input, input_resolution, num_heads);
let _d = config.init::<NdArray>(&Default::default());
}
#[should_panic(expected = "input_resolution must be divisible by window size")]
#[test]
fn test_shifted_window_transformer_block_config_invalid_resolution() {
let d_input = 128;
let num_heads = 4;
let input_resolution = [15, 14];
let config = ShiftedWindowTransformerBlockConfig::new(d_input, input_resolution, num_heads);
let _d = config.init::<NdArray>(&Default::default());
}
#[should_panic(expected = "d_input must be greater than zero")]
#[test]
fn test_shifted_window_transformer_block_config_zero_d_input() {
let d_input = 0; let num_heads = 4;
let input_resolution = [14, 14];
let config = ShiftedWindowTransformerBlockConfig::new(d_input, input_resolution, num_heads);
let _d = config.init::<NdArray>(&Default::default());
}
#[should_panic(expected = "num_heads must be greater than zero")]
#[test]
fn test_shifted_window_transformer_block_config_zero_num_heads() {
let d_input = 128;
let num_heads = 0; let input_resolution = [14, 14];
let config = ShiftedWindowTransformerBlockConfig::new(d_input, input_resolution, num_heads);
let _d = config.init::<NdArray>(&Default::default());
}
#[should_panic(expected = "window_size must be greater than zero")]
#[test]
fn test_shifted_window_transformer_block_config_zero_window_size() {
let d_input = 128;
let num_heads = 4;
let input_resolution = [14, 14];
let window_size = 0;
let config = ShiftedWindowTransformerBlockConfig::new(d_input, input_resolution, num_heads)
.with_window_size(window_size);
let _d = config.init::<NdArray>(&Default::default());
}
#[test]
fn test_block() {
let b = 1;
let num_heads = 4;
let channels_per_head = 3;
let d_input = num_heads * channels_per_head;
let window_size = 4;
let h = 2 * window_size;
let w = 3 * window_size;
let input_resolution = [h, w];
let config = ShiftedWindowTransformerBlockConfig::new(d_input, input_resolution, num_heads)
.with_window_size(window_size);
let device = Default::default();
let block = config.init::<NdArray>(&device);
let distribution = burn::tensor::Distribution::Uniform(0.0, 1.0);
let input = Tensor::<NdArray, 3>::random([b, h * w, d_input], distribution, &device);
let _output = block.forward(input.clone());
assert_eq!(input.dims(), _output.dims());
}
}