#include "src/fallback/conv_bias/conv1x1/conv1x1_utils.h"
namespace megdnn {
namespace fallback {
namespace conv1x1 {
namespace utils {
WorkspaceBundle get_thread_bundle(
const ConvBiasImpl::NCBKernSizeParam& param, size_t matmul_c_size,
size_t oc_tile_size) {
size_t OH = param.osz[0];
size_t OW = param.osz[1];
bool is_dst_8bit = (param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
(param.src_type.enumv() == DTypeEnum::Quantized8Asymm &&
param.dst_type.enumv() == DTypeEnum::Quantized8Asymm);
size_t matmul_dst_bytes_per_thread =
is_dst_8bit ? oc_tile_size * OH * OW * sizeof(param.bias_type) : 0;
return WorkspaceBundle{nullptr, {matmul_c_size, matmul_dst_bytes_per_thread}};
}
MatrixMulImpl::KernSizeParam get_matmul_kern_param(
const ConvBiasImpl::NCBKernSizeParam& param, size_t n, size_t m) {
size_t M = m;
size_t N = n;
size_t K = param.filter_meta.icpg; size_t LDA = K, LDB = N, LDC = N;
bool is_dst_8bit = (param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
(param.src_type.enumv() == DTypeEnum::Quantized8Asymm &&
param.dst_type.enumv() == DTypeEnum::Quantized8Asymm);
size_t pack_c_size = pack_size(param.filter_meta.format);
auto format = param::MatrixMul::Format::DEFAULT;
if (param.filter_meta.format == param::ConvBias::Format::NCHW44) {
format = param::MatrixMul::Format::MK4;
} else if (param.filter_meta.format == param::ConvBias::Format::NCHW44_DOT) {
format = param::MatrixMul::Format::MK4_DOT;
}
return {param.filter_type,
param.src_type,
is_dst_8bit ? param.bias_type : param.dst_type,
M,
N,
K,
LDA * pack_c_size,
LDB * pack_c_size,
LDC * pack_c_size,
false,
false,
param::MatrixMul::ComputeMode::DEFAULT,
format};
}
} } } }