#include "src/naive/group_local/opr_impl.h"
#include <cstring>
#include "src/naive/handle.h"
namespace {
template <typename dtype>
void forward(
const dtype* src, const dtype* filter, dtype* dst, size_t N, size_t IC,
size_t IH, size_t IW, size_t FH, size_t FW, size_t OC, size_t OH, size_t OW,
size_t group, size_t pad_h, size_t pad_w, size_t stride_h, size_t stride_w) {
size_t ICg = IC / group;
size_t OCg = OC / group;
for (size_t n = 0; n < N; ++n)
for (size_t gid = 0; gid < group; ++gid)
for (size_t ocg = 0; ocg < OCg; ++ocg)
for (size_t oh = 0; oh < OH; ++oh)
for (size_t ow = 0; ow < OW; ++ow) {
float res = 0;
size_t oc = gid * OCg + ocg;
for (size_t fh = 0; fh < FH; ++fh)
for (size_t fw = 0; fw < FW; ++fw)
for (size_t icg = 0; icg < ICg; ++icg) {
size_t ih = oh * stride_h - pad_h + fh;
size_t iw = ow * stride_w - pad_w + fw;
size_t ic = gid * ICg + icg;
if (ih < IH && iw < IW) {
auto fval = filter[(
(((((gid * OH + oh) * OW + ow) * ICg +
icg) * FH +
fh) * FW +
fw) * OCg +
ocg)];
auto sval =
src[n * IC * IH * IW + ic * IH * IW +
ih * IW + iw];
res += fval * sval;
}
}
dst[n * OC * OH * OW + oc * OH * OW + oh * OW + ow] = res;
}
}
void backward_data(
const float* filter, const float* diff, float* grad, size_t N, size_t IC,
size_t IH, size_t IW, size_t FH, size_t FW, size_t OC, size_t OH, size_t OW,
size_t group, size_t pad_h, size_t pad_w, size_t stride_h, size_t stride_w) {
auto ICg = IC / group;
auto OCg = OC / group;
memset(grad, 0, sizeof(float) * N * IC * IH * IW);
for (size_t n = 0; n < N; ++n)
for (size_t gid = 0; gid < group; ++gid)
for (size_t ocg = 0; ocg < OCg; ++ocg)
for (size_t oh = 0; oh < OH; ++oh)
for (size_t ow = 0; ow < OW; ++ow) {
size_t oc = gid * OCg + ocg;
for (size_t fh = 0; fh < FH; ++fh)
for (size_t fw = 0; fw < FW; ++fw)
for (size_t icg = 0; icg < ICg; ++icg) {
size_t ih = oh * stride_h - pad_h + fh;
size_t iw = ow * stride_w - pad_w + fw;
size_t ic = gid * ICg + icg;
if (ih < IH && iw < IW) {
auto fval = filter[(
(((((gid * OH + oh) * OW + ow) * ICg +
icg) * FH +
fh) * FW +
fw) * OCg +
ocg)];
auto dval =
diff[n * OC * OH * OW + oc * OH * OW +
oh * OW + ow];
auto& sval =
grad[n * IC * IH * IW + ic * IH * IW +
ih * IW + iw];
sval += fval * dval;
}
}
}
}
void backward_filter(
const float* src, const float* diff, float* grad, size_t N, size_t IC,
size_t IH, size_t IW, size_t FH, size_t FW, size_t OC, size_t OH, size_t OW,
size_t group, size_t pad_h, size_t pad_w, size_t stride_h, size_t stride_w) {
auto ICg = IC / group;
auto OCg = OC / group;
memset(grad, 0, sizeof(float) * group * OH * OW * ICg * FH * FW * OCg);
for (size_t n = 0; n < N; ++n)
for (size_t gid = 0; gid < group; ++gid)
for (size_t ocg = 0; ocg < OCg; ++ocg)
for (size_t oh = 0; oh < OH; ++oh)
for (size_t ow = 0; ow < OW; ++ow) {
size_t oc = gid * OCg + ocg;
for (size_t fh = 0; fh < FH; ++fh)
for (size_t fw = 0; fw < FW; ++fw)
for (size_t icg = 0; icg < ICg; ++icg) {
size_t ih = oh * stride_h - pad_h + fh;
size_t iw = ow * stride_w - pad_w + fw;
size_t ic = gid * ICg + icg;
if (ih < IH && iw < IW) {
auto sval =
src[n * IC * IH * IW + ic * IH * IW +
ih * IW + iw];
auto& fval = grad[(
(((((gid * OH + oh) * OW + ow) * ICg +
icg) * FH +
fh) * FW +
fw) * OCg +
ocg)];
auto dval =
diff[n * OC * OH * OW + oc * OH * OW +
oh * OW + ow];
fval += sval * dval;
}
}
}
}
}
namespace megdnn {
namespace naive {
void GroupLocalForwardImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
_megdnn_workspace workspace) {
check_exec(src.layout, filter.layout, dst.layout, workspace.size);
auto N = src.layout.shape[0], IC = src.layout.shape[1], IH = src.layout.shape[2],
IW = src.layout.shape[3];
auto group = filter.layout.shape[0];
auto FH = filter.layout.shape[4], FW = filter.layout.shape[5];
auto OC = dst.layout.shape[1], OH = dst.layout.shape[2], OW = dst.layout.shape[3];
auto pad_h = param().pad_h;
auto pad_w = param().pad_w;
auto stride_h = param().stride_h;
auto stride_w = param().stride_w;
if (src.layout.dtype == dtype::Float32() &&
filter.layout.dtype == dtype::Float32() &&
dst.layout.dtype == dtype::Float32()) {
MEGDNN_DISPATCH_CPU_KERN_OPR(
forward(src.ptr<dt_float32>(), filter.ptr<dt_float32>(),
dst.ptr<dt_float32>(), N, IC, IH, IW, FH, FW, OC, OH, OW, group,
pad_h, pad_w, stride_h, stride_w));
} else if (DNN_FLOAT16_SELECT(
src.layout.dtype == dtype::Float16() &&
filter.layout.dtype == dtype::Float16() &&
dst.layout.dtype == dtype::Float16(),
false)) {
DNN_INC_FLOAT16(MEGDNN_DISPATCH_CPU_KERN_OPR(forward(
src.ptr<dt_float16>(), filter.ptr<dt_float16>(), dst.ptr<dt_float16>(),
N, IC, IH, IW, FH, FW, OC, OH, OW, group, pad_h, pad_w, stride_h,
stride_w)););
} else {
megdnn_assert_internal(false);
}
}
void GroupLocalBackwardDataImpl::exec(
_megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) {
check_exec(filter.layout, diff.layout, grad.layout, workspace.size);
auto N = grad.layout.shape[0], IC = grad.layout.shape[1], IH = grad.layout.shape[2],
IW = grad.layout.shape[3];
auto group = filter.layout.shape[0];
auto FH = filter.layout.shape[4], FW = filter.layout.shape[5];
auto OC = diff.layout.shape[1], OH = diff.layout.shape[2],
OW = diff.layout.shape[3];
auto pad_h = param().pad_h;
auto pad_w = param().pad_w;
auto stride_h = param().stride_h;
auto stride_w = param().stride_w;
MEGDNN_DISPATCH_CPU_KERN_OPR(backward_data(
filter.ptr<dt_float32>(), diff.ptr<dt_float32>(), grad.ptr<dt_float32>(), N,
IC, IH, IW, FH, FW, OC, OH, OW, group, pad_h, pad_w, stride_h, stride_w));
}
void GroupLocalBackwardFilterImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) {
check_exec(src.layout, diff.layout, grad.layout, workspace.size);
auto N = src.layout.shape[0], IC = src.layout.shape[1], IH = src.layout.shape[2],
IW = src.layout.shape[3];
auto group = grad.layout.shape[0];
auto FH = grad.layout.shape[4], FW = grad.layout.shape[5];
auto OC = diff.layout.shape[1], OH = diff.layout.shape[2],
OW = diff.layout.shape[3];
auto pad_h = param().pad_h;
auto pad_w = param().pad_w;
auto stride_h = param().stride_h;
auto stride_w = param().stride_w;
MEGDNN_DISPATCH_CPU_KERN_OPR(backward_filter(
src.ptr<dt_float32>(), diff.ptr<dt_float32>(), grad.ptr<dt_float32>(), N,
IC, IH, IW, FH, FW, OC, OH, OW, group, pad_h, pad_w, stride_h, stride_w));
}
} }