#pragma once
#include <functional>
#include <string>
#include <tuple>
#include "megdnn/oprs/base.h"
#include "src/common/utils.h"
namespace megdnn {
#define MEGDNN_DECL_ALGO_TYPE(_type) \
uint32_t type() const override { \
return static_cast<std::underlying_type<AlgoType>::type>(AlgoType::_type); \
}
#define MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(_opr) \
static fallback::_opr::AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc)
#define MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(_opr) \
fallback::_opr::AlgoBase* _opr::get_algo_from_desc(const AlgorithmDesc& desc) { \
megdnn_assert( \
algo_pack().all_algos_map().find(desc) != \
algo_pack().all_algos_map().end()); \
return algo_pack().all_algos_map().at(desc); \
}
#define MEGDNN_DEF_GET_ALGO_FROM_DESC(_opr) \
_opr::Algorithm* _opr::get_algorithm_from_desc(const AlgorithmDesc& desc) { \
megdnn_assert( \
algo_pack().all_algos_map().find(desc) != \
algo_pack().all_algos_map().end()); \
return algo_pack().all_algos_map().at(desc); \
}
#define MEGDNN_FOREACH_ALGO_ATTRIBUTE_INHERITABLE(cb) \
cb(AlgoAttribute::ACCURACY_DEPEND_ON_BATCH)
template <typename AlgoBase>
class AlgoConstructMixin {
protected:
std::vector<std::unique_ptr<AlgoBase>> m_refhold;
typename AlgoBase::Mapper m_all_algos_map;
public:
AlgoBase* construct_and_get_algo(const detail::Algorithm::Info::Desc& desc) {
auto iter = m_all_algos_map.find(desc);
if (iter != m_all_algos_map.end()) {
return m_all_algos_map.at(desc);
}
std::string serialized_bin;
AlgoBase::serialize_write_pod(desc.type, serialized_bin);
serialized_bin += desc.param;
m_refhold.emplace_back(AlgoBase::deserialize(serialized_bin));
m_all_algos_map.emplace(desc, m_refhold.back().get());
return m_refhold.back().get();
}
void clear() {
m_all_algos_map.clear();
m_refhold.clear();
}
const typename AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
};
template <std::size_t I = 0, typename Opr, typename... Tp>
inline typename std::enable_if<I == sizeof...(Tp), void>::type set_sub_execution_policy(
const Opr*, std::tuple<Tp...>&) {}
template <std::size_t I = 0, typename Opr, typename... Tp>
inline typename std::enable_if <
I<sizeof...(Tp), void>::type set_sub_execution_policy(
const Opr* opr, std::tuple<Tp...>& t) {
std::get<I>(t)->execution_policy() = opr->execution_policy().sub_policy[I];
set_sub_execution_policy<I + 1, Opr, Tp...>(opr, t);
}
template <typename Opr, typename... SubOpr>
void set_execution_policy(const Opr* opr, SubOpr... sub_oprs) {
if (opr->execution_policy().algo.valid() &&
!opr->execution_policy().sub_policy.empty()) {
megdnn_assert(opr->execution_policy().sub_policy.size() == sizeof...(sub_oprs));
auto&& sub = std::make_tuple(sub_oprs...);
set_sub_execution_policy<0, Opr, SubOpr...>(opr, sub);
}
}
}
namespace std {
template <>
struct hash<megdnn::detail::Algorithm::Info::Desc> {
std::size_t operator()(const megdnn::detail::Algorithm::Info::Desc& desc) const {
return megdnn::hash_combine<size_t>(
megdnn::hash_combine<size_t>(
std::hash<std::string>()(desc.name),
megdnn::hash_combine<size_t>(
std::hash<std::string>()(desc.param),
std::hash<uint32_t>()(desc.type))),
std::hash<uint32_t>()(static_cast<uint32_t>(desc.handle_type)));
}
};
}