#pragma once
#include <unordered_map>
#include "src/common/algo_base.h"
#include "src/common/metahelper.h"
#include "src/x86/handle.h"
#include "src/x86/pooling/opr_impl.h"
namespace megdnn {
namespace x86 {
using AlgoBase = PoolingImpl::AlgoBase;
class PoolingImpl::AlgoBase : public Algorithm {
public:
enum class AlgoType : uint32_t {
X86_MeanW2S2AVX,
X86_MeanW2S2SSE3,
X86_MaxW2S2SSE,
X86_MaxW3S3SSE,
X86_MaxS1NCHW88AVX,
#if MEGDNN_X86_WITH_MKL_DNN
X86_MKLDNNNCHW,
X86_MKLDNNNCHW88,
#endif
X86_Fallback
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::X86; }
virtual ~AlgoBase() = default;
struct SizeArgs {
HandleImpl* handle;
PoolingImpl* opr;
const TensorLayout layout_src, layout_dst;
std::string to_string() const;
SizeArgs(PoolingImpl* opr, const TensorLayout& src, const TensorLayout& dst);
};
struct ExecArgs : public SizeArgs {
const TensorND src_tensor, dst_tensor;
Workspace workspace;
ExecArgs(
PoolingImpl* opr, _megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace);
};
virtual bool is_available(const SizeArgs& args) const = 0;
virtual void exec(const ExecArgs& args) const = 0;
uint32_t type() const override { return INVALID_ALGO_TYPE; };
bool is_available_attribute(
const SizeArgs& args,
const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) {
return contain_attribute_all(positive_attr) &&
!contain_attribute_any(negative_attr) && is_available(args);
}
};
#define ALGO_IMPL(_name) \
class PoolingImpl::Algo##_name final : public AlgoBase { \
std::string m_algo_name; \
\
public: \
Algo##_name() : m_algo_name(std::string(#_name).append("_POOLING")) {} \
AlgoAttribute attribute() const override { \
return AlgoAttribute::REPRODUCIBLE; \
}; \
const char* name() const override { return m_algo_name.c_str(); } \
bool is_available(const SizeArgs& args) const override; \
void exec(const ExecArgs& args) const override; \
MEGDNN_DECL_ALGO_TYPE(X86_##_name) \
};
ALGO_IMPL(MeanW2S2AVX)
ALGO_IMPL(MeanW2S2SSE3)
ALGO_IMPL(MaxW2S2SSE)
ALGO_IMPL(MaxW3S3SSE)
ALGO_IMPL(MaxS1NCHW88AVX)
#if MEGDNN_X86_WITH_MKL_DNN
ALGO_IMPL(MKLDNNNCHW)
ALGO_IMPL(MKLDNNNCHW88)
#endif
#undef ALGO_IMPL
class PoolingImpl::AlgoFallback final : public AlgoBase {
std::string m_algo_name;
public:
AlgoFallback() : m_algo_name("FALLBACK_POOLING") {}
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; };
const char* name() const override { return m_algo_name.c_str(); }
bool is_available(const SizeArgs&) const override { return true; }
void exec(const ExecArgs&) const override {}
MEGDNN_DECL_ALGO_TYPE(X86_Fallback)
};
class PoolingImpl::AlgoPack : NonCopyableObj {
private:
AlgoBase::Mapper m_all_algos_map;
AlgoMeanW2S2AVX algo_mean_w2s2_avx;
AlgoMeanW2S2SSE3 algo_mean_w2s2_sse3;
AlgoMaxW2S2SSE algo_max_w2s2_sse;
AlgoMaxW3S3SSE algo_max_w3s3_sse;
#if MEGDNN_X86_WITH_MKL_DNN
AlgoMKLDNNNCHW algo_mkldnn_nchw;
AlgoMKLDNNNCHW88 algo_mkldnn_nchw88;
#endif
AlgoMaxS1NCHW88AVX algo_max_w13s1_nchw88_avx;
AlgoFallback algo_fallback;
public:
AlgoPack();
std::vector<AlgoBase*> all_algos;
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
};
} }