use crate::models::swin::v2::windowing::window_partition;
use burn::prelude::{Backend, Bool, Int, Tensor};
#[inline(always)]
#[must_use]
pub fn apply_attention_mask<B: Backend>(
b_nw: usize,
n: usize,
num_heads: usize,
attn: Tensor<B, 4>,
mask: Tensor<B, 3>,
) -> Tensor<B, 4> {
let num_windows = mask.dims()[0];
let b = b_nw / num_windows;
let attn = attn.reshape([b, num_windows, num_heads, n, n]);
let mask = mask.unsqueeze_dim::<4>(1).unsqueeze::<5>();
let attn: Tensor<B, 5> = attn + mask;
attn.reshape([-1, num_heads as i32, n as i32, n as i32])
}
#[must_use]
fn sw_img_mask<B: Backend>(
input_shape: [usize; 2],
window_size: usize,
shift_size: usize,
device: &B::Device,
) -> Tensor<B, 2, Int> {
let [h, w] = input_shape;
let mut img_mask = Tensor::<B, 2, Int>::zeros([h, w], device);
let h = h as i32;
let w = w as i32;
let window_size = window_size as i32;
let shift_size = shift_size as i32;
assert_eq!(
h % window_size,
0,
"Height {h} is not divisible by window size {window_size}"
);
assert_eq!(
w % window_size,
0,
"Width {w} is not divisible by window size {window_size}"
);
let h_slices = [
0..(h - window_size) as usize,
(h - window_size) as usize..(h - shift_size) as usize,
(h - shift_size) as usize..h as usize,
];
let w_slices = [
0..(w - window_size) as usize,
(w - window_size) as usize..(w - shift_size) as usize,
(w - shift_size) as usize..w as usize,
];
let mut cnt = 0;
for h in h_slices.iter() {
for w in w_slices.iter() {
let slice_shape = img_mask.clone().slice([h.clone(), w.clone()]).dims();
let val: Tensor<B, 1, Int> = Tensor::from_data([cnt], device);
let val = val.unsqueeze::<2>().expand(slice_shape);
img_mask = img_mask.slice_assign([h.clone(), w.clone()], val);
cnt += 1;
}
}
img_mask
}
#[must_use]
pub fn sw_attn_mask<B: Backend>(
input_shape: [usize; 2],
window_size: usize,
shift_size: usize,
device: &B::Device,
) -> Tensor<B, 3, Bool> {
let img_mask = sw_img_mask(input_shape, window_size, shift_size, device);
let img_mask = img_mask.unsqueeze_dim::<3>(2).unsqueeze::<4>();
let mask_windows = window_partition(img_mask, window_size);
let mask_windows = mask_windows.reshape([-1, (window_size * window_size) as i32]);
let mask =
mask_windows.clone().unsqueeze_dim::<3>(1) - mask_windows.clone().unsqueeze_dim::<3>(2);
mask.not_equal_elem(0)
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
use burn::prelude::{Tensor, TensorData, s};
#[should_panic(expected = "Height 5 is not divisible by window size 2")]
#[test]
fn test_sw_img_mask_height_not_divisible() {
let device = Default::default();
let _d = sw_img_mask::<NdArray>([5, 4], 2, 1, &device);
}
#[should_panic(expected = "Width 5 is not divisible by window size 2")]
#[test]
fn test_sw_img_mask_width_not_divisible() {
let device = Default::default();
let _d = sw_img_mask::<NdArray>([4, 5], 2, 1, &device);
}
#[test]
fn test_apply_attention_mask() {
let b = 2;
let nw = 2;
let b_nw = b * nw;
let ws = 2;
let n = ws * ws;
let num_heads = 5;
let device = Default::default();
let attn = Tensor::<NdArray, 4>::zeros([b_nw, num_heads, n, n], &device);
let mask = Tensor::<NdArray, 3>::from_data(
[
[
[0.0, 0.25, 0.5, 0.75],
[1.0, 1.25, 1.5, 1.75],
[2.0, 2.25, 2.5, 2.75],
[3.0, 3.25, 3.5, 3.75],
],
[
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
],
],
&device,
);
let res = apply_attention_mask(b_nw, n, num_heads, attn, mask.clone());
assert_eq!(res.dims(), [b_nw, num_heads, n, n]);
let res = res.reshape([b, nw, num_heads, n, n]);
for bi in 0..b {
for wi in 0..nw {
let window = res
.clone()
.slice(s![bi, wi, .., ..])
.squeeze_dim::<4>(0)
.squeeze_dim::<3>(0);
let wmask: Tensor<NdArray, 2> =
mask.clone().slice(s![wi, .., ..]).squeeze_dim::<2>(0);
for hi in 0..num_heads {
let h_attn = window.clone().slice(s![hi, .., ..]).squeeze_dim::<2>(0);
h_attn.to_data().assert_eq(&wmask.to_data(), true);
}
}
}
}
#[test]
fn test_attn_mask() {
let device = Default::default();
sw_attn_mask::<NdArray>([4, 4], 2, 1, &device)
.to_data()
.assert_eq(
&TensorData::from([
[
[false, false, false, false],
[false, false, false, false],
[false, false, false, false],
[false, false, false, false],
],
[
[false, true, false, true],
[true, false, true, false],
[false, true, false, true],
[true, false, true, false],
],
[
[false, false, true, true],
[false, false, true, true],
[true, true, false, false],
[true, true, false, false],
],
[
[false, true, true, true],
[true, false, true, true],
[true, true, false, true],
[true, true, true, false],
],
]),
true,
);
sw_attn_mask::<NdArray>([6, 6], 3, 1, &device)
.to_data()
.assert_eq(
&TensorData::from([
[
[
false, false, false, false, false, false, false, false, false,
],
[
false, false, false, false, false, false, false, false, false,
],
[
false, false, false, false, false, false, false, false, false,
],
[
false, false, false, false, false, false, false, false, false,
],
[
false, false, false, false, false, false, false, false, false,
],
[
false, false, false, false, false, false, false, false, false,
],
[
false, false, false, false, false, false, false, false, false,
],
[
false, false, false, false, false, false, false, false, false,
],
[
false, false, false, false, false, false, false, false, false,
],
],
[
[false, false, true, false, false, true, false, false, true],
[false, false, true, false, false, true, false, false, true],
[true, true, false, true, true, false, true, true, false],
[false, false, true, false, false, true, false, false, true],
[false, false, true, false, false, true, false, false, true],
[true, true, false, true, true, false, true, true, false],
[false, false, true, false, false, true, false, false, true],
[false, false, true, false, false, true, false, false, true],
[true, true, false, true, true, false, true, true, false],
],
[
[false, false, false, false, false, false, true, true, true],
[false, false, false, false, false, false, true, true, true],
[false, false, false, false, false, false, true, true, true],
[false, false, false, false, false, false, true, true, true],
[false, false, false, false, false, false, true, true, true],
[false, false, false, false, false, false, true, true, true],
[true, true, true, true, true, true, false, false, false],
[true, true, true, true, true, true, false, false, false],
[true, true, true, true, true, true, false, false, false],
],
[
[false, false, true, false, false, true, true, true, true],
[false, false, true, false, false, true, true, true, true],
[true, true, false, true, true, false, true, true, true],
[false, false, true, false, false, true, true, true, true],
[false, false, true, false, false, true, true, true, true],
[true, true, false, true, true, false, true, true, true],
[true, true, true, true, true, true, false, false, true],
[true, true, true, true, true, true, false, false, true],
[true, true, true, true, true, true, true, true, false],
],
]),
true,
);
sw_attn_mask::<NdArray>([6, 6], 3, 2, &device)
.to_data()
.assert_eq(
&TensorData::from([
[
[
false, false, false, false, false, false, false, false, false,
],
[
false, false, false, false, false, false, false, false, false,
],
[
false, false, false, false, false, false, false, false, false,
],
[
false, false, false, false, false, false, false, false, false,
],
[
false, false, false, false, false, false, false, false, false,
],
[
false, false, false, false, false, false, false, false, false,
],
[
false, false, false, false, false, false, false, false, false,
],
[
false, false, false, false, false, false, false, false, false,
],
[
false, false, false, false, false, false, false, false, false,
],
],
[
[false, true, true, false, true, true, false, true, true],
[true, false, false, true, false, false, true, false, false],
[true, false, false, true, false, false, true, false, false],
[false, true, true, false, true, true, false, true, true],
[true, false, false, true, false, false, true, false, false],
[true, false, false, true, false, false, true, false, false],
[false, true, true, false, true, true, false, true, true],
[true, false, false, true, false, false, true, false, false],
[true, false, false, true, false, false, true, false, false],
],
[
[false, false, false, true, true, true, true, true, true],
[false, false, false, true, true, true, true, true, true],
[false, false, false, true, true, true, true, true, true],
[true, true, true, false, false, false, false, false, false],
[true, true, true, false, false, false, false, false, false],
[true, true, true, false, false, false, false, false, false],
[true, true, true, false, false, false, false, false, false],
[true, true, true, false, false, false, false, false, false],
[true, true, true, false, false, false, false, false, false],
],
[
[false, true, true, true, true, true, true, true, true],
[true, false, false, true, true, true, true, true, true],
[true, false, false, true, true, true, true, true, true],
[true, true, true, false, true, true, false, true, true],
[true, true, true, true, false, false, true, false, false],
[true, true, true, true, false, false, true, false, false],
[true, true, true, false, true, true, false, true, true],
[true, true, true, true, false, false, true, false, false],
[true, true, true, true, false, false, true, false, false],
],
]),
true,
);
}
}