#pragma once
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/common/utils.h"
#include "src/fallback/convolution/opr_impl.h"
namespace megdnn {
namespace arm_common {
class ConvBiasImpl;
class ConvolutionBackwardDataImpl : public fallback::ConvolutionBackwardDataImpl {
public:
using fallback::ConvolutionBackwardDataImpl::ConvolutionBackwardDataImpl;
protected:
class AlgoBase : public fallback::ConvolutionBackwardDataImpl::AlgoBase {
protected:
~AlgoBase() = default;
public:
AlgoBase() : fallback::ConvolutionBackwardDataImpl::AlgoBase() {
m_handle_type = Handle::HandleType::ARM_COMMON;
}
virtual bool usable(
fallback::ConvolutionBackwardDataImpl* opr,
const NCBKernSizeParam& param) const = 0;
virtual size_t get_workspace(
fallback::ConvolutionBackwardDataImpl* opr,
const NCBKernSizeParam& param) const = 0;
virtual ncb_kern_t dispatch_kern(
fallback::ConvolutionBackwardDataImpl* opr,
const NCBKernSizeParam& param) const = 0;
};
ncb_kern_t ncb_1g_dispatch_kern(
Algorithm* algo, const NCBKernSizeParam& param) override;
size_t ncb_1g_get_workspace(
Algorithm* algo, const NCBKernSizeParam& param) override;
const char* get_algorithm_set_name() const override;
SmallVector<fallback::ConvolutionBackwardDataImpl::AlgoBase*> get_all_packed_algo()
override;
public:
MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(ConvolutionBackwardDataImpl);
private:
#if MGB_ENABLE_DOT
class AlgoSdot8DirectStride1;
class AlgoSdot8DirectStride2;
class AlgoUdot8DirectStride1;
class AlgoUdot8DirectStride2;
#endif
class AlgoPack;
static const AlgoPack& algo_pack();
};
} }