use crate::models::swin::v2::window_attention::{
OffsetGridRelativePositionBias, RelativePositionBiasConfig, RelativePositionBiasMeta,
apply_attention_mask,
};
use bimm_contracts::{assert_shape_contract_periodically, unpack_shape_contract};
use burn::config::Config;
use burn::module::{Module, Param, ParamId};
use burn::nn::activation::ActivationConfig;
use burn::nn::{Dropout, DropoutConfig, Linear, LinearConfig};
use burn::prelude::{Backend, Tensor};
use burn::tensor::activation::softmax;
use burn::tensor::linalg::{Norm, vector_normalize};
pub const EPS: f64 = 1e-12;
pub trait WindowAttentionMeta {
fn d_input(&self) -> usize;
fn window_shape(&self) -> [usize; 2];
fn window_height(&self) -> usize {
self.window_shape()[0]
}
fn window_width(&self) -> usize {
self.window_shape()[1]
}
fn num_heads(&self) -> usize;
fn attn_drop(&self) -> f64;
fn proj_drop(&self) -> f64;
fn enable_qkv_bias(&self) -> bool;
}
impl WindowAttentionMeta for WindowAttentionConfig {
fn d_input(&self) -> usize {
self.d_input
}
fn window_shape(&self) -> [usize; 2] {
self.window_shape
}
fn num_heads(&self) -> usize {
self.num_heads
}
fn attn_drop(&self) -> f64 {
self.attn_drop
}
fn proj_drop(&self) -> f64 {
self.proj_drop
}
fn enable_qkv_bias(&self) -> bool {
self.enable_qkv_bias
}
}
impl<B: Backend> WindowAttentionMeta for WindowAttention<B> {
fn d_input(&self) -> usize {
self.d_input
}
fn window_shape(&self) -> [usize; 2] {
self.rpb_module.window_shape()
}
fn num_heads(&self) -> usize {
self.num_heads
}
fn attn_drop(&self) -> f64 {
self.attn_drop.prob
}
fn proj_drop(&self) -> f64 {
self.proj_drop.prob
}
fn enable_qkv_bias(&self) -> bool {
self.q_linear.bias.is_some()
}
}
#[cfg(test)]
mod tests {
use crate::models::swin::v2::window_attention::*;
use bimm_contracts::assert_shape_contract;
use burn::backend::NdArray;
use burn::prelude::Tensor;
use burn::tensor::Distribution;
#[test]
fn test_window_attention_meta() {
let window_shape = [4, 4];
let num_heads = 8;
let channels = num_heads * 3;
let config = WindowAttentionConfig::new(channels, window_shape, num_heads);
assert_eq!(config.d_input(), channels);
assert_eq!(config.window_shape(), window_shape);
assert_eq!(config.num_heads(), num_heads);
assert!(config.enable_qkv_bias());
assert_eq!(config.attn_drop(), 0.0);
assert_eq!(config.proj_drop(), 0.0);
assert_eq!(config.window_height(), 4);
assert_eq!(config.window_width(), 4);
let device = Default::default();
let attn_mod = config.init::<NdArray>(&device);
assert_eq!(attn_mod.d_input(), channels);
assert_eq!(attn_mod.window_shape(), window_shape);
assert_eq!(attn_mod.num_heads(), num_heads);
assert!(attn_mod.enable_qkv_bias());
assert_eq!(attn_mod.attn_drop(), 0.0);
assert_eq!(attn_mod.proj_drop(), 0.0);
}
#[test]
fn test_wa() {
let b = 3;
let num_windows = 2;
let window_size = 4;
let num_heads = 5;
let cph = 3;
let channels = num_heads * cph;
let config = WindowAttentionConfig::new(channels, [window_size, window_size], num_heads);
let device = Default::default();
let attn_mod = config.init::<NdArray>(&device);
assert_eq!(attn_mod.d_input(), channels);
assert_eq!(attn_mod.window_shape(), [window_size, window_size]);
assert_eq!(attn_mod.num_heads(), num_heads);
let distribution = Distribution::Uniform(0.0, 1.0);
let input = Tensor::<NdArray, 3>::random(
[b * num_windows, window_size * window_size, channels],
distribution,
&device,
);
let res = attn_mod.forward(input, None);
assert_shape_contract!(
[
"bn" = "batch" * "num_windows",
"window_size" ^ 2,
"channels"
],
&res.dims(),
&[
("batch", b),
("num_windows", num_windows),
("window_size", window_size),
("channels", channels),
],
);
}
}
#[derive(Config, Debug)]
pub struct WindowAttentionConfig {
pub d_input: usize,
pub window_shape: [usize; 2],
pub num_heads: usize,
#[config(default = true)]
pub enable_qkv_bias: bool,
#[config(default = 0.)]
pub attn_drop: f64,
#[config(default = 0.)]
pub proj_drop: f64,
#[config(default = 512)]
pub rpb_mlp_hidden_dim: usize,
#[config(default = "ActivationConfig::Relu")]
pub rpb_mlp_activation: ActivationConfig,
}
#[derive(Module, Debug)]
pub struct WindowAttention<B: Backend> {
pub d_input: usize,
pub num_heads: usize,
pub q_linear: Linear<B>,
pub k_linear: Linear<B>,
pub v_linear: Linear<B>,
pub logit_scale: Param<Tensor<B, 3>>,
pub rpb_module: OffsetGridRelativePositionBias<B>,
pub proj: Linear<B>,
pub attn_drop: Dropout,
pub proj_drop: Dropout,
}
impl<B: Backend> WindowAttention<B> {
#[must_use]
pub fn forward(
&self,
x: Tensor<B, 3>,
mask: Option<Tensor<B, 3>>,
) -> Tensor<B, 3> {
let [wh, ww] = self.window_shape();
let [b_nw, n, c] = unpack_shape_contract!(
["b_nw", "n" = "wh" * "ww", "c"],
&x.dims(),
&["b_nw", "n", "c"],
&[("wh", wh), ("ww", ww)],
);
self.window_shape();
let q = self.q_linear.forward(x.clone());
let k = self.k_linear.forward(x.clone());
let v = self.v_linear.forward(x);
let c_per_head = c / self.num_heads;
let qkv_shape = [b_nw, n, self.num_heads, c_per_head];
let q = q.reshape(qkv_shape).swap_dims(1, 2);
let k = k.reshape(qkv_shape).swap_dims(1, 2);
let v = v.reshape(qkv_shape).swap_dims(1, 2);
let attn = self.attention(b_nw, n, q, k, mask);
let x = attn.matmul(v);
let x = x.swap_dims(1, 2).reshape([b_nw, n, c]);
let x = self.proj.forward(x);
self.proj_drop.forward(x)
}
#[must_use]
fn attention(
&self,
b_nw: usize,
n: usize,
q: Tensor<B, 4>,
k: Tensor<B, 4>,
mask: Option<Tensor<B, 3>>,
) -> Tensor<B, 4> {
let q = vector_normalize(q, Norm::L2, 3, EPS);
let k = vector_normalize(k, Norm::L2, 3, EPS).swap_dims(2, 3);
let attn = q.matmul(k);
let attn = self.encode_attention(attn);
let attn = self.attn_drop.forward(attn);
let attn = match mask {
None => attn,
Some(mask) => apply_attention_mask(b_nw, n, self.num_heads, attn, mask),
};
let attn = softmax(attn, 3);
assert_shape_contract_periodically!(
["b_nw", "num_heads", "Wh" * "Ww", "Wh" * "Ww"],
&attn.dims(),
&[
("b_nw", b_nw),
("num_heads", self.num_heads()),
("Wh", self.window_shape()[0]),
("Ww", self.window_shape()[1]),
],
);
attn
}
#[must_use]
fn logit_scale(&self) -> Tensor<B, 3> {
self.logit_scale.val().clamp_max((1.0f64 / 0.01).ln()).exp()
}
#[inline(always)]
#[must_use]
fn relative_pos_bias(&self) -> Tensor<B, 3> {
self.rpb_module.forward()
}
#[inline(always)]
#[must_use]
fn encode_attention(
&self,
attn: Tensor<B, 4>,
) -> Tensor<B, 4> {
attn * self.logit_scale().unsqueeze() + self.relative_pos_bias().unsqueeze()
}
}
impl WindowAttentionConfig {
pub fn init<B: Backend>(
&self,
device: &B::Device,
) -> WindowAttention<B> {
let d_input = self.d_input();
let num_heads = self.num_heads();
let window_size = self.window_shape();
WindowAttention {
d_input,
num_heads,
q_linear: LinearConfig::new(d_input, d_input)
.with_bias(self.enable_qkv_bias)
.init(device),
k_linear: LinearConfig::new(d_input, d_input).init(device),
v_linear: LinearConfig::new(d_input, d_input)
.with_bias(self.enable_qkv_bias)
.init(device),
logit_scale: Param::initialized(
ParamId::new(),
Tensor::<B, 3>::ones([num_heads, 1, 1], device)
.mul_scalar(10.0)
.log(),
),
attn_drop: DropoutConfig {
prob: self.attn_drop,
}
.init(),
rpb_module: RelativePositionBiasConfig::new(num_heads, window_size)
.with_mlp_hidden_dim(self.rpb_mlp_hidden_dim)
.with_mlp_activation(self.rpb_mlp_activation.clone())
.init_offset_grid_rpb(device),
proj: LinearConfig::new(d_input, d_input)
.with_bias(false)
.init(device),
proj_drop: DropoutConfig {
prob: self.proj_drop,
}
.init(),
}
}
}