megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import itertools

def gen(mode, simd, fsize):
    funcname = "convolution_{mode}_fh{fsize}_{simd}".format(**vars())
    filename = funcname + ".cpp"
    if simd == 'fma':
        MAX_H = 15 - fsize
    elif simd == 'avx' or simd == 'sse':
        MAX_H = 14 - fsize
    else:
        assert False
    if simd == "sse":
        width = 4
        mm_type = "__m128"
        mm_load = "_mm_loadu_ps"
        mm_store = "_mm_storeu_ps"
        mm_mul = "_mm_mul_ps"
        mm_add = "_mm_add_ps"
        mm_set1 = "_mm_set1_ps"
        mm_set0 = "_mm_setzero_ps"
        mm_max = "_mm_max_ps"
        mm_set1_sign = ""
        header = ["xmmintrin.h"]
    elif simd == "avx":
        width = 8
        mm_type = "__m256"
        mm_load = "_mm256_loadu_ps"
        mm_store = "_mm256_storeu_ps"
        mm_mul = "_mm256_mul_ps"
        mm_add = "_mm256_add_ps"
        mm_set1 = "_mm256_broadcast_ss"
        mm_set0 = "_mm256_setzero_ps"
        mm_max = "_mm256_max_ps"
        mm_set1_sign = "&"
        header = ["immintrin.h", "avxintrin.h"]
    elif simd == "fma":
        width = 8
        mm_type = "__m256"
        mm_load = "_mm256_loadu_ps"
        mm_store = "_mm256_storeu_ps"
        mm_set1 = "_mm256_broadcast_ss"
        mm_set0 = "_mm256_setzero_ps"
        mm_max = "_mm256_max_ps"
        mm_set1_sign = "&"
        header = ["immintrin.h", "avxintrin.h", "fmaintrin.h"]
    with open(filename, 'w') as f:
        for H in range(1, MAX_H+1):
            f.write("""#define SIMD_H{H} do {{ \\
const size_t sh = dh; \\
const float *src_d = src + sh*src_w; \\
float *dst_d = dst + dh*dst_w; \\
size_t dw = dst_w_beg; \\
for (; dw < dst_w_end; dw += {width}) {{ \\
    const size_t sw = dw; \\
    float *dst_dd = dst_d + dw; \\
    {mm_type} tmp0; \\
""".format(**vars()))
            if simd != "fma":
                f.write("    {mm_type} tmp1; \\\n".format(**vars()))
            for h in range(H):
                f.write("""    {mm_type} res{h}; \\
    res{h} = {mm_load}(dst_dd + {h}*dst_w); \\
""".format(**vars()))
            f.write("""    for (size_t fw = 0; fw < flt_w; ++fw) {{ \\
        const float *src_dd = src_d + sw + fw; \\
""".format(**vars()))
            for fh in range(fsize):
                if mode == 'xcorr':
                    f.write("""        {mm_type} vf{fh} = {mm_set1}({mm_set1_sign}filter[{fh}*flt_w+fw]); \\
""".format(**vars()))
                elif mode == 'conv':
                    f.write("""        {mm_type} vf{fh} = {mm_set1}({mm_set1_sign}filter[{fh}*flt_w+flt_w-fw-1]); \\
""".format(**vars()))
                else:
                    assert False
            for ih in range(H+fsize-1):
                f.write("""        tmp0 = {mm_load}(src_dd + {ih}*src_w); \\
""".format(**vars()))
                for fh in range(fsize):
                    if mode == 'xcorr':
                        oh = ih - fh
                    elif mode == 'conv':
                        oh = ih - (fsize-fh-1)
                    else:
                        assert False
                    if oh >= 0 and oh < H:
                        if simd == "fma":
                            f.write("""        res{oh} = _mm256_fmadd_ps(tmp0, vf{fh}, res{oh}); \\
""".format(**vars()))
                        else:
                            f.write("""        tmp1 = {mm_mul}(tmp0, vf{fh}); \\
""".format(**vars()))
                            f.write("""        res{oh} = {mm_add}(res{oh}, tmp1); \\
""".format(**vars()))
            f.write("""    }} \\
""".format(**vars()))
            for h in range(H):
                f.write("""    {mm_store}(dst_dd + {h}*dst_w, res{h}); \\
""".format(**vars()))
            f.write("""}} \\
}} while (0)
""".format(**vars()))
            f.write("\n")


        for i in header:
            f.write('#include <{}>\n'.format(i))
        f.write("""#include <algorithm>

#include "../convolution_direct_special_cases.h"

namespace megdnn {{
namespace x86 {{
namespace detail {{

void {funcname}(const float *src, const float *filter, float *dst,
        const size_t src_h, const size_t src_w, const size_t dst_h, const size_t dst_w,
        const size_t flt_w)
{{
    (void)src_h;
    const size_t dst_h_beg = 0;
    const size_t dst_h_end = dst_h;
    const size_t dst_w_beg = 0;
    const size_t dst_w_end = dst_w;
""".format(**vars()))

        f.write("""
    size_t dh = dst_h_beg;
    for (; dh + {MAX_H} <= dst_h_end; dh += {MAX_H}) {{
        SIMD_H{MAX_H};
    }}
    switch (dst_h_end - dh) {{
""".format(**vars()))
        for H in range(1, MAX_H):
            f.write("""        case {H}:
            SIMD_H{H};
            break;
""".format(**vars()))
        f.write("""    }}
}}

}} // namespace detail
}} // namespace x86
}} // namespace megdnn
""".format(**vars()))
        for H in range(1, MAX_H+1):
            f.write("""#undef SIMD_H{H}
""".format(**vars()))

def gen_header(modes, simds, fsizes):
    with open('convolution_direct_special_cases.h', 'w') as f:
        f.write("""#pragma once

#include <cstddef>
#include "megdnn/arch.h"

namespace megdnn {
namespace x86 {
namespace detail {
""")
        for mode, simd, fsize in itertools.product(modes, simds, fsizes):
            funcname = "convolution_{mode}_fh{fsize}_{simd}".format(**vars())
            f.write("""
void {funcname}(const float *src, const float *filter, float *dst,
        const size_t src_h, const size_t src_w, const size_t dst_h, const size_t dst_w,
        const size_t flt_w) MEGDNN_ATTRIBUTE_TARGET("{simd}");
""".format(**vars()))

        f.write("""} // namespace detail
} // namespace x86
} // namespace megdnn
""")

if __name__ == '__main__':
    for mode in ['xcorr', 'conv']:
        for fsize in range(1, 8):
            for simd in ['sse', 'avx', 'fma']:
                gen(mode, simd, fsize)
    gen_header(['xcorr', 'conv'], ['sse', 'avx', 'fma'], range(1, 8))