#include "src/fallback/convolution/run_conv.h"
#include "midout.h"
#include "src/common/utils.h"
MIDOUT_DECL(megdnn_fallback_conv)
namespace {
bool can_run_xcorr_single_channel_templated(
size_t , size_t , size_t FH, size_t FW, size_t ,
size_t , size_t , size_t , size_t ,
size_t ) {
return FH == FW && FH >= 1 && FH <= 7;
}
template <int ker_size>
void run_xcorr_single_channel_templated_impl(
const float* __restrict src, const float* __restrict filter,
float* __restrict dst, size_t IH, size_t IW, size_t OH, size_t OW, size_t PH,
size_t PW, size_t SH, size_t SW, bool add_mode) {
#define divup(x, y) (((x) + (y)-1) / (y))
#define clear(oh, ow) \
if (!add_mode) { \
dst[(oh)*OW + (ow)] = 0; \
}
#define update(oh, ow, fh, fw) \
dst[(oh)*OW + (ow)] += filter[(fh)*ker_size + (fw)] * \
src[((oh)*SH + (fh)-PH) * IW + ((ow)*SW + (fw)-PW)]
size_t oh_start = divup(PH, SH);
size_t oh_end = IH + PH >= ker_size ? (IH + PH - ker_size) / SH + 1 : 0;
size_t ow_start = divup(PW, SW);
size_t ow_end = IW + PW >= ker_size ? (IW + PW - ker_size) / SW + 1 : 0;
if (oh_start > oh_end)
oh_start = oh_end = 0;
if (ow_start > ow_end)
ow_start = ow_end = 0;
for (size_t oh = 0; oh < oh_start; ++oh)
for (size_t ow = 0; ow < OW; ++ow) {
clear(oh, ow);
int ih = oh * SH - PH;
int iw = ow * SW - PW;
for (int fh = 0; fh < ker_size; ++fh)
if (ih + fh >= 0 && ih + fh < (int)IH)
for (int fw = 0; fw < ker_size; ++fw)
if (iw + fw >= 0 && iw + fw < (int)IW) {
update(oh, ow, fh, fw);
}
}
for (size_t oh = oh_start; oh < oh_end; ++oh) {
for (size_t ow = 0; ow < ow_start; ++ow) {
clear(oh, ow);
int iw = ow * SW - PW;
for (int fh = 0; fh < ker_size; ++fh)
for (int fw = 0; fw < ker_size; ++fw) {
if (iw + fw >= 0 && iw + fw < (int)IW)
update(oh, ow, fh, fw);
}
}
for (size_t ow = ow_start; ow < ow_end; ++ow) {
clear(oh, ow);
for (int fh = 0; fh < ker_size; ++fh)
for (int fw = 0; fw < ker_size; ++fw) {
update(oh, ow, fh, fw);
}
}
for (size_t ow = ow_end; ow < OW; ++ow) {
clear(oh, ow);
int iw = ow * SW - PW;
for (int fh = 0; fh < ker_size; ++fh)
for (int fw = 0; fw < ker_size; ++fw) {
if (iw + fw >= 0 && iw + fw < (int)IW)
update(oh, ow, fh, fw);
}
}
}
for (size_t oh = oh_end; oh < OH; ++oh) {
for (size_t ow = 0; ow < OW; ++ow) {
clear(oh, ow);
int ih = oh * SH - PH;
int iw = ow * SW - PW;
for (int fh = 0; fh < ker_size; ++fh)
if (ih + fh >= 0 && ih + fh < (int)IH)
for (int fw = 0; fw < ker_size; ++fw)
if (iw + fw >= 0 && iw + fw < (int)IW) {
update(oh, ow, fh, fw);
}
}
}
#undef divup
#undef clear
#undef update
}
void run_xcorr_single_channel_templated(
const float* src, const float* filter, float* dst, size_t IH, size_t IW,
size_t FH, size_t FW, size_t OH, size_t OW, size_t PH, size_t PW, size_t SH,
size_t SW, bool add_mode) {
(void)FW;
#define DISPATCH(ker_size) \
if (FH == ker_size) { \
MIDOUT_BEGIN(megdnn_fallback_conv, ker_size) { \
run_xcorr_single_channel_templated_impl<ker_size>( \
src, filter, dst, IH, IW, OH, OW, PH, PW, SH, SW, add_mode); \
} \
MIDOUT_END(); \
return; \
}
DISPATCH(1)
DISPATCH(2)
DISPATCH(3)
DISPATCH(4)
DISPATCH(5)
DISPATCH(6)
DISPATCH(7)
#undef DISPATCH
megdnn_throw("internal error in conv template dispatching: impossible");
}
void run_xcorr_single_channel_nontemplated(
const float* src, const float* filter, float* dst, size_t IH, size_t IW,
size_t FH_, size_t FW_, size_t OH, size_t OW, size_t PH, size_t PW, size_t SH,
size_t SW, bool add_mode) {
#define divup(x, y) (((x) + (y)-1) / (y))
#define clear(oh, ow) \
if (!add_mode) { \
dst[(oh)*OW + (ow)] = 0; \
}
#define update(oh, ow, fh, fw) \
dst[(oh)*OW + (ow)] += filter[(fh)*FW + (fw)] * \
src[((oh)*SH + (fh)-PH) * IW + ((ow)*SW + (fw)-PW)]
int FH = FH_, FW = FW_;
size_t oh_start = divup(PH, SH);
size_t oh_end = IH + PH >= FH_ ? (IH + PH - FH) / SH + 1 : 0;
size_t ow_start = divup(PW, SW);
size_t ow_end = IW + PW >= FW_ ? (IW + PW - FW) / SW + 1 : 0;
if (oh_start > oh_end)
oh_start = oh_end = 0;
if (ow_start > ow_end)
ow_start = ow_end = 0;
for (size_t oh = 0; oh < oh_start; ++oh)
for (size_t ow = 0; ow < OW; ++ow) {
clear(oh, ow);
int ih = oh * SH - PH;
int iw = ow * SW - PW;
for (int fh = 0; fh < FH; ++fh)
if (ih + fh >= 0 && ih + fh < (int)IH)
for (int fw = 0; fw < FW; ++fw)
if (iw + fw >= 0 && iw + fw < (int)IW) {
update(oh, ow, fh, fw);
}
}
for (size_t oh = oh_start; oh < oh_end; ++oh) {
for (size_t ow = 0; ow < ow_start; ++ow) {
clear(oh, ow);
int iw = ow * SW - PW;
for (int fh = 0; fh < FH; ++fh)
for (int fw = 0; fw < FW; ++fw) {
if (iw + fw >= 0 && iw + fw < (int)IW)
update(oh, ow, fh, fw);
}
}
for (size_t ow = ow_start; ow < ow_end; ++ow) {
clear(oh, ow);
for (int fh = 0; fh < FH; ++fh)
for (int fw = 0; fw < FW; ++fw) {
update(oh, ow, fh, fw);
}
}
for (size_t ow = ow_end; ow < OW; ++ow) {
clear(oh, ow);
int iw = ow * SW - PW;
for (int fh = 0; fh < FH; ++fh)
for (int fw = 0; fw < FW; ++fw) {
if (iw + fw >= 0 && iw + fw < (int)IW)
update(oh, ow, fh, fw);
}
}
}
for (size_t oh = oh_end; oh < OH; ++oh) {
for (size_t ow = 0; ow < OW; ++ow) {
clear(oh, ow);
int ih = oh * SH - PH;
int iw = ow * SW - PW;
for (int fh = 0; fh < FH; ++fh)
if (ih + fh >= 0 && ih + fh < (int)IH)
for (int fw = 0; fw < FW; ++fw)
if (iw + fw >= 0 && iw + fw < (int)IW) {
update(oh, ow, fh, fw);
}
}
}
#undef divup
#undef clear
#undef update
}
void run_xcorr_single_channel(
const float* src, const float* filter, float* dst, size_t IH, size_t IW,
size_t FH, size_t FW, size_t OH, size_t OW, size_t PH, size_t PW, size_t SH,
size_t SW, bool add_mode) {
if (can_run_xcorr_single_channel_templated(
IH, IW, FH, FW, OH, OW, PH, PW, SH, SW)) {
run_xcorr_single_channel_templated(
src, filter, dst, IH, IW, FH, FW, OH, OW, PH, PW, SH, SW, add_mode);
} else {
MIDOUT_BEGIN(megdnn_fallback_conv, void) {
run_xcorr_single_channel_nontemplated(
src, filter, dst, IH, IW, FH, FW, OH, OW, PH, PW, SH, SW, add_mode);
}
MIDOUT_END();
}
}
template <int ker_size>
void conv_backdata_single_channel_templated_impl(
const float* __restrict diff, const float* __restrict filter,
float* __restrict grad, size_t IH, size_t IW, size_t OH, size_t OW, size_t PH,
size_t PW, size_t SH, size_t SW) {
#define divup(x, y) (((x) + (y)-1) / (y))
#define update(oh, ow, fh, fw, val) \
grad[(oh + fh) * OW + (ow + fw)] += filter[(fh)*ker_size + (fw)] * val
size_t ih_start = divup(PH, SH);
size_t ih_end = OH + PH >= ker_size ? (OH + PH - ker_size) / SH + 1 : 0;
size_t iw_start = divup(PW, SW);
size_t iw_end = OW + PW >= ker_size ? (OW + PW - ker_size) / SW + 1 : 0;
if (ih_start > ih_end)
ih_start = ih_end = 0;
if (iw_start > iw_end)
iw_start = iw_end = 0;
for (size_t ih = 0; ih < ih_start; ++ih)
for (size_t iw = 0; iw < IW; ++iw) {
int oh = ih * SH - PH;
int ow = iw * SW - PW;
float val = diff[ih * IW + iw];
for (int fh = 0; fh < ker_size; ++fh)
if (oh + fh >= 0 && oh + fh < (int)OH)
for (int fw = 0; fw < ker_size; ++fw)
if (ow + fw >= 0 && ow + fw < (int)OW) {
update(oh, ow, fh, fw, val);
}
}
for (size_t ih = ih_start; ih < ih_end; ++ih) {
int oh = ih * SH - PH;
for (size_t iw = 0; iw < iw_start; ++iw) {
int ow = iw * SW - PW;
float val = diff[ih * IW + iw];
for (int fh = 0; fh < ker_size; ++fh)
for (int fw = 0; fw < ker_size; ++fw) {
if (ow + fw >= 0 && ow + fw < (int)OW)
update(oh, ow, fh, fw, val);
}
}
for (size_t iw = iw_start; iw < iw_end; ++iw) {
int ow = iw * SW - PW;
float val = diff[ih * IW + iw];
for (int fh = 0; fh < ker_size; ++fh)
for (int fw = 0; fw < ker_size; ++fw) {
update(oh, ow, fh, fw, val);
}
}
for (size_t iw = iw_end; iw < IW; ++iw) {
int ow = iw * SW - PW;
float val = diff[ih * IW + iw];
for (int fh = 0; fh < ker_size; ++fh)
for (int fw = 0; fw < ker_size; ++fw) {
if (ow + fw >= 0 && ow + fw < (int)OW)
update(oh, ow, fh, fw, val);
}
}
}
for (size_t ih = ih_end; ih < IH; ++ih) {
for (size_t iw = 0; iw < IW; ++iw) {
int oh = ih * SH - PH;
int ow = iw * SW - PW;
float val = diff[ih * IW + iw];
for (int fh = 0; fh < ker_size; ++fh)
if (oh + fh >= 0 && oh + fh < (int)OH)
for (int fw = 0; fw < ker_size; ++fw)
if (ow + fw >= 0 && ow + fw < (int)OW) {
update(oh, ow, fh, fw, val);
}
}
}
#undef divup
#undef update
}
void conv_backdata_single_channel_templated(
const float* src, const float* filter, float* dst, size_t IH, size_t IW,
size_t FH, size_t FW, size_t OH, size_t OW, size_t PH, size_t PW, size_t SH,
size_t SW) {
megdnn_ignore(FW);
#define DISPATCH(ker_size) \
if (FH == ker_size) { \
MIDOUT_BEGIN(megdnn_fallback_conv, ker_size) { \
conv_backdata_single_channel_templated_impl<ker_size>( \
src, filter, dst, IH, IW, OH, OW, PH, PW, SH, SW); \
} \
MIDOUT_END(); \
return; \
}
DISPATCH(1)
DISPATCH(2)
DISPATCH(3)
DISPATCH(4)
DISPATCH(5)
DISPATCH(6)
DISPATCH(7)
#undef DISPATCH
megdnn_throw("internal error in conv_backdata template dispatching: impossible");
}
void conv_backdata_single_channel_nontemplated(
const float* diff, const float* filter, float* grad, size_t IH, size_t IW,
size_t FH_, size_t FW_, size_t OH, size_t OW, size_t PH, size_t PW, size_t SH,
size_t SW) {
#define divup(x, y) (((x) + (y)-1) / (y))
#define update(oh, ow, fh, fw, val) \
grad[(oh + fh) * OW + (ow + fw)] += filter[(fh)*FW + (fw)] * val
int FH = FH_, FW = FW_;
size_t ih_start = divup(PH, SH);
size_t ih_end = OH + PH >= FH_ ? (OH + PH - FH) / SH + 1 : 0;
size_t iw_start = divup(PW, SW);
size_t iw_end = OW + PW >= FW_ ? (OW + PW - FW) / SW + 1 : 0;
if (ih_start > ih_end)
ih_start = ih_end = 0;
if (iw_start > iw_end)
iw_start = iw_end = 0;
for (size_t ih = 0; ih < ih_start; ++ih)
for (size_t iw = 0; iw < IW; ++iw) {
int oh = ih * SH - PH;
int ow = iw * SW - PW;
float val = diff[ih * IW + iw];
for (int fh = 0; fh < FH; ++fh)
if (oh + fh >= 0 && oh + fh < (int)OH)
for (int fw = 0; fw < FW; ++fw)
if (ow + fw >= 0 && ow + fw < (int)OW) {
update(oh, ow, fh, fw, val);
}
}
for (size_t ih = ih_start; ih < ih_end; ++ih) {
int oh = ih * SH - PH;
for (size_t iw = 0; iw < iw_start; ++iw) {
int ow = iw * SW - PW;
float val = diff[ih * IW + iw];
for (int fh = 0; fh < FH; ++fh)
for (int fw = 0; fw < FW; ++fw) {
if (ow + fw >= 0 && ow + fw < (int)OW)
update(oh, ow, fh, fw, val);
}
}
for (size_t iw = iw_start; iw < iw_end; ++iw) {
int ow = iw * SW - PW;
float val = diff[ih * IW + iw];
for (int fh = 0; fh < FH; ++fh)
for (int fw = 0; fw < FW; ++fw) {
update(oh, ow, fh, fw, val);
}
}
for (size_t iw = iw_end; iw < IW; ++iw) {
int ow = iw * SW - PW;
float val = diff[ih * IW + iw];
for (int fh = 0; fh < FH; ++fh)
for (int fw = 0; fw < FW; ++fw) {
if (ow + fw >= 0 && ow + fw < (int)OW)
update(oh, ow, fh, fw, val);
}
}
}
for (size_t ih = ih_end; ih < IH; ++ih) {
for (size_t iw = 0; iw < IW; ++iw) {
int oh = ih * SH - PH;
int ow = iw * SW - PW;
float val = diff[ih * IW + iw];
for (int fh = 0; fh < FH; ++fh)
if (oh + fh >= 0 && oh + fh < (int)OH)
for (int fw = 0; fw < FW; ++fw)
if (ow + fw >= 0 && ow + fw < (int)OW) {
update(oh, ow, fh, fw, val);
}
}
}
#undef divup
#undef update
}
void conv_backdata_single_channel(
const float* diff, const float* filter, float* grad, size_t IH, size_t IW,
size_t FH, size_t FW, size_t OH, size_t OW, size_t PH, size_t PW, size_t SH,
size_t SW) {
if (can_run_xcorr_single_channel_templated(
IH, IW, FH, FW, OH, OW, PH, PW, SH, SW)) {
conv_backdata_single_channel_templated(
diff, filter, grad, IH, IW, FH, FW, OH, OW, PH, PW, SH, SW);
} else {
MIDOUT_BEGIN(megdnn_fallback_conv, void) {
conv_backdata_single_channel_nontemplated(
diff, filter, grad, IH, IW, FH, FW, OH, OW, PH, PW, SH, SW);
}
MIDOUT_END();
}
}
}
namespace megdnn {
namespace fallback {
namespace convolution {
void run_conv(
const float* src, const float* filter, float* dst, void* workspace, size_t IH,
size_t IW, size_t IC, size_t FH, size_t FW, size_t OH, size_t OW, size_t OC,
size_t PH, size_t PW, size_t SH, size_t SW, bool xcorr) {
for (size_t oc = 0; oc < OC; ++oc)
for (size_t ic = 0; ic < IC; ++ic) {
const float* fut = filter + oc * IC * FH * FW + ic * FH * FW;
const float* f;
if (!xcorr) {
f = (float*)workspace;
for (size_t fh = 0; fh < FH; ++fh)
for (size_t fw = 0; fw < FW; ++fw) {
((float*)f)[fh * FW + fw] =
fut[(FH - fh - 1) * FW + (FW - fw - 1)];
}
} else {
f = fut;
}
run_xcorr_single_channel(
src + ic * IH * IW, f, dst + oc * OH * OW, IH, IW, FH, FW, OH, OW,
PH, PW, SH, SW, ic > 0);
}
}
void run_conv_backward_data(
const float* diff, const float* filter, float* grad, void* workspace, size_t IH,
size_t IW, size_t IC, size_t FH, size_t FW, size_t OH, size_t OW, size_t OC,
size_t PH, size_t PW, size_t SH, size_t SW, bool xcorr) {
std::memset(grad, 0, sizeof(float) * IC * OH * OW);
for (size_t oc = 0; oc < OC; ++oc)
for (size_t ic = 0; ic < IC; ++ic) {
const float* fut = filter + oc * IC * FH * FW + ic * FH * FW;
const float* f;
if (!xcorr) {
f = (float*)workspace;
for (size_t fh = 0; fh < FH; ++fh)
for (size_t fw = 0; fw < FW; ++fw) {
((float*)f)[fh * FW + fw] =
fut[(FH - fh - 1) * FW + (FW - fw - 1)];
}
} else {
f = fut;
}
conv_backdata_single_channel(
diff + oc * IH * IW, f, grad + ic * OH * OW, IH, IW, FH, FW, OH, OW,
PH, PW, SH, SW);
}
}
} } }