/**
* \file dnn/src/rocm/convolution/chanwise/bwd_data.cpp.hip
*
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "hip_header.h"
#include "./kern.h.hip"
#include "./kern_helper.h.hip"
using namespace megdnn;
using namespace rocm;
using namespace convolution;
using namespace chanwise;
namespace {
// grid idx is (inp_chl, worker_index)
// each y-slice of a block works on an (N, IH, IW) spatial image at given
// inp_chl
template <typename T, int CHL_MUL_SET, int FH_SET, int FW_SET, int SH_SET,
int SW_SET>
__global__ void kern_bwd_data(T* src_grad, const T* dst_grad, const T* flt_tot,
Param param) {
extern __shared__ uint8_t flt_storage[];
T* const flt = reinterpret_cast<T*>(flt_storage);
const uint32_t N = param.batch, IC = param.src_chl, ic = blockIdx.x,
IH = param.src_h, IW = param.src_w,
CHL_MUL = CHL_MUL_SET ? CHL_MUL_SET : param.chl_mul,
FH = FH_SET ? FH_SET : param.flt_h,
FW = FW_SET ? FW_SET : param.flt_w, FSIZE = FH * FW,
PH = param.pad_h, PW = param.pad_w,
SH = SH_SET ? SH_SET : param.stride_h,
SW = SW_SET ? SW_SET : param.stride_w, OH = param.out_h,
OW = param.out_w, TOT_OUT = N * IH * IW;
block_memcpy(flt, flt_tot + ic * FSIZE * CHL_MUL, FSIZE * CHL_MUL);
dst_grad += ic * CHL_MUL * OH * OW;
src_grad += ic * IH * IW;
uint32_t out_idx_ = blockIdx.y * blockDim.x + threadIdx.x,
nr_out_per_launch = blockDim.x * gridDim.y;
for (; out_idx_ < TOT_OUT; out_idx_ += nr_out_per_launch) {
uint32_t out_idx = out_idx_, n, ih, iw;
out_idx = div_mod(out_idx, IW, iw);
out_idx = div_mod(out_idx, IH, ih);
n = out_idx;
const T* dst_grad_base = dst_grad + n * (IC * CHL_MUL * OH * OW);
T sum(0);
// o >= max(0, floor_div((i+P-F+1), S))
uint32_t ohmin = max(int32_t(ih + PH - FH + SH), 0) / SH,
owmin = max(int32_t(iw + PW - FW + SW), 0) / SW,
ohmax = min((ih + PH) / SH, OH - 1),
owmax = min((iw + PW) / SW, OW - 1);
if (SH_SET == 1 && SW_SET == 1 && FH_SET && FW_SET) {
#pragma unroll
for (uint32_t doh = 0; doh < FH; ++doh) {
uint32_t oh = ohmin + doh;
if (oh <= ohmax) {
uint32_t fh = ih - oh * SH + PH;
#pragma unroll
for (uint32_t dow = 0; dow < FW; ++dow) {
uint32_t ow = owmin + dow;
if (ow <= owmax) {
uint32_t fw = iw - ow * SW + PW;
const T* pd = dst_grad_base + oh * OW + ow;
const T* pf = flt + fh * FW + fw;
#pragma unroll
for (uint32_t chl_mul = 0; chl_mul < CHL_MUL;
++chl_mul) {
sum += *pd * *pf;
pd += OH * OW;
pf += FSIZE;
}
}
}
}
}
} else {
for (uint32_t oh = ohmin; oh <= ohmax; ++oh) {
uint32_t fh = ih - oh * SH + PH;
for (uint32_t ow = owmin; ow <= owmax; ++ow) {
uint32_t fw = iw - ow * SW + PW;
const T* pd = dst_grad_base + oh * OW + ow;
const T* pf = flt + fh * FW + fw;
#pragma unroll
for (uint32_t chl_mul = 0; chl_mul < CHL_MUL; ++chl_mul) {
sum += *pd * *pf;
pd += OH * OW;
pf += FSIZE;
}
}
}
}
src_grad[(n * (IC * IH) + ih) * IW + iw] = sum;
}
}
template <typename T>
class KernDispatch {
public:
typedef void (*kern_ptr_t)(T*, const T*, const T*, Param);
static kern_ptr_t dispatch(int chl_mul, int fh, int fw, int sh, int sw) {
if (chl_mul == 1) {
if (fh == 3 && fw == 3)
return d1<1, 3, 3>(sh, sw);
if (fh == 4 && fw == 4)
return d1<1, 4, 4>(sh, sw);
}
return d1<0, 0, 0>(sh, sw);
}
private:
template <int chl_mul, int fh, int fw>
static kern_ptr_t d1(int sh, int sw) {
if (sh == 1 && sw == 1)
return kern_bwd_data<T, chl_mul, fh, fw, 1, 1>;
if (sh == 1 && sw == 2)
return kern_bwd_data<T, chl_mul, fh, fw, 1, 2>;
if (sh == 2 && sw == 1)
return kern_bwd_data<T, chl_mul, fh, fw, 2, 1>;
if (sh == 2 && sw == 2)
return kern_bwd_data<T, chl_mul, fh, fw, 2, 2>;
return kern_bwd_data<T, chl_mul, fh, fw, 0, 0>;
}
};
} // anonymous namespace
template <typename T>
void chanwise::run_bwd_data(T* src_grad, const T* dst_grad, const T* flt,
const Param& param, hipStream_t stream) {
typename KernDispatch<T>::kern_ptr_t kern =
KernDispatch<T>::dispatch(param.chl_mul, param.flt_h, param.flt_w,
param.stride_h, param.stride_w);
int nr_thread = 256, nr_out_dimx = param.src_h * param.src_w * param.batch;
dim3 nr_block(param.src_chl,
std::min(512, max(nr_out_dimx / (nr_thread * 4), 1)));
uint32_t shared = param.chl_mul * param.flt_h * param.flt_w * sizeof(T);
hipLaunchKernelGGL(kern, nr_block, nr_thread, shared, stream, src_grad, dst_grad, flt,
param);
after_kernel_launch();
}
namespace megdnn {
namespace rocm {
namespace convolution {
namespace chanwise {
#define INST(_dt) \
template void run_bwd_data( \
DTypeTrait<_dt>::ctype*, const DTypeTrait<_dt>::ctype*, \
const DTypeTrait<_dt>::ctype*, const Param&, hipStream_t);
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(INST)
#undef INST
#undef DO_INST
} // namespace chanwise
} // namespace convolution
} // namespace rocm
} // namespace megdnn
// vim: syntax=cuda.doxygen