#pragma once
#include "src/arm_common/matrix_mul/opr_impl.h"
namespace megdnn {
namespace aarch64 {
class MatrixMulImpl : public arm_common::MatrixMulImpl {
public:
using arm_common::MatrixMulImpl::MatrixMulImpl;
class AlgoBase : public arm_common::MatrixMulImpl::AlgoBase {
public:
AlgoBase() : arm_common::MatrixMulImpl::AlgoBase() {
m_handle_type = Handle::HandleType::AARCH64;
}
};
SmallVector<fallback::MatrixMulImpl::AlgoBase*> get_all_packed_algo() override;
MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(MatrixMulImpl);
private:
class AlgoF32K8x12x1; class AlgoF32MK4_8x12x1; class AlgoF32K4x16x1; class AlgoF32MK4_4x16; class AlgoF32Gemv; #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class AlgoF16K8x24x1; class AlgoF16MK8_8x8; #endif
#if MGB_ENABLE_DOT
class AlgoInt8x8x32K8x12x4DotProd; class AlgoInt8x8x32MK4_8x12x4DotProd; #endif
class AlgoInt8x8x32MK4_4x4x16; class AlgoInt8x8x32K4x4x16; class AlgoInt8x8x32K8x8x8; class AlgoInt8x8x16K8x8x8; class AlgoInt8x8x16K4x4x16; class AlgoInt8x8x16MK4_16x12x4; class AlgoInt8x8x16MK4_4x4x8;
class AlgoInt16x16x32K12x8x1; class AlgoInt16x16x32MK8_8x8;
#if MGB_ENABLE_DOT
class AlgoQuint8K8x8x4DotProd; class AlgoQuint8GemvDotProd; #endif
class AlgoQuint8K8x8x8; class AlgoInt8x8x16MK4_K8x8x8; class AlgoInt4x4x16K8x8x8; class AlgoPack;
public:
static const AlgoPack& algo_pack();
};
} }