#pragma once
#include "src/common/algo_base.h"
#include "src/common/utils.h"
#include "src/fallback/matrix_mul/opr_impl.h"
namespace megdnn {
namespace arm_common {
class MatrixMulImpl : public fallback::MatrixMulImpl {
public:
using fallback::MatrixMulImpl::MatrixMulImpl;
bool is_thread_safe() const override { return true; }
class AlgoBase : public fallback::MatrixMulImpl::AlgoBase {
public:
AlgoBase() : fallback::MatrixMulImpl::AlgoBase() {
m_handle_type = Handle::HandleType::ARM_COMMON;
}
};
SmallVector<fallback::MatrixMulImpl::AlgoBase*> get_all_packed_algo() override;
MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(MatrixMulImpl);
protected:
class AlgoF32Gemv; class AlgoF32GemvMK4; class AlgoInt8x8x32Gemv; class AlgoInt8x8x32GemvMK4; class AlgoGevm; #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class AlgoF16Gemv;
#endif
#if MGB_ENABLE_DOT
class AlgoInt8x8x32GemvMK4Dot; #endif
class AlgoInt8x8x16; class AlgoPack;
public:
static const AlgoPack& algo_pack();
};
} }