#include "src/naive/sliding_window_transpose/opr_impl.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"
#include <cstring>
namespace megdnn {
namespace naive {
template <typename T>
void SlidingWindowTransposeForwardImpl::exec_internal(
_megdnn_tensor_in src, _megdnn_tensor_out dst) {
int N = dst.layout.shape[0], C = dst.layout.shape[1], IH = dst.layout.shape[2],
IW = dst.layout.shape[3];
auto sptr = src.ptr<T>();
auto dptr = dst.ptr<T>();
size_t idx = 0;
int window_h = static_cast<int>(param().window_h);
int window_w = static_cast<int>(param().window_w);
int pad_h = static_cast<int>(param().pad_h);
int pad_w = static_cast<int>(param().pad_w);
int stride_h = static_cast<int>(param().stride_h);
int stride_w = static_cast<int>(param().stride_w);
int dilate_h = static_cast<int>(param().dilate_h);
int dilate_w = static_cast<int>(param().dilate_w);
int equ_window_h = dilate_h * (window_h - 1) + 1;
int equ_window_w = dilate_w * (window_w - 1) + 1;
memset(dptr, 0, sizeof(T) * N * C * IH * IW);
for (int n = 0; n < N; ++n)
for (int c = 0; c < C; ++c) {
int ih = -pad_h;
for (; ih + equ_window_h <= IH + pad_h; ih += stride_h) {
int iw = -pad_w;
for (; iw + equ_window_w <= IW + pad_w; iw += stride_w) {
for (int kh = 0; kh < window_h; ++kh)
for (int kw = 0; kw < window_w; ++kw) {
int ih2 = ih + dilate_h * kh, iw2 = iw + dilate_w * kw;
if (ih2 >= 0 && ih2 < IH && iw2 >= 0 && iw2 < IW) {
dptr[n * C * IH * IW + c * IH * IW + ih2 * IW + iw2] +=
sptr[idx * window_h * window_w + kh * window_w +
kw];
}
}
++idx;
}
}
}
}
void SlidingWindowTransposeForwardImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size);
#define cb(DType) \
if (src.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
exec_internal<typename DTypeTrait<DType>::ctype>(src, dst);); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb);
#undef cb
megdnn_assert_internal(0);
}
template <typename T>
void SlidingWindowTransposeBackwardImpl::exec_internal(
_megdnn_tensor_in diff, _megdnn_tensor_out grad) {
int N = diff.layout.shape[0], C = diff.layout.shape[1], IH = diff.layout.shape[2],
IW = diff.layout.shape[3];
auto sptr = grad.ptr<T>();
auto dptr = diff.ptr<T>();
size_t idx = 0;
int window_h = static_cast<int>(param().window_h);
int window_w = static_cast<int>(param().window_w);
int pad_h = static_cast<int>(param().pad_h);
int pad_w = static_cast<int>(param().pad_w);
int stride_h = static_cast<int>(param().stride_h);
int stride_w = static_cast<int>(param().stride_w);
int dilate_h = static_cast<int>(param().dilate_h);
int dilate_w = static_cast<int>(param().dilate_w);
int equ_window_h = dilate_h * (window_h - 1) + 1;
int equ_window_w = dilate_w * (window_w - 1) + 1;
for (int n = 0; n < N; ++n)
for (int c = 0; c < C; ++c) {
int ih = -pad_h;
for (; ih + equ_window_h <= IH + pad_h; ih += stride_h) {
int iw = -pad_w;
for (; iw + equ_window_w <= IW + pad_w; iw += stride_w) {
for (int kh = 0; kh < window_h; ++kh)
for (int kw = 0; kw < window_w; ++kw) {
int ih2 = ih + dilate_h * kh, iw2 = iw + dilate_w * kw;
sptr[idx * window_h * window_w + kh * window_w + kw] =
ih2 >= 0 && ih2 < IH && iw2 >= 0 && iw2 < IW
? dptr[n * C * IH * IW + c * IH * IW +
ih2 * IW + iw2]
: 0.0f;
}
++idx;
}
}
}
}
void SlidingWindowTransposeBackwardImpl::exec(
_megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_workspace workspace) {
check_exec(diff.layout, grad.layout, workspace.size);
#define cb(DType) \
if (diff.layout.dtype == DType()) { \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
exec_internal<typename DTypeTrait<DType>::ctype>(diff, grad);); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb);
#undef cb
megdnn_assert_internal(0);
}
} }