import itertools
def gen(mode, simd, fsize):
funcname = "convolution_{mode}_fh{fsize}_{simd}".format(**vars())
filename = funcname + ".cpp"
if simd == 'fma':
MAX_H = 15 - fsize
elif simd == 'avx' or simd == 'sse':
MAX_H = 14 - fsize
else:
assert False
if simd == "sse":
width = 4
mm_type = "__m128"
mm_load = "_mm_loadu_ps"
mm_store = "_mm_storeu_ps"
mm_mul = "_mm_mul_ps"
mm_add = "_mm_add_ps"
mm_set1 = "_mm_set1_ps"
mm_set0 = "_mm_setzero_ps"
mm_max = "_mm_max_ps"
mm_set1_sign = ""
header = ["xmmintrin.h"]
elif simd == "avx":
width = 8
mm_type = "__m256"
mm_load = "_mm256_loadu_ps"
mm_store = "_mm256_storeu_ps"
mm_mul = "_mm256_mul_ps"
mm_add = "_mm256_add_ps"
mm_set1 = "_mm256_broadcast_ss"
mm_set0 = "_mm256_setzero_ps"
mm_max = "_mm256_max_ps"
mm_set1_sign = "&"
header = ["immintrin.h", "avxintrin.h"]
elif simd == "fma":
width = 8
mm_type = "__m256"
mm_load = "_mm256_loadu_ps"
mm_store = "_mm256_storeu_ps"
mm_set1 = "_mm256_broadcast_ss"
mm_set0 = "_mm256_setzero_ps"
mm_max = "_mm256_max_ps"
mm_set1_sign = "&"
header = ["immintrin.h", "avxintrin.h", "fmaintrin.h"]
with open(filename, 'w') as f:
for H in range(1, MAX_H+1):
f.write("""#define SIMD_H{H} do {{ \\
const size_t sh = dh; \\
const float *src_d = src + sh*src_w; \\
float *dst_d = dst + dh*dst_w; \\
size_t dw = dst_w_beg; \\
for (; dw < dst_w_end; dw += {width}) {{ \\
const size_t sw = dw; \\
float *dst_dd = dst_d + dw; \\
{mm_type} tmp0; \\
""".format(**vars()))
if simd != "fma":
f.write(" {mm_type} tmp1; \\\n".format(**vars()))
for h in range(H):
f.write(""" {mm_type} res{h}; \\
res{h} = {mm_load}(dst_dd + {h}*dst_w); \\
""".format(**vars()))
f.write(""" for (size_t fw = 0; fw < flt_w; ++fw) {{ \\
const float *src_dd = src_d + sw + fw; \\
""".format(**vars()))
for fh in range(fsize):
if mode == 'xcorr':
f.write(""" {mm_type} vf{fh} = {mm_set1}({mm_set1_sign}filter[{fh}*flt_w+fw]); \\
""".format(**vars()))
elif mode == 'conv':
f.write(""" {mm_type} vf{fh} = {mm_set1}({mm_set1_sign}filter[{fh}*flt_w+flt_w-fw-1]); \\
""".format(**vars()))
else:
assert False
for ih in range(H+fsize-1):
f.write(""" tmp0 = {mm_load}(src_dd + {ih}*src_w); \\
""".format(**vars()))
for fh in range(fsize):
if mode == 'xcorr':
oh = ih - fh
elif mode == 'conv':
oh = ih - (fsize-fh-1)
else:
assert False
if oh >= 0 and oh < H:
if simd == "fma":
f.write(""" res{oh} = _mm256_fmadd_ps(tmp0, vf{fh}, res{oh}); \\
""".format(**vars()))
else:
f.write(""" tmp1 = {mm_mul}(tmp0, vf{fh}); \\
""".format(**vars()))
f.write(""" res{oh} = {mm_add}(res{oh}, tmp1); \\
""".format(**vars()))
f.write(""" }} \\
""".format(**vars()))
for h in range(H):
f.write(""" {mm_store}(dst_dd + {h}*dst_w, res{h}); \\
""".format(**vars()))
f.write("""}} \\
}} while (0)
""".format(**vars()))
f.write("\n")
for i in header:
f.write('#include <{}>\n'.format(i))
f.write("""#include <algorithm>
#include "../convolution_direct_special_cases.h"
namespace megdnn {{
namespace x86 {{
namespace detail {{
void {funcname}(const float *src, const float *filter, float *dst,
const size_t src_h, const size_t src_w, const size_t dst_h, const size_t dst_w,
const size_t flt_w)
{{
(void)src_h;
const size_t dst_h_beg = 0;
const size_t dst_h_end = dst_h;
const size_t dst_w_beg = 0;
const size_t dst_w_end = dst_w;
""".format(**vars()))
f.write("""
size_t dh = dst_h_beg;
for (; dh + {MAX_H} <= dst_h_end; dh += {MAX_H}) {{
SIMD_H{MAX_H};
}}
switch (dst_h_end - dh) {{
""".format(**vars()))
for H in range(1, MAX_H):
f.write(""" case {H}:
SIMD_H{H};
break;
""".format(**vars()))
f.write(""" }}
}}
}} // namespace detail
}} // namespace x86
}} // namespace megdnn
""".format(**vars()))
for H in range(1, MAX_H+1):
f.write("""#undef SIMD_H{H}
""".format(**vars()))
def gen_header(modes, simds, fsizes):
with open('convolution_direct_special_cases.h', 'w') as f:
f.write("""#pragma once
#include <cstddef>
#include "megdnn/arch.h"
namespace megdnn {
namespace x86 {
namespace detail {
""")
for mode, simd, fsize in itertools.product(modes, simds, fsizes):
funcname = "convolution_{mode}_fh{fsize}_{simd}".format(**vars())
f.write("""
void {funcname}(const float *src, const float *filter, float *dst,
const size_t src_h, const size_t src_w, const size_t dst_h, const size_t dst_w,
const size_t flt_w) MEGDNN_ATTRIBUTE_TARGET("{simd}");
""".format(**vars()))
f.write("""} // namespace detail
} // namespace x86
} // namespace megdnn
""")
if __name__ == '__main__':
for mode in ['xcorr', 'conv']:
for fsize in range(1, 8):
for simd in ['sse', 'avx', 'fma']:
gen(mode, simd, fsize)
gen_header(['xcorr', 'conv'], ['sse', 'avx', 'fma'], range(1, 8))