#include "src/arm_common/conv_bias/f16/strategy.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/utils.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/winograd/winograd.h"
#include "src/arm_common/conv_bias/f16/helper.h"
#include "src/arm_common/elemwise_helper/op_unary.h"
#include "src/naive/matrix_mul/matrix_mul_helper.h"
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#include "midout.h"
MIDOUT_DECL(megdnn_arm_common_winograd_fp16_F23)
using namespace megdnn;
using namespace arm_common;
namespace {
void transpose_4x4(const __fp16* src, __fp16* dst, int lda, int ldb) {
float16x4x2_t a0, a1;
a0.val[0] = vld1_f16(src + 0 * lda); a0.val[1] = vld1_f16(src + 1 * lda); a1.val[0] = vld1_f16(src + 2 * lda); a1.val[1] = vld1_f16(src + 3 * lda); float16x4x2_t b0 = vzip_f16(a0.val[0], a1.val[0]); float16x4x2_t b1 = vzip_f16(a0.val[1], a1.val[1]); float16x4x2_t c0 = vzip_f16(b0.val[0], b1.val[0]); float16x4x2_t c1 = vzip_f16(b0.val[1], b1.val[1]); vst1_f16(dst + 0 * ldb, c0.val[0]);
vst1_f16(dst + 1 * ldb, c0.val[1]);
vst1_f16(dst + 2 * ldb, c1.val[0]);
vst1_f16(dst + 3 * ldb, c1.val[1]);
}
struct InputTransform2X3 {
template <bool inner>
static void prepare(
const __fp16* input, __fp16* patch, __fp16* patchT, int ih_start,
int iw_start, size_t IH, size_t IW, size_t ic, size_t IC) {
constexpr size_t alpha = 2 + 3 - 1;
constexpr size_t alpha4 = alpha * 4;
if (!(inner && ic + 4 < IC)) {
memset(patch, 0, sizeof(__fp16) * 4 * alpha * alpha);
}
if (inner) {
const __fp16* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start;
for (size_t ico = 0; ico < 4; ++ico) {
if (ic + ico < IC) {
auto v0 = vld1_f16(input_ptr);
auto v1 = vld1_f16(input_ptr + IW);
auto v2 = vld1_f16(input_ptr + IW * 2);
auto v3 = vld1_f16(input_ptr + IW * 3);
vst1_f16(patch + ico * alpha4 + 0 * 4, v0);
vst1_f16(patch + ico * alpha4 + 1 * 4, v1);
vst1_f16(patch + ico * alpha4 + 2 * 4, v2);
vst1_f16(patch + ico * alpha4 + 3 * 4, v3);
input_ptr += IH * IW;
}
}
} else {
int ih0_act = std::max<int>(ih_start, 0),
ih1_act = std::min<int>(ih_start + alpha, IH),
iw0_act = std::max<int>(iw_start, 0),
iw1_act = std::min<int>(iw_start + alpha, IW);
for (size_t ico = 0; ico < 4; ++ico) {
if (ic + ico < IC) {
for (int ih = ih0_act; ih < ih1_act; ++ih) {
for (int iw = iw0_act; iw < iw1_act; ++iw) {
size_t iho = ih - ih_start, iwo = iw - iw_start;
patch[ico * alpha4 + iho * 4 + iwo] =
input[(ic + ico) * IH * IW + ih * IW + iw];
}
}
}
}
}
transpose_4x4(patch + 0 * 1, patchT + 0 * 4, 16, 4);
transpose_4x4(patch + 4 * 1, patchT + 4 * 4, 16, 4);
transpose_4x4(patch + 8 * 1, patchT + 8 * 4, 16, 4);
transpose_4x4(patch + 12 * 1, patchT + 12 * 4, 16, 4);
}
static void transform(
const __fp16* patchT, __fp16* input_transform_buf, size_t unit_idx,
size_t nr_units_in_tile, size_t ic, size_t IC) {
constexpr size_t alpha = 2 + 3 - 1;
#define cb(m, n) \
Vector<__fp16, 4> d##m##n = Vector<__fp16, 4>::load(patchT + m * 4 * 4 + n * 4);
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb);
#undef cb
#define cb(m) \
auto t0##m = d0##m - d2##m; \
auto t1##m = d1##m + d2##m; \
auto t2##m = d2##m - d1##m; \
auto t3##m = d3##m - d1##m;
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
#define cb(m) \
d##m##0 = t##m##0 - t##m##2; \
d##m##1 = t##m##1 + t##m##2; \
d##m##2 = t##m##2 - t##m##1; \
d##m##3 = t##m##3 - t##m##1;
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
#define cb(m, n) \
d##m##n.save( \
input_transform_buf + (m * alpha + n) * nr_units_in_tile * IC + \
unit_idx * IC + ic);
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb)
#undef cb
}
};
template <BiasMode bmode, typename Op>
struct OutputTransform2X3 {
static void transform(
const dt_float16* output_transform_buf, const dt_float16* bias,
dt_float16* output, dt_float16* transform_mid_buf, size_t oh_start,
size_t ow_start, size_t OH, size_t OW, size_t oc_start, size_t oc_end,
size_t oc_index, size_t unit_idx, size_t nr_units_in_tile,
const DType& src_dtype, const DType& dst_dtype) {
Op op(src_dtype, dst_dtype);
const __fp16* output_transform_ptr =
reinterpret_cast<const __fp16*>(output_transform_buf);
const __fp16* bias_ptr = reinterpret_cast<const __fp16*>(bias);
__fp16* output_ptr = reinterpret_cast<__fp16*>(output);
__fp16* transform_mid_ptr = reinterpret_cast<__fp16*>(transform_mid_buf);
constexpr size_t alpha = 2 + 3 - 1;
size_t OC = oc_end - oc_start;
size_t oc = oc_start + oc_index;
#define cb(m, n) \
auto v##m##n = Vector<__fp16, 4>::load( \
output_transform_ptr + (m * alpha + n) * nr_units_in_tile * OC + \
unit_idx * OC + oc_index);
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb);
#undef cb
#define cb(m) \
auto t0##m = v0##m + v1##m + v2##m; \
auto t1##m = v1##m - v2##m + v3##m;
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
v00 = t00 + t01 + t02;
v10 = t10 + t11 + t12;
v01 = t01 - t02 + t03;
v11 = t11 - t12 + t13;
Vector<__fp16, 4> vbias;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
vbias = Vector<__fp16, 4>::load(bias_ptr + oc);
v00 += vbias;
v10 += vbias;
v01 += vbias;
v11 += vbias;
}
float16x8_t result01, result23;
result01 = vcombine_f16(v00.value, v01.value);
result23 = vcombine_f16(v10.value, v11.value);
if (bmode != BiasMode::BIAS) {
result01 = op(result01);
result23 = op(result23);
}
vst1q_f16(transform_mid_ptr, result01);
vst1q_f16(transform_mid_ptr + 8, result23);
for (size_t oco = 0; oco < 4 && oc + oco < oc_end; ++oco) {
for (size_t oho = 0; oho < 2 && oh_start + oho < OH; ++oho) {
for (size_t owo = 0; owo < 2 && ow_start + owo < OW; ++owo) {
size_t oh = oh_start + oho;
size_t ow = ow_start + owo;
__fp16 res = transform_mid_ptr[oho * 2 * 4 + owo * 4 + oco];
if (bmode == BiasMode::BIAS) {
res += bias_ptr[(oc + oco) * OH * OW + oh * OW + ow];
res = op(res);
}
output_ptr[(oc + oco) * OH * OW + oh * OW + ow] = res;
}
}
}
}
};
}
namespace megdnn {
namespace arm_common {
namespace winograd {
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_4x4_f16)
void winograd_2x3_4x4_f16::filter(
const dt_float16* filter, dt_float16* filter_transform_buf,
dt_float16* transform_mid_buf, size_t OC, size_t IC, size_t oc_start,
size_t oc_end) {
constexpr int alpha = 2 + 3 - 1;
__fp16* filter_transbuf_ptr = reinterpret_cast<__fp16*>(filter_transform_buf);
__fp16* filter_transmid_ptr = reinterpret_cast<__fp16*>(transform_mid_buf);
for (size_t oc = oc_start; oc < oc_end; oc++) {
rep(ic, IC) {
const __fp16* filter_ptr =
reinterpret_cast<const __fp16*>(filter) + (oc * IC + ic) * 3 * 3;
float16x4_t v0 = vld1_f16(filter_ptr); float16x4_t v1 = vld1_f16(filter_ptr + 3); float16x4_t v2 = vld1_f16(filter_ptr + 5); float16x4_t v3 = vdup_n_f16(0);
v2 = vext_f16(v2, v3, 1);
v0 = vset_lane_f16(0, v0, 3);
v1 = vset_lane_f16(0, v1, 3);
#define cb(i) float16x4_t vsum##i;
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
vsum0 = v0;
float16x4_t v0addv2 = vadd_f16(v0, v2);
float16x4_t v02addv1 = vadd_f16(v0addv2, v1);
float16x4_t v02subv1 = vsub_f16(v0addv2, v1);
vsum1 = vmul_n_f16(v02addv1, 0.5);
vsum2 = vmul_n_f16(v02subv1, 0.5);
vsum3 = v2;
#define cb(i) \
do { \
mid_buf1[0] = vget_lane_f16(vsum##i, 0); \
__fp16 a0a2 = vget_lane_f16(vsum##i, 0) + vget_lane_f16(vsum##i, 2); \
__fp16 a0a2adda1 = a0a2 + vget_lane_f16(vsum##i, 1); \
__fp16 a0a2suba1 = a0a2 - vget_lane_f16(vsum##i, 1); \
mid_buf1[1] = a0a2adda1 * 0.5; \
mid_buf1[2] = a0a2suba1 * 0.5; \
mid_buf1[3] = vget_lane_f16(vsum##i, 2); \
mid_buf1 += 4; \
} while (0);
__fp16* mid_buf1 = filter_transmid_ptr;
UNROLL_CALL_NOWRAPPER(4, cb);
mid_buf1 = filter_transmid_ptr;
#undef cb
rep(i, alpha) rep(j, alpha) {
filter_transbuf_ptr[(i * alpha + j) * OC * IC + ic * OC + oc] =
filter_transmid_ptr[i * alpha + j];
}
}
}
}
void winograd_2x3_4x4_f16::input(
const dt_float16* input, dt_float16* input_transform_buf,
dt_float16* transform_mid_buf, size_t IH, size_t IW, size_t IC, size_t PH,
size_t PW, size_t unit_start_idx, size_t nr_units_in_tile) {
megdnn_assert(IC % 4 == 0);
constexpr int alpha = 3 + 2 - 1;
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE);
dt_float16* patch = transform_mid_buf;
dt_float16* patchT = transform_mid_buf + 4 * alpha * alpha;
for (size_t ic = 0; ic < IC; ic += 4) {
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
size_t nh = index / units_w;
size_t nw = index % units_w;
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH;
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW;
if (ih_start >= 0 && ih_start + alpha <= static_cast<int>(IH) &&
iw_start >= 0 && iw_start + alpha <= static_cast<int>(IW)) {
InputTransform2X3::prepare<true>(
reinterpret_cast<const __fp16*>(input),
reinterpret_cast<__fp16*>(patch),
reinterpret_cast<__fp16*>(patchT), ih_start, iw_start, IH, IW,
ic, IC);
InputTransform2X3::transform(
reinterpret_cast<const __fp16*>(patchT),
reinterpret_cast<__fp16*>(input_transform_buf), unit_idx,
nr_units_in_tile, ic, IC);
} else {
InputTransform2X3::prepare<false>(
reinterpret_cast<const __fp16*>(input),
reinterpret_cast<__fp16*>(patch),
reinterpret_cast<__fp16*>(patchT), ih_start, iw_start, IH, IW,
ic, IC);
InputTransform2X3::transform(
reinterpret_cast<const __fp16*>(patchT),
reinterpret_cast<__fp16*>(input_transform_buf), unit_idx,
nr_units_in_tile, ic, IC);
}
}
}
}
void winograd_2x3_4x4_f16::output(
const dt_float16* output_transform_buf, const dt_float16* bias,
dt_float16* output, dt_float16* transform_mid_buf, BiasMode bmode,
NonlineMode nonline_mode, size_t OH, size_t OW, size_t oc_start, size_t oc_end,
size_t unit_start_idx, size_t nr_units_in_tile) {
#define cb(_bmode, _nonline_op, ...) \
OutputTransform2X3<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__);
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE);
for (size_t oc = oc_start; oc < oc_end; oc += 4) {
size_t oc_index = oc - oc_start;
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
auto nh = index / units_w;
auto nw = index % units_w;
size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_arm_common_winograd_fp16_F23, cb, __fp16, __fp16, bmode,
nonline_mode, output_transform_buf, bias, output, transform_mid_buf,
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx,
nr_units_in_tile, src_dtype, dst_dtype);
}
}
#undef cb
}
} } } #endif