#include "src/arm_common/conv_bias/quint8/algos.h"
#include "src/arm_common/conv_bias/quint8/stride1.h"
#include "src/arm_common/conv_bias/quint8/stride1_dotprod.h"
#include "src/arm_common/conv_bias/quint8/stride2.h"
#include "src/arm_common/conv_bias/quint8/stride2_dotprod.h"
#include "src/arm_common/elemwise_op.h"
#include "src/fallback/conv_bias/common.h"
#include "midout.h"
MIDOUT_DECL(megdnn_arm_common_conv_bias_quint8)
using namespace megdnn;
using namespace arm_common;
bool ConvBiasImpl::AlgoQU8DirectStride1::usable(
const NCBKernSizeParam& param, AlgoSelectionStrategy) const {
return direct_quint8_stride1::can_conv_direct_stride1_quint8(param);
}
size_t ConvBiasImpl::AlgoQU8DirectStride1::get_workspace(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(
megdnn_arm_common_conv_bias_quint8,
midout_iv("AlgoQU8DirectStride1::get_workspace"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_quint8_stride1::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoQU8DirectStride1::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(
megdnn_arm_common_conv_bias_quint8,
midout_iv("AlgoQU8DirectStride1::dispatch_kerns"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
return direct_quint8_stride1::get_kimpls(param, large_group);
}
MIDOUT_END();
return {};
}
bool ConvBiasImpl::AlgoQU8DirectStride2::usable(
const NCBKernSizeParam& param, AlgoSelectionStrategy) const {
return direct_quint8_stride2::can_conv_direct_stride2_quint8(param);
}
size_t ConvBiasImpl::AlgoQU8DirectStride2::get_workspace(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(
megdnn_arm_common_conv_bias_quint8,
midout_iv("AlgoQU8DirectStride1::get_workspace"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_quint8_stride2::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoQU8DirectStride2::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(
megdnn_arm_common_conv_bias_quint8,
midout_iv("AlgoQU8DirectStride1::dispatch_kerns"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
return direct_quint8_stride2::get_kimpls(param, large_group);
}
MIDOUT_END();
return {};
}
#if MGB_ENABLE_DOT
bool ConvBiasImpl::AlgoDotU8DirectStride1::usable(
const NCBKernSizeParam& param, AlgoSelectionStrategy) const {
if (!cpuinfo_has_arm_neon_dot()) {
return false;
}
return direct_dotprod_quint8_stride1::can_conv_direct_stride1_quint8(param);
}
size_t ConvBiasImpl::AlgoDotU8DirectStride1::get_workspace(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(
megdnn_arm_common_conv_bias_quint8,
midout_iv("AlgoQU8DirectStride1::get_workspace"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_dotprod_quint8_stride1::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoDotU8DirectStride1::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(
megdnn_arm_common_conv_bias_quint8,
midout_iv("AlgoQU8DirectStride1::dispatch_kerns"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
return direct_dotprod_quint8_stride1::get_kimpls(param, large_group);
}
MIDOUT_END();
return {};
}
bool ConvBiasImpl::AlgoDotU8DirectStride2::usable(
const NCBKernSizeParam& param, AlgoSelectionStrategy) const {
if (!cpuinfo_has_arm_neon_dot()) {
return false;
}
return direct_dotprod_quint8_stride2::can_conv_direct_stride2_quint8(param);
}
size_t ConvBiasImpl::AlgoDotU8DirectStride2::get_workspace(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(
megdnn_arm_common_conv_bias_quint8,
midout_iv("AlgoQU8DirectStride1::get_workspace"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_dotprod_quint8_stride2::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoDotU8DirectStride2::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(
megdnn_arm_common_conv_bias_quint8,
midout_iv("AlgoQU8DirectStride1::dispatch_kerns"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
return direct_dotprod_quint8_stride2::get_kimpls(param, large_group);
}
MIDOUT_END();
return {};
}
#endif