#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/common/winograd/winograd_helper.h"
#include "src/fallback/conv_bias/winograd/winograd.h"
#include "src/x86/avx_helper.h"
#include "src/x86/conv_bias/f32/strategy.h"
#include "src/x86/elemwise_helper/op_unary.h"
#include <x86intrin.h>
#ifdef WIN32
#include <avx2intrin.h>
#include <avxintrin.h>
#include <fmaintrin.h>
#include <smmintrin.h>
#endif
#include "midout.h"
MIDOUT_DECL(megdnn_x86_winograd_nchw88_fp32_F23_8x8)
using namespace megdnn;
using namespace x86;
namespace {
constexpr size_t alpha = 2 + 3 - 1;
struct InputTransform2X3_NCHW88 {
template <bool inner>
MEGDNN_ATTRIBUTE_TARGET("avx2")
static void prepare(
const float* input, float* patch, float* patchT, int ih_start, int iw_start,
size_t IH, size_t IW, size_t ic, size_t IC) {
MEGDNN_MARK_USED_VAR(patch);
size_t IW8 = IW * 8; size_t iw8_start = iw_start * 8; size_t icb = ic / 8;
if (!(inner && ic + 8 < IC)) {
memset(patchT, 0, sizeof(float) * 8 * alpha * alpha);
}
if (inner) {
const float* input_ptr =
input + icb * IH * IW8 + ih_start * IW8 + iw8_start;
for (size_t ih = 0; ih < alpha; ih++) {
#define cb(i) auto v##i = _mm256_loadu_ps(input_ptr + 8 * i);
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
#define cb(i) _mm256_storeu_ps(patchT + ih * alpha * 8 + i * 8, v##i);
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
input_ptr += IW8;
}
} else {
int ih0_act = std::max<int>(ih_start, 0),
ih1_act = std::min<int>(ih_start + alpha, IH),
iw0_act = std::max<int>(iw_start, 0),
iw1_act = std::min<int>(iw_start + alpha, IW);
const float* input_ptr = input + icb * IH * IW8;
for (int ih = ih0_act; ih < ih1_act; ++ih) {
for (int iw = iw0_act; iw < iw1_act; ++iw) {
size_t iho = ih - ih_start, iwo = iw - iw_start;
auto src = _mm256_loadu_ps(input_ptr + ih * IW8 + iw * 8);
_mm256_storeu_ps(patchT + iho * alpha * 8 + iwo * 8, src);
}
}
}
}
MEGDNN_ATTRIBUTE_TARGET("avx2")
static void transform(
const float* patchT, float* input_transform_buf, size_t unit_idx,
size_t nr_units_in_tile, size_t ic, size_t IC) {
#define cb(m, n) \
Vector<float, 8> d##m##n = Vector<float, 8>::load(patchT + m * alpha * 8 + n * 8);
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb);
#undef cb
#define cb(m) \
auto t0##m = d0##m - d2##m; \
auto t1##m = d1##m + d2##m; \
auto t2##m = d2##m - d1##m; \
auto t3##m = d3##m - d1##m;
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
#define cb(m) \
d##m##0 = t##m##0 - t##m##2; \
d##m##1 = t##m##1 + t##m##2; \
d##m##2 = t##m##2 - t##m##1; \
d##m##3 = t##m##3 - t##m##1;
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
size_t ICB = IC / 8;
size_t icb = ic / 8;
#define cb(m, n) \
d##m##n.save( \
input_transform_buf + (m * alpha + n) * ICB * nr_units_in_tile * 8 + \
icb * nr_units_in_tile * 8 + unit_idx * 8);
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb)
#undef cb
}
};
struct FilterTransform2X3_MCHW88 {
MEGDNN_ATTRIBUTE_TARGET("avx2")
static void transform(
const float* filter, float* filter_transform_buf, float* transform_mid_buf,
size_t OC, size_t IC, size_t oc_start, size_t oc_end) {
MEGDNN_MARK_USED_VAR(transform_mid_buf);
megdnn_assert(
(oc_end - oc_start) % 8 == 0 && oc_start % 8 == 0 && oc_end % 8 == 0 &&
IC % 8 == 0 && OC % 8 == 0,
"Winograd filter transform input param is not times of 8!");
size_t OCB = OC / 8;
size_t ICB = IC / 8;
for (size_t ocb = oc_start / 8; ocb < oc_end / 8; ocb++) {
for (size_t icb = 0; icb < ICB; icb++) {
for (size_t ic_inner = 0; ic_inner < 8; ic_inner++) {
const float* fptr =
filter + (ocb * ICB + icb) * 3 * 3 * 8 * 8 + ic_inner * 8;
#define cb(m, n) \
Vector<float, 8> g##m##n = Vector<float, 8>::load(fptr + (m * 3 + n) * 8 * 8);
UNROLL_CALL_NOWRAPPER_D2(3, 3, cb)
#undef cb
#define FILTER_TRANSFORM(n, wd, g) \
auto wd##n##0 = g##0##n; \
tmp0 = (g##0##n + g##2##n) * 0.5; \
tmp1 = g##1##n * 0.5; \
auto wd##n##1 = tmp0 + tmp1; \
auto wd##n##2 = tmp0 - tmp1; \
auto wd##n##3 = g##2##n;
Vector<float, 8> tmp0, tmp1;
UNROLL_CALL_RAW(3, FILTER_TRANSFORM, wd, g);
UNROLL_CALL_RAW(4, FILTER_TRANSFORM, ret, wd);
#undef FILTER_TRANSFORM
#define cb_save(m, n) \
ret##m##n.save( \
filter_transform_buf + (m * alpha + n) * OCB * ICB * 8 * 8 + \
ocb * ICB * 8 * 8 + icb * 8 * 8 + ic_inner * 8);
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb_save)
#undef cb_save
}
}
}
}
};
#define CONCAT(a, idx) a##idx
template <BiasMode bmode, typename Op>
struct OutputTransform2X3_NCHW88 {
MEGDNN_ATTRIBUTE_TARGET("avx2")
static void transform(
const float* output_transform_buf, const float* bias, float* output,
float* transform_mid_buf, size_t oh_start, size_t ow_start, size_t OH,
size_t OW, size_t oc_start, size_t oc_end, size_t oc_index, size_t unit_idx,
size_t nr_units_in_tile, const DType& src_dtype, const DType& dst_dtype) {
MEGDNN_MARK_USED_VAR(transform_mid_buf);
Op op(src_dtype, dst_dtype);
size_t OCB = (oc_end - oc_start) / 8;
size_t oc = oc_start + oc_index;
size_t ocb = oc_index / 8;
#define cb(m, n) \
auto v##m##n = Vector<float, 8>::load( \
output_transform_buf + (m * alpha + n) * OCB * nr_units_in_tile * 8 + \
ocb * nr_units_in_tile * 8 + unit_idx * 8);
UNROLL_CALL_NOWRAPPER_D2(4, 4, cb);
#undef cb
#define cb(m) \
auto t0##m = v0##m + v1##m + v2##m; \
auto t1##m = v1##m - v2##m + v3##m;
UNROLL_CALL_NOWRAPPER(4, cb);
#undef cb
#define cb(m) \
v##m##0 = t##m##0 + t##m##1 + t##m##2; \
v##m##1 = t##m##1 - t##m##2 + t##m##3;
UNROLL_CALL_NOWRAPPER(2, cb);
#undef cb
Vector<float, 8> vbias;
if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) {
vbias = Vector<float, 8>::load(bias + oc);
#define cb(m, n) v##m##n += vbias;
UNROLL_CALL_RAW_D2(2, 2, cb);
#undef cb
}
if (bmode != BiasMode::BIAS) {
#define cb(m, n) v##m##n = op(CONCAT(v##m, n).value);
UNROLL_CALL_RAW_D2(2, 2, cb);
#undef cb
}
#define out_save(oho, owo) \
do { \
size_t oh = oh_start + oho; \
size_t ow = ow_start + owo; \
if (oh < OH && ow < OW) { \
if (bmode == BiasMode::BIAS) { \
v##oho##owo += Vector<float, 8>::load( \
bias + oc / 8 * OH * OW * 8 + oh * OW * 8 + ow * 8); \
v##oho##owo = op(v##oho##owo.value); \
} \
v##oho##owo.save(output + oc / 8 * OH * OW * 8 + oh * OW * 8 + ow * 8); \
} \
} while (0);
UNROLL_CALL_RAW_D2(2, 2, out_save);
}
};
#undef CONCAT
}
namespace megdnn {
namespace x86 {
namespace winograd {
MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_nchw88_2x3_8x8_f)
void winograd_nchw88_2x3_8x8_f::filter(
const float* filter, float* filter_transform_buf, float* transform_mid_buf,
size_t OC, size_t IC, size_t oc_start, size_t oc_end) {
FilterTransform2X3_MCHW88::transform(
filter, filter_transform_buf, transform_mid_buf, OC, IC, oc_start, oc_end);
}
void winograd_nchw88_2x3_8x8_f::input(
const float* input, float* input_transform_buf, float* transform_mid_buf,
size_t IH, size_t IW, size_t IC, size_t PH, size_t PW, size_t unit_start_idx,
size_t nr_units_in_tile) {
megdnn_assert(IC % 8 == 0);
auto units_w = div_ceil<size_t>(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE);
float* patch = transform_mid_buf;
float* patchT = transform_mid_buf + 8 * alpha * alpha;
for (size_t ic = 0; ic < IC; ic += 8) {
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
size_t nh = index / units_w;
size_t nw = index % units_w;
int ih_start = nh * OUTPUT_BLOCK_SIZE - PH;
int iw_start = nw * OUTPUT_BLOCK_SIZE - PW;
if (ih_start >= 0 && ih_start + alpha <= static_cast<size_t>(IH) &&
iw_start >= 0 && iw_start + alpha <= static_cast<size_t>(IW)) {
InputTransform2X3_NCHW88::prepare<true>(
input, patch, patchT, ih_start, iw_start, IH, IW, ic, IC);
InputTransform2X3_NCHW88::transform(
patchT, input_transform_buf, unit_idx, nr_units_in_tile, ic,
IC);
} else {
InputTransform2X3_NCHW88::prepare<false>(
input, patch, patchT, ih_start, iw_start, IH, IW, ic, IC);
InputTransform2X3_NCHW88::transform(
patchT, input_transform_buf, unit_idx, nr_units_in_tile, ic,
IC);
}
}
}
}
void winograd_nchw88_2x3_8x8_f::output(
const float* output_transform_buf, const float* bias, float* output,
float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, size_t OH,
size_t OW, size_t oc_start, size_t oc_end, size_t unit_start_idx,
size_t nr_units_in_tile) {
#define cb(_bmode, _nonline_op, ...) \
OutputTransform2X3_NCHW88<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__);
auto units_w = div_ceil<size_t>(OW, OUTPUT_BLOCK_SIZE);
size_t OC = oc_end - oc_start;
megdnn_assert(
OC % 8 == 0 && oc_start % 8 == 0 && oc_end % 8 == 0,
"Winograd output transform input param is not times of 8!");
for (size_t oc = oc_start; oc + 8 <= oc_end; oc += 8) {
size_t oc_index = oc - oc_start;
rep(unit_idx, nr_units_in_tile) {
size_t index = unit_start_idx + unit_idx;
auto nh = index / units_w;
auto nw = index % units_w;
size_t oh_start = nh * OUTPUT_BLOCK_SIZE;
size_t ow_start = nw * OUTPUT_BLOCK_SIZE;
DISPATCH_CONV_WINOGRAD_BIAS(
megdnn_x86_winograd_nchw88_fp32_F23_8x8, cb, SIMDType::AVX2, float,
float, bmode, nonline_mode, output_transform_buf, bias, output,
transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, oc_end,
oc_index, unit_idx, nr_units_in_tile, src_dtype, dst_dtype);
}
}
#undef cb
}
} } }