use burn::prelude::{Backend, Int, Tensor};
use burn::tensor::grid::{IndexPos, meshgrid_stack};
#[inline(always)]
#[must_use]
pub fn window_index_offset_grid<B: Backend>(
window_shape: [usize; 2],
device: &B::Device,
) -> Tensor<B, 3, Int> {
let [h, w] = window_shape;
assert_ne!(h, 0, "Height must be non-zero");
assert_ne!(w, 0, "Width must be non-zero");
let h = h as i64;
let w = w as i64;
meshgrid_stack(
&[
Tensor::arange_step(-(h - 1)..h, 1, device),
Tensor::arange_step(-(w - 1)..w, 1, device),
],
IndexPos::Last,
)
}
#[inline(always)]
#[must_use]
pub fn window_relative_offset_grid<B: Backend>(
window_shape: [usize; 2],
device: &B::Device,
) -> Tensor<B, 3> {
let x: Tensor<B, 3> = window_index_offset_grid(window_shape, device).float();
let d = Tensor::<B, 1, Int>::from_data(window_shape, device) - 1;
x.div(d.unsqueeze().float())
}
#[inline(always)]
#[must_use]
pub fn window_log1p_relative_offset_grid<B: Backend>(
window_shape: [usize; 2],
base: f64,
device: &B::Device,
) -> Tensor<B, 3> {
let x = window_relative_offset_grid(window_shape, device);
let x = x * base;
let sign = x.clone().sign();
sign * x.abs().log1p() / base.ln()
}
#[inline(always)]
#[must_use]
pub fn window_attention_relative_position_index<B: Backend>(
window_shape: [usize; 2],
device: &B::Device,
) -> Tensor<B, 2, Int> {
let [h, w] = window_shape;
let h = h as i64;
let w = w as i64;
let hw = h * w;
let positions: Tensor<B, 3, Int> = meshgrid_stack(
&[
Tensor::<B, 1, Int>::arange(0..h, device),
Tensor::<B, 1, Int>::arange(0..w, device),
],
IndexPos::First,
);
let coords: Tensor<B, 2, Int> = positions.flatten(1, 2);
let a: Tensor<B, 3, Int> = coords
.clone()
.reshape([2, hw as usize, 1])
.expand([2, hw, hw]);
let b = a.clone().permute([0, 2, 1]);
let rel = a - b;
let rel: Tensor<B, 3, Int> = rel.permute([1, 2, 0]);
let d = Tensor::<B, 1, Int>::from_data(window_shape, device) - 1;
let rel: Tensor<B, 3, Int> = rel + d.unsqueeze();
let s = Tensor::<B, 1, Int>::from_data([2 * window_shape[1] - 1, 1], device);
let rel = rel.mul(s.unsqueeze());
let rel: Tensor<B, 2, Int> = rel.sum_dim(2).squeeze_dim::<2>(2);
rel
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
use burn::prelude::TensorData;
use burn::tensor::Tolerance;
#[test]
fn test_window_index_offset_grid() {
let device = Default::default();
window_index_offset_grid::<NdArray>([3, 2], &device)
.to_data()
.assert_eq(
&TensorData::from([
[[-2, -1], [-2, 0], [-2, 1]],
[[-1, -1], [-1, 0], [-1, 1]],
[[0, -1], [0, 0], [0, 1]],
[[1, -1], [1, 0], [1, 1]],
[[2, -1], [2, 0], [2, 1]],
]),
false,
);
}
#[test]
fn test_window_relative_offset_grid() {
let device = Default::default();
window_relative_offset_grid::<NdArray>([3, 2], &device)
.clone()
.to_data()
.assert_eq(
&TensorData::from([
[[-1., -1.], [-1., 0.], [-1., 1.]],
[[-0.5, -1.], [-0.5, 0.], [-0.5, 1.]],
[[0., -1.], [0., 0.], [0., 1.]],
[[0.5, -1.], [0.5, 0.], [0.5, 1.]],
[[1., -1.], [1., 0.], [1., 1.]],
]),
false,
);
}
#[test]
fn test_window_log_offset_grid() {
let device = Default::default();
let base = 8.0;
let actual = window_log1p_relative_offset_grid::<NdArray>([3, 2], base, &device);
actual.to_data().assert_approx_eq(
&TensorData::from([
[
[-1.0566417, -1.0566417],
[-1.0566417, 0.0],
[-1.0566417, 1.0566417],
],
[
[-0.773976, -1.0566417],
[-0.773976, 0.0],
[-0.773976, 1.0566417],
],
[[0.0, -1.0566417], [0.0, 0.0], [0.0, 1.0566417]],
[
[0.773976, -1.0566417],
[0.773976, 0.0],
[0.773976, 1.0566417],
],
[
[1.0566417, -1.0566417],
[1.0566417, 0.0],
[1.0566417, 1.0566417],
],
]),
Tolerance::<f64>::absolute(1e-5),
);
}
#[test]
fn test_relative_position_index() {
let window_shape = [2, 3];
let device = Default::default();
let rel = window_attention_relative_position_index::<NdArray>(window_shape, &device);
rel.clone().to_data().assert_eq(
&TensorData::from([
[7, 6, 5, 2, 1, 0],
[8, 7, 6, 3, 2, 1],
[9, 8, 7, 4, 3, 2],
[12, 11, 10, 7, 6, 5],
[13, 12, 11, 8, 7, 6],
[14, 13, 12, 9, 8, 7],
]),
false,
);
}
}