megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file dnn/src/common/algo_chooser.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */

#pragma once

#include <cstddef>
#include <limits>
#include <utility>
#include <vector>

#include "megdnn/common.h"
#include "megdnn/heuristic_cache.h"
#include "utils.h"

namespace megdnn {

template <class Opr, typename... Args>
size_t get_dnn_workspace(Opr* opr, Args&&... args) {
    TensorLayoutArray layouts{{args...}};
    HeuristicCache::Key key{opr->handle(),  opr->get_opr_type(), layouts.data(),
                            layouts.size(), &opr->param(),       sizeof(opr->param())};
    auto rst = HeuristicCache::instance().get(key);
    if (rst.policy.algo.valid()) {
        return rst.workspace;
    }

    typename Opr::AlgoBase::SizeArgs size_args(opr, std::forward<Args>(args)...);
    return get_algorithm(opr, std::forward<Args>(args)...)
            ->get_workspace_in_bytes(size_args);
}

/*!
 * \brief get user-configured algorithm, or heuristic algorithm
 */
template <class Opr, typename... Args>
typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) {
    typename Opr::AlgorithmDesc ret;
    auto set = opr->execution_policy().algo;
    if (set.valid()) {
        ret = set;
    } else {
        TensorLayoutArray layouts{{args...}};
        HeuristicCache::Key key{opr->handle(),  opr->get_opr_type(),
                                layouts.data(), layouts.size(),
                                &opr->param(),  sizeof(opr->param())};
        auto rst = HeuristicCache::instance().get(key);
        if (rst.policy.algo.valid()) {
            ret = rst.policy.algo;
        } else {
            ret = opr->get_algorithm_info_heuristic(
                             std::forward<Args>(args)...,
                             std::numeric_limits<size_t>::max(), AlgoAttribute::DEFAULT,
                             AlgoAttribute::DEFAULT)
                          .desc;
        }
    }
    return static_cast<typename Opr::AlgoBase*>(opr->get_algorithm_from_desc(ret));
}

/*!
 * \brief get user-configured algorithm, or heuristic algorithm. used in opencl
 * whose algo need to be constructed each time.
 */
template <class Opr, typename... Args>
typename Opr::AlgoBase* get_algorithm_or_construct(Opr* opr, Args&&... args) {
    auto set = opr->execution_policy().algo;
    if (set.valid()) {
        return opr->algo_pack().construct_and_get_algo(set);
    } else {
        return static_cast<typename Opr::AlgoBase*>(opr->get_algorithm_heuristic(
                std::forward<Args>(args)..., std::numeric_limits<size_t>::max(),
                AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT));
    }
}

/*!
 * \brief get all algorithms from algo_pack() that is available for current size
 */
template <class Opr>
std::vector<typename Opr::Algorithm*> get_all_algorithms(
        const typename Opr::AlgoBase::SizeArgs& args) {
    std::vector<typename Opr::Algorithm*> ret;
    ret.reserve(Opr::algo_pack().all_algos.size());
    for (auto i : Opr::algo_pack().all_algos) {
        if (i->is_available(args)) {
            ret.push_back(i);
        }
    }
    return ret;
}
template <class Opr>
std::vector<typename Opr::Algorithm*> get_all_algorithms_safe(
        const typename Opr::AlgoBase::SizeArgs& args) {
    auto ret_safe = get_all_algorithms<Opr>(args);
    megdnn_assert(!ret_safe.empty(), "no algorithm for %s", args.to_string().c_str());
    return ret_safe;
}

/*!
 * \brief a helper function to get an algorithm match attribute. If require a
 * algorithm with specified attribute, and the given algorithm match that
 * attribute, return the given algorithm. Otherwise return nullptr
 */
template <typename Opr>
typename Opr::Algorithm* get_algo_match_attribute(
        typename Opr::AlgoBase* algo, const AlgoAttribute& positive_attr,
        const AlgoAttribute& negative_attr) {
    if (algo->contain_attribute_all(positive_attr) &&
        !algo->contain_attribute_any(negative_attr)) {
        return algo;
    }
    return nullptr;
}

template <typename Opr>
typename Opr::Algorithm* get_algo_match_attribute(
        const std::vector<typename Opr::AlgoBase*>& algos,
        const typename Opr::AlgoBase::SizeArgs& args, size_t workspace_limit_in_bytes,
        const char* name,
        const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
        const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) {
    size_t min_workspace_limit_in_bytes = std::numeric_limits<size_t>::max();
    bool available_but_limited_by_workspace = false;
    bool available_but_attribute_mismatch = false;
    for (auto i : algos) {
        if (i->is_available_attribute(
                    args, positive_attr, negative_attr, workspace_limit_in_bytes)) {
            return i;
        }
        if (i->is_available_attribute(args, positive_attr, negative_attr)) {
            if (i->get_workspace_in_bytes(args) > workspace_limit_in_bytes) {
                available_but_limited_by_workspace = true;
                min_workspace_limit_in_bytes = std::min(
                        min_workspace_limit_in_bytes, i->get_workspace_in_bytes(args));
            }
        }
        if (i->is_available(args)) {
            if (!(i->contain_attribute_all(positive_attr) &&
                  !i->contain_attribute_any(negative_attr)))
                available_but_attribute_mismatch = true;
        }
    }

    MEGDNN_MARK_USED_VAR(name);
    if (available_but_limited_by_workspace) {
        megdnn_throw(ssprintf(
                "no %s algorithm without attribute(%s) with "
                "attribute(%s) : %s workspace limit %zu is "
                "less than mini workspace limit %zu",
                name, Algorithm::attribute_str(negative_attr).c_str(),
                Algorithm::attribute_str(positive_attr).c_str(),
                args.to_string().c_str(), workspace_limit_in_bytes,
                min_workspace_limit_in_bytes));
    } else if (available_but_attribute_mismatch) {
        megdnn_throw(ssprintf(
                "no %s algorithm without attribute(%s) with attribute(%s)", name,
                Algorithm::attribute_str(negative_attr).c_str(),
                Algorithm::attribute_str(positive_attr).c_str()));
    } else {
        megdnn_throw(ssprintf("no usable %s algorithm", name));
    }
}

}  // namespace megdnn

// vim: syntax=cpp.doxygen