use crate::{
ShapeOps, SharedArray,
element::{FloatNdArrayElement, IntNdArrayElement},
iter_range_par,
ops::padding::apply_padding_4d,
run_par,
sharing::UnsafeSharedRef,
};
use burn_backend::ElementConversion;
use burn_backend::ops::conv::calculate_pool_output_size;
use ndarray::Array4;
pub(crate) fn max_pool2d<E: FloatNdArrayElement>(
x: SharedArray<E>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
ceil_mode: bool,
) -> SharedArray<E> {
let [kernel_height, kernel_width] = kernel_size;
let [padding_height, padding_width] = padding;
let [stride_height, stride_width] = stride;
let [dilation_height, dilation_width] = dilation;
let [batch_size, channels, x_height, x_width] = x.shape().dims();
let inf = (-f32::INFINITY).elem::<E>();
let out_height = calculate_pool_output_size(
kernel_height,
stride_height,
padding_height,
dilation_height,
x_height,
ceil_mode,
);
let out_width = calculate_pool_output_size(
kernel_width,
stride_width,
padding_width,
dilation_width,
x_width,
ceil_mode,
);
let max_ih =
(out_height.saturating_sub(1)) * stride_height + (kernel_height - 1) * dilation_height;
let max_iw = (out_width.saturating_sub(1)) * stride_width + (kernel_width - 1) * dilation_width;
let padded_height = x_height + 2 * padding_height;
let padded_width = x_width + 2 * padding_width;
let extra_pad_h = max_ih.saturating_sub(padded_height.saturating_sub(1));
let extra_pad_w = max_iw.saturating_sub(padded_width.saturating_sub(1));
let total_padding = [padding_height + extra_pad_h, padding_width + extra_pad_w];
let x = apply_padding_4d::<E>(x, total_padding, inf);
let offset_h = extra_pad_h;
let offset_w = extra_pad_w;
let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf);
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
run_par!(|| {
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
let b = k / channels;
let c = k % channels;
let output = unsafe_shared_out.get();
for oh in 0..out_height {
for ow in 0..out_width {
let mut max_val = inf;
for kh in 0..kernel_height {
let ih = offset_h + oh * stride_height + kh * dilation_height;
for kw in 0..kernel_width {
let iw = offset_w + ow * stride_width + kw * dilation_width;
let val = x[[b, c, ih, iw]];
if val > max_val {
max_val = val;
}
}
}
output[[b, c, oh, ow]] = max_val;
}
}
})
});
output.into_dyn().into_shared()
}
pub(crate) fn max_pool2d_with_indices<E: FloatNdArrayElement, I: IntNdArrayElement>(
x: SharedArray<E>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
ceil_mode: bool,
) -> (SharedArray<E>, SharedArray<I>) {
let [kernel_height, kernel_width] = kernel_size;
let [padding_height, padding_width] = padding;
let [stride_height, stride_width] = stride;
let [dilation_height, dilation_width] = dilation;
let [batch_size, channels, x_height, x_width] = x.shape().dims();
let inf = (-f32::INFINITY).elem::<E>();
let out_height = calculate_pool_output_size(
kernel_height,
stride_height,
padding_height,
dilation_height,
x_height,
ceil_mode,
);
let out_width = calculate_pool_output_size(
kernel_width,
stride_width,
padding_width,
dilation_width,
x_width,
ceil_mode,
);
let max_ih =
(out_height.saturating_sub(1)) * stride_height + (kernel_height - 1) * dilation_height;
let max_iw = (out_width.saturating_sub(1)) * stride_width + (kernel_width - 1) * dilation_width;
let padded_height = x_height + 2 * padding_height;
let padded_width = x_width + 2 * padding_width;
let extra_pad_h = max_ih.saturating_sub(padded_height.saturating_sub(1));
let extra_pad_w = max_iw.saturating_sub(padded_width.saturating_sub(1));
let total_padding = [padding_height + extra_pad_h, padding_width + extra_pad_w];
let x = apply_padding_4d::<E>(x, total_padding, inf);
let offset_h = extra_pad_h;
let offset_w = extra_pad_w;
let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf);
let mut indices = Array4::<I>::zeros((batch_size, channels, out_height, out_width));
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
let unsafe_shared_indices = UnsafeSharedRef::new(&mut indices);
run_par!(|| {
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
let b = k / channels;
let c = k % channels;
let output = unsafe_shared_out.get();
let indices = unsafe_shared_indices.get();
for oh in 0..out_height {
for ow in 0..out_width {
let mut max_val = inf;
let mut index = 0;
for kh in 0..kernel_height {
let ih = offset_h + oh * stride_height + kh * dilation_height;
for kw in 0..kernel_width {
let iw = offset_w + ow * stride_width + kw * dilation_width;
let val = x[[b, c, ih, iw]];
if val > max_val {
max_val = val;
let ih_orig = ih as i64 - (total_padding[0]) as i64;
let iw_orig = iw as i64 - (total_padding[1]) as i64;
let ih_clamped = ih_orig.max(0).min(x_height as i64 - 1);
let iw_clamped = iw_orig.max(0).min(x_width as i64 - 1);
index = ih_clamped * x_width as i64 + iw_clamped;
}
}
}
output[[b, c, oh, ow]] = max_val;
indices[[b, c, oh, ow]] = index.elem();
}
}
})
});
let output = output.into_dyn().into_shared();
let indices = indices.into_dyn().into_shared();
(output, indices)
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn max_pool2d_backward<E: FloatNdArrayElement, I: IntNdArrayElement>(
x: SharedArray<E>,
_kernel_size: [usize; 2],
_stride: [usize; 2],
_padding: [usize; 2],
_dilation: [usize; 2],
_ceil_mode: bool,
output_grad: SharedArray<E>,
indices: SharedArray<I>,
) -> SharedArray<E> {
let [_batch_size, _channels, height, width] = output_grad.shape().dims();
let [batch_size, channels, height_x, width_x] = x.shape().dims();
let output_grad = output_grad;
let indices = indices;
let mut output = Array4::zeros((batch_size, channels, height_x, width_x));
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
run_par!(|| {
iter_range_par!(0, batch_size * channels).for_each(|k| unsafe {
let b = k / channels;
let c = k % channels;
let output = unsafe_shared_out.get();
for h in 0..height {
for w in 0..width {
let index = indices[[b, c, h, w]].elem::<i64>();
let grad = output_grad[[b, c, h, w]];
let index_h = index as usize / width_x;
let index_w = index as usize % width_x;
output[[b, c, index_h, index_w]] += grad;
}
}
});
});
output.into_dyn().into_shared()
}