#include "src/naive/local/opr_impl.h"
#include <cstring>
#include "src/common/utils.h"
#include "src/naive/handle.h"
using namespace megdnn;
using namespace naive;
LocalForwardImpl::float_noncontig_batch_kern LocalForwardImpl::
dispatch_float_noncontig_batch(
const TensorLayout& src, const TensorLayout& ,
const TensorLayout& ) {
if (src.dtype == dtype::Float32()) {
if (param().mode == Mode::CROSS_CORRELATION) {
return &naive_kern<true, float>;
} else {
return &naive_kern<false, float>;
}
} else if (DNN_FLOAT16_SELECT(src.dtype == dtype::Float16(), false)) {
DNN_INC_FLOAT16(
megdnn_assert(src.dtype == dtype::Float16());
if (param().mode == Mode::CROSS_CORRELATION) {
return &naive_kern<true MEGDNN_COMMA dt_float16>;
} else { return &naive_kern<false MEGDNN_COMMA dt_float16>; });
} else {
megdnn_assert_internal(false);
return nullptr;
}
}
template <bool is_xcorr, typename dtype>
void LocalForwardImpl::naive_kern(const FloatNoncontigBatchKernParam& param) {
UNPACK_LOCAL_FLOAT_NONCONTIG_BATCH_KERN_PARAM(param, dtype);
static_cast<void>(workspace);
rep(n, N) rep(oc, OC) rep(oh, OH) rep(ow, OW) {
auto& dval = dst[n * OUT_BS + oc * OH * OW + oh * OW + ow];
dval = 0.0f;
rep(fh, FH) rep(fw, FW) {
size_t ih = SH * oh;
size_t iw = SW * ow;
if (is_xcorr) {
ih += fh;
iw += fw;
} else {
ih += FH - fh - 1;
iw += FW - fw - 1;
}
ih -= PH;
iw -= PW;
if (ih < static_cast<size_t>(IH) && iw < static_cast<size_t>(IW)) {
rep(ic, IC) {
auto sval = src[n * INP_BS + ic * IH * IW + ih * IW + iw];
auto fval =
filter[oh * OW * IC * FH * FW * OC +
ow * IC * FH * FW * OC + ic * FH * FW * OC +
fh * FW * OC + fw * OC + oc];
dval += sval * fval;
}
}
}
}
}
void LocalForwardImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
_megdnn_workspace workspace) {
exec_use_float_noncontig_batch(src, filter, dst, workspace);
}
LocalForwardImpl::FloatNoncontigBatchKernParam LocalForwardImpl::make_float_kern_param(
_megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
_megdnn_workspace workspace) const {
return {src.get_ref_ptr(), filter.get_ref_ptr(), dst.get_ref_ptr(),
src.layout.shape[0],
src.layout.shape[1], src.layout.shape[2], src.layout.shape[3],
dst.layout.shape[1], dst.layout.shape[2], dst.layout.shape[3],
filter.layout.shape[3], filter.layout.shape[4],
param().pad_h, param().pad_w, param().stride_h, param().stride_w,
src.layout.stride[0], dst.layout.stride[0], workspace.raw_ptr};
}
void LocalForwardImpl::exec_use_float_noncontig_batch(
_megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
_megdnn_workspace workspace) {
check_exec(src.layout, filter.layout, dst.layout, workspace.size);
auto fp = make_float_kern_param(src, filter, dst, workspace);
auto kptr = dispatch_float_noncontig_batch(src.layout, filter.layout, dst.layout);
auto kern = [fp, kptr]() { kptr(fp); };
static_cast<naive::HandleImpl*>(handle())->dispatch_kern(kern);
}
void LocalBackwardDataImpl::exec(
_megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) {
check_exec(filter.layout, diff.layout, grad.layout, workspace.size);
size_t N = grad.layout.shape[0], IC = grad.layout.shape[1],
IH = grad.layout.shape[2], IW = grad.layout.shape[3];
size_t FH = filter.layout.shape[3], FW = filter.layout.shape[4];
size_t OC = diff.layout.shape[1], OH = diff.layout.shape[2],
OW = diff.layout.shape[3];
size_t ph = param().pad_h, pw = param().pad_w;
size_t sh = param().stride_h, sw = param().stride_w;
auto mode = param().mode;
auto kern = [=]() {
auto gptr = grad.ptr<dt_float32>(), fptr = filter.ptr<dt_float32>(),
hptr = diff.ptr<dt_float32>();
memset(gptr, 0, sizeof(float_t) * N * IC * IH * IW);
rep(n, N) rep(oc, OC) rep(oh, OH) rep(ow, OW) {
auto& hval = hptr[((n * OC + oc) * OH + oh) * OW + ow];
rep(ic, IC) rep(fh, FH) rep(fw, FW) {
size_t ih = -ph + sh * oh;
size_t iw = -pw + sw * ow;
if (mode == Mode::CROSS_CORRELATION) {
ih += fh;
iw += fw;
} else {
ih += FH - fh - 1;
iw += FW - fw - 1;
}
if (ih < IH && iw < IW) {
auto& gval = gptr[(((n * IC + ic) * IH) + ih) * IW + iw];
auto fval =
fptr[((((oh * OW + ow) * IC + ic) * FH + fh) * FW + fw) *
OC +
oc];
gval += fval * hval;
}
}
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(kern());
}
void LocalBackwardFilterImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) {
check_exec(src.layout, diff.layout, grad.layout, workspace.size);
size_t N = src.layout.shape[0], IC = src.layout.shape[1], IH = src.layout.shape[2],
IW = src.layout.shape[3];
size_t FH = grad.layout.shape[3], FW = grad.layout.shape[4];
size_t OC = diff.layout.shape[1], OH = diff.layout.shape[2],
OW = diff.layout.shape[3];
size_t ph = param().pad_h, pw = param().pad_w;
size_t sh = param().stride_h, sw = param().stride_w;
auto mode = param().mode;
auto kern = [=]() {
auto gptr = grad.ptr<dt_float32>(), sptr = src.ptr<dt_float32>(),
hptr = diff.ptr<dt_float32>();
memset(gptr, 0, sizeof(float_t) * OH * OW * IC * FH * FW * OC);
rep(n, N) rep(oc, OC) rep(oh, OH) rep(ow, OW) {
auto& hval = hptr[((n * OC + oc) * OH + oh) * OW + ow];
rep(ic, IC) rep(fh, FH) rep(fw, FW) {
size_t ih = -ph + sh * oh;
size_t iw = -pw + sw * ow;
if (mode == Mode::CROSS_CORRELATION) {
ih += fh;
iw += fw;
} else {
ih += FH - fh - 1;
iw += FW - fw - 1;
}
if (ih < IH && iw < IW) {
auto sval = sptr[((n * IC + ic) * IH + ih) * IW + iw];
auto& gval =
gptr[((((oh * OW + ow) * IC + ic) * FH + fh) * FW + fw) *
OC +
oc];
gval += sval * hval;
}
}
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(kern());
}