megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file dnn/src/fallback/elemwise/opr_binary_impl.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#include "./opr_impl.h"

#include "src/common/elemwise/kern_defs.cuh"
#include "src/common/utils.h"
#include "src/naive/handle.h"

#include "midout.h"

MIDOUT_DECL(megdnn_fallback_elemwise_binary)

namespace megdnn {
namespace fallback {

template <typename dtype, uint32_t mode>
void ElemwiseImpl::binary_kern(const ElemwiseOpParamN<2>& param) {
    using ctype = typename DTypeTrait<dtype>::ctype;
    using Kern = ElemwiseKern<megcorePlatformCPU, mode, ctype>;

    MIDOUT_BEGIN(megdnn_fallback_elemwise_binary, ctype, midout_iv(mode)) {
        if (param.max_ndim == 1) {
            MIDOUT_BEGIN(
                    megdnn_fallback_elemwise_binary, ctype, midout_iv(mode),
                    midout_iv(1)) {
                auto tot = param.size;
                auto as = param[0].layout.stride[0], bs = param[1].layout.stride[0];
                auto src0 = param[0];
                auto src1 = param[1];
                auto dst_tensor = *m_dst;

                MEGDNN_DISPATCH_CPU_KERN_OPR({
                    ctype* __restrict a = static_cast<ctype*>(src0.raw_ptr());
                    ctype* __restrict b = static_cast<ctype*>(src1.raw_ptr());
                    ctype* __restrict dst = static_cast<ctype*>(dst_tensor.raw_ptr());
                    for (size_t i = 0; i < tot; ++i) {
                        dst[i] = Kern::apply(a[i * as], b[i * bs]);
                    }
                });
                return;
            }
            MIDOUT_END();
        }

        if (std::min(param[0].layout.ndim, param[1].layout.ndim) > 1) {
            return naive::ElemwiseForwardImpl::exec(*m_src, *m_dst);
        }

        if (param.max_ndim == 2) {
            if (param[0].layout.ndim == 1) {
                MIDOUT_BEGIN(
                        megdnn_fallback_elemwise_binary, ctype, midout_iv(mode),
                        midout_iv(21)) {
                    auto as = param[0].layout.stride[0],
                         bs0 = param[1].layout.stride[0],
                         bs1 = param[1].layout.stride[1];
                    auto n0 = param[1].layout.shape[0], n1 = param[1].layout.shape[1];
                    auto src0 = param[0];
                    auto src1 = param[1];
                    auto dst_tensor = *m_dst;

                    MEGDNN_DISPATCH_CPU_KERN_OPR({
                        ctype* __restrict a = static_cast<ctype*>(src0.raw_ptr());
                        ctype* __restrict b = static_cast<ctype*>(src1.raw_ptr());
                        ctype* __restrict dst =
                                static_cast<ctype*>(dst_tensor.raw_ptr());
                        ptrdiff_t toff = 0;
                        for (size_t i = 0; i < n0; ++i) {
                            for (size_t j = 0; j < n1; ++j) {
                                dst[toff] =
                                        Kern::apply(a[as * toff], b[bs0 * i + bs1 * j]);
                                ++toff;
                            }
                        }
                    });
                    return;
                }
                MIDOUT_END();
            }

            MIDOUT_BEGIN(
                    megdnn_fallback_elemwise_binary, ctype, midout_iv(mode),
                    midout_iv(22)) {
                megdnn_assert(param[1].layout.ndim == 1);
                auto bs = param[1].layout.stride[0], as0 = param[0].layout.stride[0],
                     as1 = param[0].layout.stride[1];
                auto n0 = param[0].layout.shape[0], n1 = param[0].layout.shape[1];
                auto src0 = param[0];
                auto src1 = param[1];
                auto dst_tensor = *m_dst;

                MEGDNN_DISPATCH_CPU_KERN_OPR({
                    ctype* __restrict a = static_cast<ctype*>(src0.raw_ptr());
                    ctype* __restrict b = static_cast<ctype*>(src1.raw_ptr());
                    ctype* __restrict dst = static_cast<ctype*>(dst_tensor.raw_ptr());
                    ptrdiff_t toff = 0;
                    for (size_t i = 0; i < n0; ++i) {
                        for (size_t j = 0; j < n1; ++j) {
                            dst[toff] = Kern::apply(a[as0 * i + as1 * j], b[toff * bs]);
                            ++toff;
                        }
                    }
                });
                return;
            }
            MIDOUT_END();
        }

        if (param.max_ndim == 3) {
            auto brd_101 = [](const TensorND& t) {
                auto&& l = t.layout;
                return l.ndim == 3 && l.stride[0] == 0 && l.stride[2] == 0;
            };
            if (param[0].layout.ndim == 1 && brd_101(param[1])) {
                MIDOUT_BEGIN(
                        megdnn_fallback_elemwise_binary, ctype, midout_iv(mode),
                        midout_iv(31)) {
                    auto as = param[0].layout.stride[0], bs = param[1].layout.stride[1];
                    auto n0 = param[1].layout.shape[0], n1 = param[1].layout.shape[1],
                         n2 = param[1].layout.shape[2];
                    auto src0 = param[0];
                    auto src1 = param[1];
                    auto dst_tensor = *m_dst;

                    MEGDNN_DISPATCH_CPU_KERN_OPR({
                        ctype* __restrict a = static_cast<ctype*>(src0.raw_ptr());
                        ctype* __restrict b = static_cast<ctype*>(src1.raw_ptr());
                        ctype* __restrict dst =
                                static_cast<ctype*>(dst_tensor.raw_ptr());
                        size_t toff = 0;
                        for (size_t i = 0; i < n0; ++i) {
                            for (size_t j = 0; j < n1; ++j) {
                                for (size_t k = 0; k < n2; ++k) {
                                    dst[toff] = Kern::apply(a[as * toff], b[bs * j]);
                                    ++toff;
                                }
                            }
                        }
                    });
                    return;
                }
                MIDOUT_END();
            }
            if (param[1].layout.ndim == 1 && brd_101(param[0])) {
                MIDOUT_BEGIN(
                        megdnn_fallback_elemwise_binary, ctype, midout_iv(mode),
                        midout_iv(32)) {
                    auto as = param[0].layout.stride[1], bs = param[1].layout.stride[0];
                    auto n0 = param[0].layout.shape[0], n1 = param[0].layout.shape[1],
                         n2 = param[0].layout.shape[2];
                    auto src0 = param[0];
                    auto src1 = param[1];
                    auto dst_tensor = *m_dst;
                    MEGDNN_DISPATCH_CPU_KERN_OPR({
                        ctype* __restrict a = static_cast<ctype*>(src0.raw_ptr());
                        ctype* __restrict b = static_cast<ctype*>(src1.raw_ptr());
                        ctype* __restrict dst =
                                static_cast<ctype*>(dst_tensor.raw_ptr());
                        size_t toff = 0;
                        for (size_t i = 0; i < n0; ++i) {
                            for (size_t j = 0; j < n1; ++j) {
                                for (size_t k = 0; k < n2; ++k) {
                                    dst[toff] = Kern::apply(a[as * j], b[bs * toff]);
                                    ++toff;
                                }
                            }
                        }
                    });
                    return;
                }
                MIDOUT_END();
            }
        }

        naive::ElemwiseForwardImpl::exec(*m_src, *m_dst);
    }
    MIDOUT_END();
}

#define SWITCH_DTYPE(_cat, _cb)                            \
    switch (m_dst->layout.dtype.enumv()) {                 \
        MEGDNN_FOREACH_COMPUTING_DTYPE_##_cat(_cb) default \
                : megdnn_throw("bad dtype");               \
    }

template <uint32_t mode>
void ElemwiseImpl::exec_BINARY_INT() {
    auto param = make_elemwise_op_param<2>();
#define cb(_dt)                  \
    case DTypeTrait<_dt>::enumv: \
        return binary_kern<_dt, mode>(param);

    SWITCH_DTYPE(INT, cb)

#undef cb
}

template <uint32_t mode>
void ElemwiseImpl::exec_BINARY_FLOAT() {
    auto param = make_elemwise_op_param<2>();
#define cb(_dt)                  \
    case DTypeTrait<_dt>::enumv: \
        return binary_kern<_dt, mode>(param);

    SWITCH_DTYPE(FLOAT, cb)

#undef cb
}

#undef SWITCH_DTYPE

#undef SWITCH_DTYPE
using Mode = param_enumv::Elemwise::Mode;
#define INST(mode) template void megdnn::fallback::ElemwiseImpl::exec_BINARY_INT<mode>()
INST(Mode::ABS_GRAD);
INST(Mode::ADD);
INST(Mode::FLOOR_DIV);
INST(Mode::MAX);
INST(Mode::MIN);
INST(Mode::MOD);
INST(Mode::MUL);
INST(Mode::SIGMOID_GRAD);
INST(Mode::SUB);
INST(Mode::SWITCH_GT0);
INST(Mode::TANH_GRAD);
INST(Mode::LT);
INST(Mode::LEQ);
INST(Mode::EQ);
INST(Mode::SHL);
INST(Mode::SHR);
INST(Mode::FUSE_ADD_RELU);
INST(Mode::RMULH);
#undef INST

#define INST(mode) \
    template void megdnn::fallback::ElemwiseImpl::exec_BINARY_FLOAT<mode>()
INST(Mode::ABS_GRAD);
INST(Mode::ADD);
INST(Mode::FLOOR_DIV);
INST(Mode::MAX);
INST(Mode::MIN);
INST(Mode::MOD);
INST(Mode::MUL);
INST(Mode::POW);
INST(Mode::SIGMOID_GRAD);
INST(Mode::SUB);
INST(Mode::SWITCH_GT0);
INST(Mode::TANH_GRAD);
INST(Mode::TRUE_DIV);
INST(Mode::LOG_SUM_EXP);
INST(Mode::LT);
INST(Mode::LEQ);
INST(Mode::EQ);
INST(Mode::FUSE_ADD_RELU);
INST(Mode::FUSE_ADD_SIGMOID);
INST(Mode::FUSE_ADD_TANH);
INST(Mode::FAST_TANH_GRAD);
INST(Mode::ATAN2);
INST(Mode::H_SWISH_GRAD);
INST(Mode::FUSE_ADD_H_SWISH);
INST(Mode::SILU_GRAD);
INST(Mode::GELU_GRAD);
#undef INST
}  // namespace fallback
}  // namespace megdnn

// vim: syntax=cpp.doxygen