megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file dnn/src/common/elemwise/opr_impl.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "src/common/elemwise/kern_defs.cuh"
#include "src/common/utils.h"

#include "megdnn/oprs.h"
#include "megdnn/tensor_format.h"

#include "midout.h"
MIDOUT_DECL(megdnn_common_elemwise)

#include <mutex>
#include <vector>

using namespace megdnn;

namespace {
class FormatDeducer {
    const TensorFormat m_default;
    TensorFormat m_result = m_default;

public:
    inline void feed(TensorFormat cur);
    bool is_default(TensorFormat f) const { return f == m_default; }
    TensorFormat get() const { return m_result; }
};
}  // anonymous namespace

using Mode = param::Elemwise::Mode;
using ModeTrait = ElemwiseForward::ModeTrait;

const ModeTrait& ModeTrait::from_mode(Mode mode) {
    static DNN_MUTEX mtx;
    static std::vector<ModeTrait> traits;

    MEGDNN_LOCK_GUARD(mtx);

    if (traits.empty()) {
        auto get = [&](Mode m) -> ModeTrait& {
            auto im = static_cast<size_t>(m);
            if (im >= traits.size())
                traits.resize(im + 1);
            return traits[im];
        };

#define cb(_m)                                                  \
    MIDOUT_BEGIN(megdnn_common_elemwise, midout_iv(Mode::_m)) { \
        get(Mode::_m).allow_int = true;                         \
    }                                                           \
    MIDOUT_END();
        MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_INT(cb);
        MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_INT(cb);
        MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_INT(cb);
#undef cb

#define cb(_m)                                                  \
    MIDOUT_BEGIN(megdnn_common_elemwise, midout_iv(Mode::_m)) { \
        get(Mode::_m).allow_float = true;                       \
    }                                                           \
    MIDOUT_END();
        MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_FLOAT(cb);
        MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_FLOAT(cb);
        MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb);
#undef cb

#define cb(_m)                                                  \
    MIDOUT_BEGIN(megdnn_common_elemwise, midout_iv(Mode::_m)) { \
        get(Mode::_m).allow_bool = true;                        \
    }                                                           \
    MIDOUT_END();
        MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_BOOL(cb);
        MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_BOOL(cb);
#undef cb

#define cb(_m)                                                  \
    MIDOUT_BEGIN(megdnn_common_elemwise, midout_iv(Mode::_m)) { \
        auto&& t = get(Mode::_m);                               \
        t.arity = _a;                                           \
        t.name = (#_m);                                         \
    }                                                           \
    MIDOUT_END();
#define _a 1
        MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_FLOAT(cb);
        MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_INT(cb);
        MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_BOOL(cb);
#undef _a
#define _a 2
        MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_FLOAT(cb);
        MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_INT(cb);
        MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_BOOL(cb);
#undef _a
#define _a 3
        MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb);
        MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_INT(cb);
#undef _a
#undef cb

#define FUSE(_m, _arity)                                        \
    MIDOUT_BEGIN(megdnn_common_elemwise, midout_iv(Mode::_m)) { \
        auto&& t = get(Mode::_m);                               \
        t.allow_int = true;                                     \
        t.allow_float = true;                                   \
        t.allow_bool = true;                                    \
        t.arity = _arity;                                       \
        t.name = (#_m);                                         \
    }                                                           \
    MIDOUT_END();
        FUSE(FUSE_MUL_ADD3, 3);
        FUSE(FUSE_MUL_ADD4, 4);
#undef FUSE

#define COMM_CB(_m)                                              \
    MIDOUT_BEGIN(megdnn_common_elemwise, midout_iv(Mode::_m)) {  \
        traits.at(static_cast<int>(Mode::_m)).commutable = true; \
    }                                                            \
    MIDOUT_END()
#define COMM(_m) MEGDNN_ELEMWISE_MODE_ENABLE(_m, COMM_CB)

        COMM(ADD);
        COMM(FUSE_ADD_RELU);
        COMM(FUSE_ADD_SIGMOID);
        COMM(FUSE_ADD_TANH);
        COMM(MUL);
        COMM(RMULH);
        COMM(MAX);
        COMM(MIN);
        COMM(EQ);
        COMM(LOG_SUM_EXP);

#undef COMM
#undef COMM_CB

#if MEGDNN_ELEMWISE_MODE_ENABLE_ALL
        for (auto&& i : traits) {
            megdnn_assert(
                    i.arity && (i.allow_int || i.allow_float || i.allow_bool) &&
                    (!i.commutable || i.arity == 2));
        }
#else
#pragma message "elemwise mode stripped"
#endif
    }

    auto&& ret = traits.at(static_cast<int>(mode));
#if !MEGDNN_ELEMWISE_MODE_ENABLE_ALL
    megdnn_assert(ret.arity);
#endif
    return ret;
}

