megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file dnn/src/arm_common/utils.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "src/common/utils.h"
#include <cstring>
#include "src/arm_common/simd_macro/marm_neon.h"

using namespace megdnn;

namespace {

template <typename dtype>
void transpose_naive(const dtype* src, dtype* dst, int lda, int ldb, int n, int m) {
    rep(i, n) rep(j, m) { dst[i * ldb + j] = src[j * lda + i]; }
}

void transpose_4x4_neon(const float* src, float* dst, int lda, int ldb) {
    float32x4x2_t a0, a1;
    a0.val[0] = vld1q_f32(src + 0 * lda);
    a0.val[1] = vld1q_f32(src + 1 * lda);
    a1.val[0] = vld1q_f32(src + 2 * lda);
    a1.val[1] = vld1q_f32(src + 3 * lda);
    float32x4x2_t b0 = vzipq_f32(a0.val[0], a1.val[0]);
    float32x4x2_t b1 = vzipq_f32(a0.val[1], a1.val[1]);
    float32x4x2_t c0 = vzipq_f32(b0.val[0], b1.val[0]);
    float32x4x2_t c1 = vzipq_f32(b0.val[1], b1.val[1]);
    vst1q_f32(dst + 0 * ldb, c0.val[0]);
    vst1q_f32(dst + 1 * ldb, c0.val[1]);
    vst1q_f32(dst + 2 * ldb, c1.val[0]);
    vst1q_f32(dst + 3 * ldb, c1.val[1]);
}

#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
void transpose_8x8_neon(const dt_float16* src, dt_float16* dst, int lda, int ldb) {
    const __fp16* src_ptr = reinterpret_cast<const __fp16*>(src);
    __fp16* dst_ptr = reinterpret_cast<__fp16*>(dst);
    float16x8x4_t a0, a1;
    a0.val[0] = vld1q_f16(src_ptr + 0 * lda);  // A0A1A2A3A4A5A6A7
    a0.val[1] = vld1q_f16(src_ptr + 1 * lda);  // B0B1B2B3B4B5B6B7
    a0.val[2] = vld1q_f16(src_ptr + 2 * lda);  // C0C1C2C3C4C5C6C7
    a0.val[3] = vld1q_f16(src_ptr + 3 * lda);  // D0D1D2D3D4D5D6D7
    a1.val[0] = vld1q_f16(src_ptr + 4 * lda);  // E0E1E2E3E4E5E6E7
    a1.val[1] = vld1q_f16(src_ptr + 5 * lda);  // F0F1F2F3F4F5F6F7
    a1.val[2] = vld1q_f16(src_ptr + 6 * lda);  // G0G1G2G3G4G5G6G7
    a1.val[3] = vld1q_f16(src_ptr + 7 * lda);  // H0H1H2H3H4H5H6H7

    float16x8x2_t b0 =
            vzipq_f16(a0.val[0], a1.val[0]);  // A0E0A1E1A2E2A3E3 A4E4A5E5A6E6A7E7
    float16x8x2_t b1 =
            vzipq_f16(a0.val[2], a1.val[2]);  // C0G0C1G1C2G2C3G3 C4G4C5G5C6G6C7G7
    float16x8x2_t c0 =
            vzipq_f16(a0.val[1], a1.val[1]);  // B0F0B1F1B2F2B3F3 B4F4B5F5B6F6B7F7
    float16x8x2_t c1 =
            vzipq_f16(a0.val[3], a1.val[3]);  // D0H0D1H1D2H2D3H3 D4H4D5H5D6H6D7H7

    float16x8x2_t d0 =
            vzipq_f16(b0.val[0], b1.val[0]);  // A0C0E0G0A1C1E1G1 A2C2E2G2A3C3E3G3
    float16x8x2_t d1 =
            vzipq_f16(c0.val[0], c1.val[0]);  // B0D0F0H0B1D1F1H1 B2D2F2H2B3D3F3H3
    float16x8x2_t e0 =
            vzipq_f16(d0.val[0], d1.val[0]);  // A0B0C0D0E0F0G0H0 A1B1C1D1E1F1G1H1
    float16x8x2_t e1 =
            vzipq_f16(d0.val[1], d1.val[1]);  // A2B2C2D2E2F2G2H2 A3B3C3D3E3F3G3H3

    float16x8x2_t f0 =
            vzipq_f16(b0.val[1], b1.val[1]);  // A4C4E4G4A5C5E5G5 A6C6E6G6A7C7E7G7
    float16x8x2_t f1 =
            vzipq_f16(c0.val[1], c1.val[1]);  // B4D4F4H4B5D5F5H5 B6D6E6G6B7D7E7H7
    float16x8x2_t g0 =
            vzipq_f16(f0.val[0], f1.val[0]);  // A4B4C4D4E4F4G4H4 A5B5C5D5E5F5G5H5
    float16x8x2_t g1 =
            vzipq_f16(f0.val[1], f1.val[1]);  // A6B6C6D6E6F6G6H6 A7B7C7D7E7F7G7H7

    vst1q_f16(dst_ptr + 0 * ldb, e0.val[0]);
    vst1q_f16(dst_ptr + 1 * ldb, e0.val[1]);
    vst1q_f16(dst_ptr + 2 * ldb, e1.val[0]);
    vst1q_f16(dst_ptr + 3 * ldb, e1.val[1]);
    vst1q_f16(dst_ptr + 4 * ldb, g0.val[0]);
    vst1q_f16(dst_ptr + 5 * ldb, g0.val[1]);
    vst1q_f16(dst_ptr + 6 * ldb, g1.val[0]);
    vst1q_f16(dst_ptr + 7 * ldb, g1.val[1]);
}
#endif

}  // anonymous namespace

