/**
* \file dnn/src/rocm/convolution/chanwise/fwd.cpp.hip
*
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "hip_header.h"
#include "./kern.h.hip"
#include "./kern_helper.h.hip"
using namespace megdnn;
using namespace rocm;
using namespace convolution;
using namespace chanwise;
namespace {
// grid idx is (inp_chl, worker_index)
// each y-slice of a block works on an (N, CHL_MUL, OH, OW) spatial image at
// given inp_chl
template <typename T, int CHL_MUL_SET, int FH_SET, int FW_SET>
__global__ void kern_fwd(T* dst, const T* src, const T* flt_tot, Param param) {
extern __shared__ uint8_t flt_storage[];
T* const flt = reinterpret_cast<T*>(flt_storage);
const uint32_t N = param.batch, IC = param.src_chl, ic = blockIdx.x,
IH = param.src_h, IW = param.src_w,
CHL_MUL = CHL_MUL_SET ? CHL_MUL_SET : param.chl_mul,
FH = FH_SET ? FH_SET : param.flt_h,
FW = FW_SET ? FW_SET : param.flt_w, FSIZE = FH * FW,
PH = param.pad_h, PW = param.pad_w, SH = param.stride_h,
SW = param.stride_w, OH = param.out_h, OW = param.out_w,
TOT_OUT = N * CHL_MUL * OH * OW;
block_memcpy(flt, flt_tot + ic * FSIZE * CHL_MUL, FSIZE * CHL_MUL);
uint32_t out_idx_ = blockIdx.y * blockDim.x + threadIdx.x,
nr_out_per_launch = blockDim.x * gridDim.y;
for (; out_idx_ < TOT_OUT; out_idx_ += nr_out_per_launch) {
uint32_t out_idx = out_idx_, n, chl_mul, oh, ow;
out_idx = div_mod(out_idx, OW, ow);
out_idx = div_mod(out_idx, OH, oh);
if (CHL_MUL_SET == 1) {
chl_mul = 0;
n = out_idx;
} else {
n = div_mod(out_idx, CHL_MUL, chl_mul);
}
int ih = int(oh * SH) - int(PH), iw = int(ow * SW) - int(PW);
const T* flt_base = flt + chl_mul * FSIZE;
const T* src_base = src + int(((n * IC + ic) * IH + ih) * IW + iw);
T sum(0);
if (FH_SET && FW_SET) {
#pragma unroll
for (uint32_t fh = 0; fh < FH; ++fh) {
if (static_cast<uint32_t>(fh + ih) < IH) {
#pragma unroll
for (uint32_t fw = 0; fw < FW; ++fw) {
if (static_cast<uint32_t>(fw + iw) < IW) {
sum += flt_base[fh * FW + fw] *
src_base[fh * IW + fw];
}
}
}
}
} else {
int fhmax = min(int(FH), int(IH - ih)),
fwmax = min(int(FW), int(IW - iw));
for (int fh = max(0, -ih); fh < fhmax; ++fh) {
for (int fw = max(0, -iw); fw < fwmax; ++fw) {
sum += flt_base[fh * FW + fw] * src_base[fh * IW + fw];
}
}
}
dst[(((n * IC + ic) * CHL_MUL + chl_mul) * OH + oh) * OW + ow] = sum;
}
}
} // anonymous namespace
template <typename T>
void chanwise::run_fwd(T* dst, const T* src, const T* flt, const Param& param,
hipStream_t stream) {
void (*kern)(T*, const T*, const T*, Param);
if (param.chl_mul == 1) {
if (param.flt_h == 3 && param.flt_w == 3) {
kern = kern_fwd<T, 1, 3, 3>;
} else if (param.flt_h == 4 && param.flt_w == 4) {
kern = kern_fwd<T, 1, 4, 4>;
} else {
kern = kern_fwd<T, 1, 0, 0>;
}
} else {
kern = kern_fwd<T, 0, 0, 0>;
}
int nr_thread = 256,
nr_out_dimx = param.out_h * param.out_w * param.batch * param.chl_mul;
dim3 nr_block(param.src_chl,
std::min(512, max(nr_out_dimx / (nr_thread * 4), 1)));
uint32_t shared = param.chl_mul * param.flt_h * param.flt_w * sizeof(T);
hipLaunchKernelGGL(kern, nr_block, nr_thread, shared, stream, dst, src, flt, param);
after_kernel_launch();
}
namespace megdnn {
namespace rocm {
namespace convolution {
namespace chanwise {
#define DO_INST(_ct) \
template void run_fwd(_ct*, const _ct*, const _ct*, const Param&, \
hipStream_t);
#define INST(_dt) DO_INST(DTypeTrait<_dt>::ctype)
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(INST)
#undef INST
#undef DO_INST
} // namespace chanwise
} // namespace convolution
} // namespace rocm
} // namespace megdnn
// vim: syntax=cuda.doxygen