void ElemwiseForward::deduce_shape(const TensorShapeArray& src, TensorShape& dst) {
    auto err = [&]() {
        std::string msg("bad input shape for polyadic operator: ");
        bool first = true;
        for (auto&& i : src) {
            if (first)
                first = false;
            else
                msg.append(", ");
            msg.append(i.to_string());
        }
        megdnn_throw(msg);
    };

    dst.ndim = 0;
    for (auto&& cur : src) {
        if (!cur.ndim)
            err();
        if (!dst.ndim || dst.is_scalar())
            dst = cur;
        else if (!cur.is_scalar()) {
            int max_ndim = std::max(cur.ndim, dst.ndim);
            for (int i = 0; i < max_ndim; ++i) {
                int cur_idx = cur.ndim - i - 1;
                int dst_idx = dst.ndim - i - 1;
                if (cur_idx >= 0 && dst_idx >= 0) {
                    size_t v0 = dst.shape[dst_idx], v1 = cur.shape[cur_idx];
                    if (v0 != v1) {
                        if (v0 > 1 && v1 > 1)
                            err();
                    }
                    int final_idx = std::max(cur_idx, dst_idx);
                    dst.shape[final_idx] = (v0 != 0 && v1 != 0) ? std::max(v0, v1) : 0;
                } else {
                    if (dst_idx < 0) {
                        dst.shape[cur_idx] = cur.shape[cur_idx];
                    }
                }
            }
            dst.ndim = max_ndim;
        }
    }
}

void FormatDeducer::feed(TensorFormat cur) {
    // only one kind of non-default format can exist; and in such case the
    // layouts with default format must be scalar (checked in deduce_layout)
    if (cur == m_default)
        return;

    if (m_result == m_default) {
        m_result = cur;
    } else {
        megdnn_assert(
                m_result == cur, "different input layout formats in elemwise: %s vs %s",
                m_result.impl()->to_string().c_str(), cur.impl()->to_string().c_str());
    }
}

void ElemwiseForward::deduce_format(const TensorFormatArray& src, TensorFormat& dst) {
    FormatDeducer d;
    for (auto i : src) {
        d.feed(i);
    }
    dst = d.get();
}

void ElemwiseForward::deduce_layout(const TensorLayoutArray& src, TensorLayout& dst) {
    megdnn_assert(src.size() == mode_trait().arity);
    DType dtype;
    FormatDeducer format_deducer;
    for (auto&& i : src) {
        if (!dtype.valid()) {
            dtype = i.dtype;
            dst.format = i.format;
        } else {
            megdnn_assert(
                    dtype == i.dtype, "input dtype not unique: get %s and %s",
                    dtype.name(), i.dtype.name());
        }

        format_deducer.feed(i.format);
    }
    dst.format = format_deducer.get();
    if (!format_deducer.is_default(dst.format)) {
        for (auto&& i : src) {
            if (format_deducer.is_default(i.format)) {
                megdnn_assert(
                        i.collapse_contiguous().is_scalar(),
                        "default format can only be used on scalar, got %s",
                        i.to_string().c_str());
            }
        }
    }

    check_dtype(dtype);
    TensorShapeArray src_shp;
    for (auto&& i : src)
        src_shp.push_back(i);
    deduce_shape(src_shp, dst);
    dst.dtype = dtype;
    dst.init_contiguous_stride();
}

void ElemwiseForward::check_layout_and_broadcast(
        const TensorLayoutPtrArray& src, const TensorLayout& dst) {
    megdnn_assert(src.size() == mode_trait().arity);
    DType dtype;
    for (auto i : src) {
        if (!dtype.valid()) {
            dtype = i->dtype;
        } else {
            megdnn_assert(dtype == i->dtype);
        }
        *i = i->broadcast(dst);
    }
    check_dtype(dtype);
    megdnn_assert(dtype == dst.dtype && dst.is_contiguous());
}

void ElemwiseForward::check_dtype(DType dtype) {
    megdnn_assert(dtype.valid());
    auto&& trait = mode_trait();
    switch (dtype.category()) {
        case DTypeCategory::FLOAT:
            megdnn_assert(
                    trait.allow_float, "unsupport mode %s for float\n", trait.name);
            break;
        case DTypeCategory::INT:
            megdnn_assert(trait.allow_int, "unsupport mode %s for int\n", trait.name);
            break;
        case DTypeCategory::BOOL:
            megdnn_assert(trait.allow_bool, "unsupport mode %s for bool\n", trait.name);
            break;
        default:
            megdnn_throw("bad dtype");
    }
}

// vim: syntax=cpp.doxygen