#pragma once
#include <unordered_map>
#include "megdnn/opr_param_defs.h"
#include "megdnn/oprs/base.h"
#include "src/common/algo_base.h"
#include "src/common/utils.h"
#include "src/naive/matrix_mul/opr_impl.h"
namespace megdnn {
struct AlgoTypePack {
detail::AlgoDataType data_type : 32;
param::MatrixMul::Format format : 32;
};
namespace fallback {
class MatrixMulImpl : public naive::MatrixMulForwardImpl {
public:
using naive::MatrixMulForwardImpl::MatrixMulForwardImpl;
using AlgoDataType = detail::AlgoDataType;
bool is_thread_safe() const override { return true; }
size_t get_workspace_in_bytes(
const TensorLayout&, const TensorLayout&, const TensorLayout&) override;
void exec(
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
_megdnn_workspace workspace) override;
struct KernSizeParam {
DType A_type, B_type, C_type;
size_t M, N, K;
size_t LDA, LDB, LDC;
bool trA, trB;
Param::ComputeMode compute_mode;
Param::Format format;
AlgoDataType deduce_algo_data_type() const;
};
struct KernParam : public KernSizeParam {
RefPtr A_ptr;
RefPtr B_ptr;
RefPtr C_ptr;
void* workspace_ptr = nullptr;
size_t workspace_size = 0;
template <typename T>
inline const T* A() const {
return static_cast<const T*>(A_ptr.get_ptr());
}
template <typename T>
inline const T* B() const {
return static_cast<const T*>(B_ptr.get_ptr());
}
template <typename T>
inline T* C() const {
return static_cast<T*>(C_ptr.get_ptr());
}
template <typename T>
inline T* workspace() const {
return static_cast<T*>(workspace_ptr);
}
};
typedef void (*kern_t)(const KernParam&);
typedef void (*kern_naked_t)(
const KernParam&, const void* a_panel, const void* b_panel);
class AlgoBase : public Algorithm {
protected:
virtual ~AlgoBase() = default;
bool can_be_treated_as_int8x8x32(const KernSizeParam& param) const {
return param.A_type.enumv() == param.B_type.enumv() &&
(param.A_type.enumv() == DTypeEnum::Int8 ||
param.A_type.enumv() == DTypeEnum::QuantizedS8) &&
(param.C_type.enumv() == DTypeEnum::Int32 ||
param.C_type.enumv() == DTypeEnum::QuantizedS32) &&
param.compute_mode == Param::ComputeMode::DEFAULT &&
param.format == param::MatrixMul::Format::DEFAULT;
}
bool can_be_treated_as_int8x8x16(const KernSizeParam& param) const {
return param.A_type.enumv() == param.B_type.enumv() &&
(param.A_type.enumv() == DTypeEnum::Int8 ||
param.A_type.enumv() == DTypeEnum::QuantizedS8) &&
(param.C_type.enumv() == DTypeEnum::Int16 ||
param.C_type.enumv() == DTypeEnum::QuantizedS16);
}
public:
AlgoBase() { m_handle_type = Handle::HandleType::FALLBACK; }
enum class AlgoType : uint32_t {
FB_F32K8x12x1 = 1 << 0,
FB_GEMV,
FB_NAIVE,
#if MEGDNN_X86
X86_F32_BLAS = 1 << 8,
X86_F32_MKL_PACKA,
X86_INT8X8X32_AVX2_2X4X16,
X86_INT8X8X32_AVX2_4X16X2,
X86_INT8X8X16_AVX2,
X86_INT8X8X16_SSE,
X86_INT8X8X32_SSE_4X8X2,
X86_F32_MK8_8X8,
X86_F32_6x16,
X86_INT8X8X32_VNNI,
X86_INT8X8X32_MKLDNN,
#elif MEGDNN_AARCH64 || MEGDNN_ARMV7
ARM_COMMON_INT8X8X16 = 1 << 8,
ARM_COMMON_INT8X8X32_GEMV,
ARM_COMMON_INT8X8X32_GEMV_MK4,
ARM_COMMON_INT8X8X32_GEMV_MK4_DOT,
ARM_COMMON_F32_GEMV_MK4,
ARM_COMMON_F16_GEMV,
ARM_COMMON_GEVM,
#if MEGDNN_AARCH64
AARCH64_F32_K8X12X1 = 1 << 16,
AARCH64_F32_MK4_K8X12X1,
AARCH64_F32_K4X16X1,
AARCH64_F32_MK4_4x16,
AARCH64_F32_GEMV,
AARCH64_F16_K8X24X1,
AARCH64_F16_MK8_8X8,
AARCH64_INT8X8X32_K8X12X4_DOTPROD,
AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD,
AARCH64_INT8X8X32_MK4_4X4X16,
AARCH64_INT8X8X32_K4X4X16,
AARCH64_INT8X8X32_K8X8X8,
AARCH64_INT8X8X16_K8X8X8,
AARCH64_INT8X8X16_K4X4X16,
AARCH64_INT8X8X16_MK4_16X12X4,
AARCH64_INT8X8X16_MK4_K8X8X8,
AARCH64_INT8X8X16_MK4_4X4X8,
AARCH64_INT16X16X32_K12X8X1,
AARCH64_INT16X16X32_MK8_8X8,
AARCH64_QUINT8_K8X8X4_DOTPROD,
AARCH64_QUINT8_GEMV_DOTPROD,
AARCH64_QUINT8_K8X8X8,
AARCH64_INT4X4X16_K8X8X8,
#else
ARMV7_F32 = 1 << 16,
ARMV7_F32_MK4_PACK_4X12,
ARMV7_F32_MK4_4x8,
ARMV7_F16_K4X16X1,
ARMV7_F16_MK8_4X8,
ARMV7_INT8_K6X8X4,
ARMV7_QUINT8_K4X8X4,
ARMV7_INT8_MK4_8X4X4_DOTPROD,
ARMV7_F32_GEMV,
ARMV7_INT8X8X32_K4X2X16,
ARMV7_INT8X8X32_K4X8X8,
ARMV7_QUINT8_K4X8X8,
ARMV7_INT8X8X16_K4X2X16,
ARMV7_INT8X8X16_K4X8X8,
ARMV7_INT8X8X16_MK4_K8X8X4,
ARMV7_INT16X16X32_K12X4X1,
ARMV7_INT16X16X32_MK8_4X8,
ARMV7_INT8X8X32_MK4_4X2X16,
ARMV7_INT8X8X16_K8X8X4
#endif
#endif
};
enum class AlgoSet : uint32_t {
ALGO_TYPE_GEMM = 0,
ALGO_TYPE_GEMV = 1,
};
enum class PackMode : uint32_t {
DEFAULT = 0,
NO_PACK = 1,
ONLY_PACKA = 2,
};
struct InnerBlockSize {
size_t m, n, k;
};
struct MatmulDescription {
PackMode packmode;
InnerBlockSize innerblocksize;
AlgoTypePack algo_type;
size_t packa_type_size;
};
virtual bool usable(const KernSizeParam&) const = 0;
virtual bool preferred(const KernSizeParam&) const { return true; }
virtual size_t get_workspace(const KernSizeParam&) const = 0;
virtual kern_t get_kern(const KernSizeParam&) const = 0;
virtual kern_naked_t get_kern_naked(const KernSizeParam&) const {
megdnn_assert(0);
};
virtual AlgoSet algoset() const { return AlgoSet::ALGO_TYPE_GEMM; }
virtual PackMode packmode() const { return PackMode::DEFAULT; }
virtual void pack_A(const KernParam&, void*, size_t, size_t) const {
megdnn_assert(0);
};
virtual void pack_B(const KernParam&, void*, size_t, size_t) const {
megdnn_assert(0);
};
virtual WorkspaceBundle get_bundle(const KernSizeParam&) const {
megdnn_assert(0);
};
virtual InnerBlockSize get_inner_block_size() const { megdnn_assert(0); };
bool preferred_attribute(
const KernSizeParam& param,
const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) {
return contain_attribute_all(positive_attr) &&
!contain_attribute_any(negative_attr) && preferred(param);
};
virtual MatmulDescription matmul_description() const = 0;
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
};
private:
class AlgoF32K8x12x1; class AlgoGemv;
class AlgoNaive;
class AlgoPack;
static const AlgoPack& algo_pack();
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
public:
virtual SmallVector<AlgoBase*> get_all_packed_algo();
SmallVector<AlgoBase*> select_algo_type(AlgoTypePack algo_type);
protected:
KernSizeParam make_kern_size_param(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C);
KernParam make_kern_param(
_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
_megdnn_workspace workspace);
std::vector<Algorithm*> get_all_algorithms(
const TensorLayout& A, const TensorLayout& B,
const TensorLayout& C) override;
std::vector<Algorithm*> get_all_algorithms_safe(
const TensorLayout& A, const TensorLayout& B,
const TensorLayout& C) override;
Algorithm* get_algorithm_heuristic(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) override;
};
} }