namespace megdnn {

template <>
void transpose(
        const float* src, float* dst, size_t m, size_t n, ptrdiff_t lds,
        ptrdiff_t ldd) {
    if (lds == -1) {
        lds = n;
    }
    if (ldd == -1) {
        ldd = m;
    }

    for (size_t is = 0; is < n; is += 16) {
        for (size_t js = 0; js < m; js += 16) {
            auto ie = std::min(is + 16, n), je = std::min(js + 16, m), i = is;
            for (; i + 4 <= ie; i += 4) {
                auto j = js;
                for (; j + 4 <= je; j += 4) {
                    transpose_4x4_neon(src + j * lds + i, dst + i * ldd + j, lds, ldd);
                }
                if (j < je) {
                    transpose_naive(
                            src + j * lds + i, dst + i * ldd + j, lds, ldd, 4, je - j);
                }
            }
            if (i < ie) {
                transpose_naive(
                        src + js * lds + i, dst + i * ldd + js, lds, ldd, ie - i,
                        je - js);
            }
        }
    }
}

template <typename dtype>
void transpose_knc2nsck_helper(
        const dtype* src, dtype* dst, size_t k, size_t n, size_t c, size_t n_stride) {
    if (n_stride == k * c) {
        // dst is contiguous
        transpose(src, dst, k, n * c);
    } else {
        for (size_t i = 0; i < n; ++i) {
            transpose(src + i * c, dst + i * n_stride, k, c, n * c);
        }
    }
}

template <>
void transpose_knc2nsck(
        const float* src, float* dst, size_t k, size_t n, size_t c, size_t n_stride) {
    transpose_knc2nsck_helper(src, dst, k, n, c, n_stride);
}

#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <>
void transpose(
        const dt_float16* src, dt_float16* dst, size_t m, size_t n, ptrdiff_t lds,
        ptrdiff_t ldd) {
    if (lds == -1) {
        lds = n;
    }
    if (ldd == -1) {
        ldd = m;
    }

    for (size_t is = 0; is < n; is += 16) {
        for (size_t js = 0; js < m; js += 16) {
            auto ie = std::min(is + 16, n), je = std::min(js + 16, m), i = is;
            for (; i + 8 <= ie; i += 8) {
                auto j = js;
                for (; j + 8 <= je; j += 8) {
                    transpose_8x8_neon(src + j * lds + i, dst + i * ldd + j, lds, ldd);
                }
                if (j < je) {
                    transpose_naive(
                            src + j * lds + i, dst + i * ldd + j, lds, ldd, 8, je - j);
                }
            }
            if (i < ie) {
                transpose_naive(
                        src + js * lds + i, dst + i * ldd + js, lds, ldd, ie - i,
                        je - js);
            }
        }
    }
}

template <>
void transpose_knc2nsck(
        const dt_float16* src, dt_float16* dst, size_t k, size_t n, size_t c,
        size_t n_stride) {
    transpose_knc2nsck_helper(src, dst, k, n, c, n_stride);
}
#endif

}  // namespace megdnn
// vim: syntax=cpp.doxygen