use core::{marker::PhantomData, mem::transmute};
use crate::{SharedArray, iter_range_par, run_par, sharing::UnsafeSharedRef};
use burn_backend::{BoolStore, DType, Element, quantization::QuantValue};
use macerator::{Simd, VOrd};
use ndarray::{Array4, s};
use nhwc::max_pool2d_nhwc;
use super::{MinMax, should_use_simd};
#[macerator::with_simd]
fn is_accelerated_impl<S: Simd, T: VOrd>(_x: PhantomData<T>) -> bool {
<T as VOrd>::is_min_max_accelerated::<S>()
}
fn is_accelerated<T: VOrd>() -> bool {
is_accelerated_impl::<T>(PhantomData)
}
macro_rules! launch_kernel {
($ty: ty, $func: ident, $x: expr, $($arg: expr),*) => {
match <$ty as Element>::dtype() {
DType::F64 if is_accelerated::<f64>() => Ok(cast($func::<f64>(cast($x), $($arg),*))),
DType::F32 if is_accelerated::<f32>() => Ok(cast($func::<f32>(cast($x), $($arg),*))),
DType::I64 if is_accelerated::<i64>() => Ok(cast($func::<i64>(cast($x), $($arg),*))),
DType::I32 if is_accelerated::<i32>() => Ok(cast($func::<i32>(cast($x), $($arg),*))),
DType::I16 if is_accelerated::<i16>() => Ok(cast($func::<i16>(cast($x), $($arg),*))),
DType::I8 if is_accelerated::<i8>() => Ok(cast($func::<i8>(cast($x), $($arg),*))),
DType::U64 if is_accelerated::<u64>() => Ok(cast($func::<u64>(cast($x), $($arg),*))),
DType::U32 if is_accelerated::<u32>() => Ok(cast($func::<u32>(cast($x), $($arg),*))),
DType::U16 if is_accelerated::<u16>() => Ok(cast($func::<u16>(cast($x), $($arg),*))),
DType::U8 if is_accelerated::<u8>() => Ok(cast($func::<u8>(cast($x), $($arg),*))),
DType::Bool(BoolStore::Native) if is_accelerated::<u8>() => Ok(cast($func::<u8>(cast($x), $($arg),*))),
DType::QFloat(scheme) => match scheme.value {
QuantValue::Q8F | QuantValue::Q8S if is_accelerated::<i8>() => Ok(cast($func::<i8>(cast($x), $($arg),*))),
_ => Err($x)
},
_ => Err($x),
}
};
}
pub(crate) fn try_max_pool2d_simd<E: Element>(
x: SharedArray<E>,
ksize: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> Result<SharedArray<E>, SharedArray<E>> {
let [_, c, _, _] = x.shape().try_into().unwrap();
if !should_use_simd(c) || x.strides()[1] != 1 {
return Err(x);
}
launch_kernel!(E, max_pool2d_nhwc, x, ksize, stride, padding, dilation)
}
fn cast<T, E>(tensor: SharedArray<T>) -> SharedArray<E> {
unsafe { transmute::<SharedArray<T>, SharedArray<E>>(tensor) }
}
mod nhwc {
use itertools::Itertools;
use macerator::{Simd, vload_unaligned, vstore_unaligned};
use ndarray::{ArrayView3, ArrayViewMut3, Ix4};
use seq_macro::seq;
use crate::ops::simd::lanes;
use super::*;
const BLOCK_REGISTERS: usize = 8;
pub(crate) fn max_pool2d_nhwc<E: Element + VOrd + MinMax>(
x: SharedArray<E>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> SharedArray<E> {
let [kernel_height, kernel_width] = kernel_size;
let [pad_h, pad_w] = padding;
let [stride_height, stride_width] = stride;
let [dilation_height, dilation_width] = dilation;
let [batch_size, channels, x_height, x_width] = x.shape().try_into().unwrap();
let lanes = lanes::<E>();
let ch_block = lanes * BLOCK_REGISTERS;
let out_height = ((x_height + 2 * pad_h - dilation_height * (kernel_height - 1) - 1)
/ stride_height)
+ 1;
let out_width =
((x_width + 2 * pad_w - dilation_width * (kernel_width - 1) - 1) / stride_width) + 1;
let mut output = unsafe {
Array4::<E>::uninit((batch_size, out_height, out_width, channels)).assume_init()
};
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
let x = x.into_dimensionality::<Ix4>().unwrap();
let x = x.view();
let x = x.permuted_axes([0, 2, 3, 1]);
let blocks = channels / ch_block;
let blocks_end = blocks * ch_block;
let simd_end = channels / lanes * lanes;
let simd_unblocked = (simd_end - blocks_end) / lanes;
let remainder = channels - simd_end;
run_par!(|| {
iter_range_par!(0, batch_size * blocks).for_each(|k| unsafe {
let block = k % blocks;
let b = k / blocks;
let output = unsafe_shared_out.get();
let x = x.slice(s![b, .., .., ..]);
let out = output.slice_mut(s![b, .., .., ..]);
loop_blocked(x, out, kernel_size, stride, padding, dilation, block);
});
iter_range_par!(0, batch_size * simd_unblocked).for_each(|k| unsafe {
let ch = (k % simd_unblocked) * lanes + blocks_end;
let b = k / simd_unblocked;
let output = unsafe_shared_out.get();
let x = x.slice(s![b, .., .., ..]);
let out = output.slice_mut(s![b, .., .., ..]);
loop_unblocked(x, out, kernel_size, stride, padding, dilation, ch);
});
iter_range_par!(0, batch_size * remainder).for_each(|k| unsafe {
let ch = (k % remainder) + simd_end;
let b = k / remainder;
let output = unsafe_shared_out.get();
let x = x.slice(s![b, .., .., ..]);
let out = output.slice_mut(s![b, .., .., ..]);
loop_scalar(x, out, kernel_size, stride, padding, dilation, ch);
});
});
output = output.permuted_axes([0, 3, 1, 2]);
output.into_dyn().into_shared()
}
#[allow(
clippy::too_many_arguments,
clippy::erasing_op,
clippy::identity_op,
unused_mut
)]
#[inline(always)]
#[macerator::with_simd]
fn loop_blocked<'a, S: Simd, E: Element + VOrd + MinMax>(
x: ArrayView3<'a, E>,
mut out: ArrayViewMut3<'a, E>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
block: usize,
) where
'a: 'a,
{
let [kernel_height, kernel_width] = kernel_size;
let [pad_h, pad_w] = padding;
let [stride_height, stride_width] = stride;
let [dilation_height, dilation_width] = dilation;
let (x_height, x_width, _) = x.dim();
let (out_height, out_width, _) = out.dim();
let lanes = E::lanes::<S>();
let ch_block = lanes * BLOCK_REGISTERS;
let min = E::MIN.splat::<S>();
for oh in pad_h..out_height.saturating_sub(pad_h) {
for ow in pad_w..out_width.saturating_sub(pad_w) {
seq!(N in 0..8 {
let mut acc~N = min;
});
let ch = block * ch_block;
let ch_end = ch + ch_block;
let mut out = out.slice_mut(s![oh, ow, ch..ch_end]);
for kh in 0..kernel_height {
let ih = oh * stride_height + kh * dilation_height - pad_h;
for kw in 0..kernel_width {
let iw = ow * stride_width + kw * dilation_width - pad_w;
let x = x.slice(s![ih, iw, ch..ch_end]);
seq!(N in 0..8 {
acc~N = acc~N.max(unsafe { vload_unaligned(&x[N * lanes]) });
});
}
}
seq!(N in 0..8 {
unsafe { vstore_unaligned(&mut out[N * lanes], acc~N) };
});
}
}
if (pad_h, pad_w) != (0, 0) {
let v_borders = (0..pad_h)
.chain(out_height.saturating_sub(pad_h)..out_height)
.cartesian_product(0..out_width);
let h_borders = (0..out_height)
.cartesian_product((0..pad_w).chain(out_width.saturating_sub(pad_w)..out_width));
for (oh, ow) in v_borders.chain(h_borders) {
seq!(N in 0..8 {
let mut acc~N = min;
});
let ch = block * ch_block;
let ch_end = ch + ch_block;
let mut out = out.slice_mut(s![oh, ow, ch..ch_end]);
for kh in 0..kernel_height {
let ih = oh * stride_height + kh * dilation_height;
if ih < pad_h || ih >= x_height + pad_h {
continue;
}
let ih = ih - pad_h;
for kw in 0..kernel_width {
let iw = ow * stride_width + kw * dilation_width;
if iw < pad_w || iw >= x_width + pad_w {
continue;
}
let iw = iw - pad_w;
let x = x.slice(s![ih, iw, ch..ch_end]);
seq!(N in 0..8 {
acc~N = acc~N.max(unsafe { vload_unaligned(&x[N * lanes]) });
});
}
}
seq!(N in 0..8 {
unsafe { vstore_unaligned(&mut out[N * lanes], acc~N) };
});
}
}
}
#[allow(clippy::too_many_arguments, unused_mut)]
#[inline(always)]
#[macerator::with_simd]
unsafe fn loop_unblocked<'a, S: Simd, E: Element + VOrd + MinMax>(
x: ArrayView3<'a, E>,
mut out: ArrayViewMut3<'a, E>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
ch: usize,
) where
'a: 'a,
{
let [kernel_height, kernel_width] = kernel_size;
let [pad_h, pad_w] = padding;
let [stride_height, stride_width] = stride;
let [dilation_height, dilation_width] = dilation;
let (x_height, x_width, _) = x.dim();
let (out_height, out_width, _) = out.dim();
for oh in pad_h..out_height.saturating_sub(pad_h) {
for ow in pad_w..out_width.saturating_sub(pad_w) {
let mut acc = E::MIN.splat::<S>();
let out = &mut out[[oh, ow, ch]];
for kh in 0..kernel_height {
let ih = oh * stride_height + kh * dilation_height - pad_h;
for kw in 0..kernel_width {
let iw = ow * stride_width + kw * dilation_width - pad_w;
acc = acc.max(unsafe { vload_unaligned(&x[[ih, iw, ch]]) });
}
}
unsafe { vstore_unaligned(out, acc) };
}
}
if (pad_h, pad_w) != (0, 0) {
let v_borders = (0..pad_h)
.chain(out_height.saturating_sub(pad_h)..out_height)
.cartesian_product(0..out_width);
let h_borders = (0..out_height)
.cartesian_product((0..pad_w).chain(out_width.saturating_sub(pad_w)..out_width));
for (oh, ow) in v_borders.chain(h_borders) {
let mut acc = E::MIN.splat::<S>();
let out = &mut out[[oh, ow, ch]];
for kh in 0..kernel_height {
let ih = oh * stride_height + kh * dilation_height;
if ih < pad_h || ih >= x_height + pad_h {
continue;
}
let ih = ih - pad_h;
for kw in 0..kernel_width {
let iw = ow * stride_width + kw * dilation_width;
if iw < pad_w || iw >= x_width + pad_w {
continue;
}
let iw = iw - pad_w;
acc = acc.max(unsafe { vload_unaligned(&x[[ih, iw, ch]]) });
}
}
unsafe { vstore_unaligned(out, acc) };
}
}
}
fn loop_scalar<E: Element + MinMax>(
x: ArrayView3<'_, E>,
mut out: ArrayViewMut3<'_, E>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
ch: usize,
) {
let [kernel_height, kernel_width] = kernel_size;
let [pad_h, pad_w] = padding;
let [stride_height, stride_width] = stride;
let [dilation_height, dilation_width] = dilation;
let (x_height, x_width, _) = x.dim();
let (out_height, out_width, _) = out.dim();
for oh in 0..out_height {
for ow in 0..out_width {
let mut acc = E::MIN;
for kh in 0..kernel_height {
let ih = oh * stride_height + kh * dilation_height;
if ih < pad_h || ih >= x_height + pad_h {
continue;
}
let ih = ih - pad_h;
for kw in 0..kernel_width {
let iw = ow * stride_width + kw * dilation_width;
if iw < pad_w || iw >= x_width + pad_w {
continue;
}
let iw = iw - pad_w;
acc = acc.max(x[[ih, iw, ch]]);
}
}
out[[oh, ow, ch]] = acc;
}
}
}
}