#include "src/common/local/local_decl.inl"
#include "src/common/macro_helper.h"
#include "src/common/utils.h"
namespace {
using namespace megdnn;
template <int N, int OC>
void local_xcorr_tpl(const LocalKParam& kparam) MEGDNN_SIMD_ATTRIBUTE_TARGET;
template <int N, int OC>
void local_xcorr_tpl(const LocalKParam& kparam) {
const float* src = static_cast<const float*>(kparam.src.get_ptr());
const float* filter = static_cast<const float*>(kparam.filter.get_ptr());
float* dst = static_cast<float*>(kparam.dst.get_ptr());
float* workspace = static_cast<float*>(kparam.workspace);
const int IC = kparam.ic, IH = kparam.ih, IW = kparam.iw, OH = kparam.oh,
OW = kparam.ow, FH = kparam.fh, FW = kparam.fw;
const uint32_t PH = kparam.ph, PW = kparam.pw, SH = kparam.sh, SW = kparam.sw;
const ptrdiff_t INP_BS = kparam.inp_bs, OUT_BS = kparam.out_bs;
float* dst2 = workspace;
const int width = MEGDNN_SIMD_WIDTH;
memset(dst2, 0, sizeof(float) * OH * OW * N * OC);
float* dst2_hwnc = dst2;
rep(oh, OH) rep(ow, OW) {
const float* src_bak = src;
rep(ic, IC) {
rep(fh, FH) for (int fw = 0; fw < FW; ++fw, filter += OC) {
int ih = -PH + oh * SH + fh;
int iw = -PW + ow * SW + fw;
if (ih < 0 || ih >= IH || iw < 0 || iw >= IW)
continue;
float* dst2_bak = dst2;
rep(n, N) {
float s = src[n * INP_BS + ih * IW + iw];
const float* filter_bak = filter;
MEGDNN_SIMD_TYPE vs = MEGDNN_SIMD_SET1(s);
int oc = 0;
for (; oc + 4 * width <= OC; oc += 4 * width, filter += 4 * width) {
MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width);
MEGDNN_SIMD_TYPE vf1 = MEGDNN_SIMD_LOADU(filter + 1 * width);
MEGDNN_SIMD_TYPE vf2 = MEGDNN_SIMD_LOADU(filter + 2 * width);
MEGDNN_SIMD_TYPE vf3 = MEGDNN_SIMD_LOADU(filter + 3 * width);
MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width);
MEGDNN_SIMD_TYPE vd1 = MEGDNN_SIMD_LOADU(dst2 + oc + 1 * width);
MEGDNN_SIMD_TYPE vd2 = MEGDNN_SIMD_LOADU(dst2 + oc + 2 * width);
MEGDNN_SIMD_TYPE vd3 = MEGDNN_SIMD_LOADU(dst2 + oc + 3 * width);
vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0);
vd1 = MEGDNN_SIMD_FMADD(vf1, vs, vd1);
vd2 = MEGDNN_SIMD_FMADD(vf2, vs, vd2);
vd3 = MEGDNN_SIMD_FMADD(vf3, vs, vd3);
MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0);
MEGDNN_SIMD_STOREU(dst2 + oc + 1 * width, vd1);
MEGDNN_SIMD_STOREU(dst2 + oc + 2 * width, vd2);
MEGDNN_SIMD_STOREU(dst2 + oc + 3 * width, vd3);
}
if (oc + 2 * width <= OC) {
MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width);
MEGDNN_SIMD_TYPE vf1 = MEGDNN_SIMD_LOADU(filter + 1 * width);
MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width);
MEGDNN_SIMD_TYPE vd1 = MEGDNN_SIMD_LOADU(dst2 + oc + 1 * width);
vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0);
vd1 = MEGDNN_SIMD_FMADD(vf1, vs, vd1);
MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0);
MEGDNN_SIMD_STOREU(dst2 + oc + 1 * width, vd1);
oc += 2 * width;
filter += 2 * width;
}
if (oc + 1 * width <= OC) {
MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width);
MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width);
vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0);
MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0);
oc += 1 * width;
filter += 1 * width;
}
for (; oc < OC; ++oc, ++filter) {
dst2[oc] += s * (*filter);
}
filter = filter_bak;
dst2 += OC;
}
dst2 = dst2_bak;
}
src += IH * IW;
}
src = src_bak;
dst2 += N * OC;
}
transpose_knc2nsck(dst2_hwnc, dst, OH * OW, N, OC, OUT_BS);
}
void local_xcorr_generic(const LocalKParam& kparam) MEGDNN_SIMD_ATTRIBUTE_TARGET;
void local_xcorr_generic(const LocalKParam& kparam) {
UNPACK_LOCAL_FLOAT_NONCONTIG_BATCH_KERN_PARAM(kparam, float);
float* dst2 = workspace;
const int width = MEGDNN_SIMD_WIDTH;
memset(dst2, 0, sizeof(float) * OH * OW * N * OC);
float* dst2_hwnc = dst2;
rep(oh, OH) rep(ow, OW) {
const float* src_bak = src;
rep(ic, IC) {
rep(fh, FH) for (int fw = 0; fw < FW; ++fw, filter += OC) {
int ih = -PH + oh * SH + fh;
int iw = -PW + ow * SW + fw;
if (ih < 0 || ih >= IH || iw < 0 || iw >= IW)
continue;
float* dst2_bak = dst2;
rep(n, N) {
float s = src[n * INP_BS + ih * IW + iw];
const float* filter_bak = filter;
MEGDNN_SIMD_TYPE vs = MEGDNN_SIMD_SET1(s);
int oc = 0;
for (; oc + 4 * width <= OC; oc += 4 * width, filter += 4 * width) {
MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width);
MEGDNN_SIMD_TYPE vf1 = MEGDNN_SIMD_LOADU(filter + 1 * width);
MEGDNN_SIMD_TYPE vf2 = MEGDNN_SIMD_LOADU(filter + 2 * width);
MEGDNN_SIMD_TYPE vf3 = MEGDNN_SIMD_LOADU(filter + 3 * width);
MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width);
MEGDNN_SIMD_TYPE vd1 = MEGDNN_SIMD_LOADU(dst2 + oc + 1 * width);
MEGDNN_SIMD_TYPE vd2 = MEGDNN_SIMD_LOADU(dst2 + oc + 2 * width);
MEGDNN_SIMD_TYPE vd3 = MEGDNN_SIMD_LOADU(dst2 + oc + 3 * width);
vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0);
vd1 = MEGDNN_SIMD_FMADD(vf1, vs, vd1);
vd2 = MEGDNN_SIMD_FMADD(vf2, vs, vd2);
vd3 = MEGDNN_SIMD_FMADD(vf3, vs, vd3);
MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0);
MEGDNN_SIMD_STOREU(dst2 + oc + 1 * width, vd1);
MEGDNN_SIMD_STOREU(dst2 + oc + 2 * width, vd2);
MEGDNN_SIMD_STOREU(dst2 + oc + 3 * width, vd3);
}
if (oc + 2 * width <= OC) {
MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width);
MEGDNN_SIMD_TYPE vf1 = MEGDNN_SIMD_LOADU(filter + 1 * width);
MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width);
MEGDNN_SIMD_TYPE vd1 = MEGDNN_SIMD_LOADU(dst2 + oc + 1 * width);
vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0);
vd1 = MEGDNN_SIMD_FMADD(vf1, vs, vd1);
MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0);
MEGDNN_SIMD_STOREU(dst2 + oc + 1 * width, vd1);
oc += 2 * width;
filter += 2 * width;
}
if (oc + 1 * width <= OC) {
MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width);
MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width);
vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0);
MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0);
oc += 1 * width;
filter += 1 * width;
}
for (; oc < OC; ++oc, ++filter) {
dst2[oc] += s * (*filter);
}
filter = filter_bak;
dst2 += OC;
}
dst2 = dst2_bak;
}
src += IH * IW;
}
src = src_bak;
dst2 += N * OC;
}
transpose_knc2nsck(dst2_hwnc, dst, OH * OW, N, OC, OUT_BS);
}
template <int N, int OC>
void local_conv_tpl(const LocalKParam& kparam) MEGDNN_SIMD_ATTRIBUTE_TARGET;
template <int N, int OC>
void local_conv_tpl(const LocalKParam& kparam) {
const float* src = static_cast<const float*>(kparam.src.get_ptr());
const float* filter = static_cast<const float*>(kparam.filter.get_ptr());
float* dst = static_cast<float*>(kparam.dst.get_ptr());
float* workspace = static_cast<float*>(kparam.workspace);
const int IC = kparam.ic, IH = kparam.ih, IW = kparam.iw, OH = kparam.oh,
OW = kparam.ow, FH = kparam.fh, FW = kparam.fw;
const uint32_t PH = kparam.ph, PW = kparam.pw, SH = kparam.sh, SW = kparam.sw;
const ptrdiff_t INP_BS = kparam.inp_bs, OUT_BS = kparam.out_bs;
float* dst2 = workspace;
const int width = MEGDNN_SIMD_WIDTH;
memset(dst2, 0, sizeof(float) * OH * OW * N * OC);
float* dst2_hwnc = dst2;
rep(oh, OH) rep(ow, OW) {
const float* src_bak = src;
rep(ic, IC) {
rep(fh, FH) for (int fw = 0; fw < FW; ++fw, filter += OC) {
int ih = -PH + oh * SH + (FH - fh - 1);
int iw = -PW + ow * SW + (FW - fw - 1);
if (ih < 0 || ih >= IH || iw < 0 || iw >= IW)
continue;
float* dst2_bak = dst2;
rep(n, N) {
float s = src[n * INP_BS + ih * IW + iw];
const float* filter_bak = filter;
MEGDNN_SIMD_TYPE vs = MEGDNN_SIMD_SET1(s);
int oc = 0;
for (; oc + 4 * width <= OC; oc += 4 * width, filter += 4 * width) {
MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width);
MEGDNN_SIMD_TYPE vf1 = MEGDNN_SIMD_LOADU(filter + 1 * width);
MEGDNN_SIMD_TYPE vf2 = MEGDNN_SIMD_LOADU(filter + 2 * width);
MEGDNN_SIMD_TYPE vf3 = MEGDNN_SIMD_LOADU(filter + 3 * width);
MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width);
MEGDNN_SIMD_TYPE vd1 = MEGDNN_SIMD_LOADU(dst2 + oc + 1 * width);
MEGDNN_SIMD_TYPE vd2 = MEGDNN_SIMD_LOADU(dst2 + oc + 2 * width);
MEGDNN_SIMD_TYPE vd3 = MEGDNN_SIMD_LOADU(dst2 + oc + 3 * width);
vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0);
vd1 = MEGDNN_SIMD_FMADD(vf1, vs, vd1);
vd2 = MEGDNN_SIMD_FMADD(vf2, vs, vd2);
vd3 = MEGDNN_SIMD_FMADD(vf3, vs, vd3);
MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0);
MEGDNN_SIMD_STOREU(dst2 + oc + 1 * width, vd1);
MEGDNN_SIMD_STOREU(dst2 + oc + 2 * width, vd2);
MEGDNN_SIMD_STOREU(dst2 + oc + 3 * width, vd3);
}
if (oc + 2 * width <= OC) {
MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width);
MEGDNN_SIMD_TYPE vf1 = MEGDNN_SIMD_LOADU(filter + 1 * width);
MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width);
MEGDNN_SIMD_TYPE vd1 = MEGDNN_SIMD_LOADU(dst2 + oc + 1 * width);
vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0);
vd1 = MEGDNN_SIMD_FMADD(vf1, vs, vd1);
MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0);
MEGDNN_SIMD_STOREU(dst2 + oc + 1 * width, vd1);
oc += 2 * width;
filter += 2 * width;
}
if (oc + 1 * width <= OC) {
MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width);
MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width);
vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0);
MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0);
oc += 1 * width;
filter += 1 * width;
}
for (; oc < OC; ++oc, ++filter) {
dst2[oc] += s * (*filter);
}
filter = filter_bak;
dst2 += OC;
}
dst2 = dst2_bak;
}
src += IH * IW;
}
src = src_bak;
dst2 += N * OC;
}
transpose_knc2nsck(dst2_hwnc, dst, OH * OW, N, OC, OUT_BS);
}
void local_conv_generic(const LocalKParam& kparam) MEGDNN_SIMD_ATTRIBUTE_TARGET;
void local_conv_generic(const LocalKParam& kparam) {
UNPACK_LOCAL_FLOAT_NONCONTIG_BATCH_KERN_PARAM(kparam, float);
float* dst2 = workspace;
const int width = MEGDNN_SIMD_WIDTH;
memset(dst2, 0, sizeof(float) * OH * OW * N * OC);
float* dst2_hwnc = dst2;
rep(oh, OH) rep(ow, OW) {
const float* src_bak = src;
rep(ic, IC) {
rep(fh, FH) for (int fw = 0; fw < FW; ++fw, filter += OC) {
int ih = -PH + oh * SH + (FH - fh - 1);
int iw = -PW + ow * SW + (FW - fw - 1);
if (ih < 0 || ih >= IH || iw < 0 || iw >= IW)
continue;
float* dst2_bak = dst2;
rep(n, N) {
float s = src[n * INP_BS + ih * IW + iw];
const float* filter_bak = filter;
MEGDNN_SIMD_TYPE vs = MEGDNN_SIMD_SET1(s);
int oc = 0;
for (; oc + 4 * width <= OC; oc += 4 * width, filter += 4 * width) {
MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width);
MEGDNN_SIMD_TYPE vf1 = MEGDNN_SIMD_LOADU(filter + 1 * width);
MEGDNN_SIMD_TYPE vf2 = MEGDNN_SIMD_LOADU(filter + 2 * width);
MEGDNN_SIMD_TYPE vf3 = MEGDNN_SIMD_LOADU(filter + 3 * width);
MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width);
MEGDNN_SIMD_TYPE vd1 = MEGDNN_SIMD_LOADU(dst2 + oc + 1 * width);
MEGDNN_SIMD_TYPE vd2 = MEGDNN_SIMD_LOADU(dst2 + oc + 2 * width);
MEGDNN_SIMD_TYPE vd3 = MEGDNN_SIMD_LOADU(dst2 + oc + 3 * width);
vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0);
vd1 = MEGDNN_SIMD_FMADD(vf1, vs, vd1);
vd2 = MEGDNN_SIMD_FMADD(vf2, vs, vd2);
vd3 = MEGDNN_SIMD_FMADD(vf3, vs, vd3);
MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0);
MEGDNN_SIMD_STOREU(dst2 + oc + 1 * width, vd1);
MEGDNN_SIMD_STOREU(dst2 + oc + 2 * width, vd2);
MEGDNN_SIMD_STOREU(dst2 + oc + 3 * width, vd3);
}
if (oc + 2 * width <= OC) {
MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width);
MEGDNN_SIMD_TYPE vf1 = MEGDNN_SIMD_LOADU(filter + 1 * width);
MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width);
MEGDNN_SIMD_TYPE vd1 = MEGDNN_SIMD_LOADU(dst2 + oc + 1 * width);
vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0);
vd1 = MEGDNN_SIMD_FMADD(vf1, vs, vd1);
MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0);
MEGDNN_SIMD_STOREU(dst2 + oc + 1 * width, vd1);
oc += 2 * width;
filter += 2 * width;
}
if (oc + 1 * width <= OC) {
MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width);
MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width);
vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0);
MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0);
oc += 1 * width;
filter += 1 * width;
}
for (; oc < OC; ++oc, ++filter) {
dst2[oc] += s * (*filter);
}
filter = filter_bak;
dst2 += OC;
}
dst2 = dst2_bak;
}
src += IH * IW;
}
src = src_bak;
dst2 += N * OC;
}
transpose_knc2nsck(dst2_hwnc, dst, OH * OW, N, OC, OUT_BS);
}
}
namespace megdnn {
#define FUNC_NAME CONCAT_STR(local_xcorr_, MEGDNN_SIMD_NAME)
void FUNC_NAME(const LocalKParam& kparam) {
auto N = kparam.n, OC = kparam.oc;
#define DISPATCH_WITH_N_OC(N, OC) \
do { \
local_xcorr_tpl<N, OC>(kparam); \
return; \
} while (0)
#define DISPATCH_WITH_N(N) \
switch (OC) { \
case 16: \
DISPATCH_WITH_N_OC(N, 16); \
break; \
case 32: \
DISPATCH_WITH_N_OC(N, 32); \
break; \
case 48: \
DISPATCH_WITH_N_OC(N, 48); \
break; \
case 64: \
DISPATCH_WITH_N_OC(N, 64); \
break; \
}
#define DISPATCH() \
switch (N) { \
case 1: \
DISPATCH_WITH_N(1); \
break; \
case 2: \
DISPATCH_WITH_N(2); \
break; \
}
DISPATCH();
#undef DISPATCH
#undef DISPATCH_WITH_N
#undef DISPATCH_WITH_N_OC
local_xcorr_generic(kparam);
}
#undef FUNC_NAME
#define FUNC_NAME CONCAT_STR(local_conv_, MEGDNN_SIMD_NAME)
void FUNC_NAME(const LocalKParam& kparam) {
auto N = kparam.n, OC = kparam.oc;
#define DISPATCH_WITH_N_OC(N, OC) \
do { \
local_conv_tpl<N, OC>(kparam); \
return; \
} while (0)
#define DISPATCH_WITH_N(N) \
switch (OC) { \
case 16: \
DISPATCH_WITH_N_OC(N, 16); \
break; \
case 32: \
DISPATCH_WITH_N_OC(N, 32); \
break; \
case 48: \
DISPATCH_WITH_N_OC(N, 48); \
break; \
case 64: \
DISPATCH_WITH_N_OC(N, 64); \
break; \
}
#define DISPATCH() \
switch (N) { \
case 1: \
DISPATCH_WITH_N(1); \
break; \
case 2: \
DISPATCH_WITH_N(2); \
break; \
}
DISPATCH();
#undef DISPATCH
#undef DISPATCH_WITH_N
#undef DISPATCH_WITH_N_OC
local_conv_generic(kparam);
}
#undef FUNC_NAME
}
#include "src/common/macro_helper_epilogue.h"