#include "./algo.h"
#include "src/cuda/handle.h"
#include "src/cuda/utils.h"
#include "./quint4x4x32_wmma/activation_u4.cuh"
#include "./quint4x4x32_wmma/reduce_with_scale_data.cuh"
#include "./quint4x4x32_wmma/wmma_conv_integer_u4.cuh"
#include "./reduce_filter.cuh"
using namespace megdnn;
using namespace cuda;
using namespace activation_u4;
#if CUDA_VERSION >= 10000
bool ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA::is_available(
const SizeArgs& args) const {
if (!args.src_layout->is_contiguous() || !args.dst_layout->is_contiguous()) {
return false;
}
if (args.z_layout->ndim > 0)
return false;
bool available = true;
auto&& filter_meta = args.filter_meta;
available &= (filter_meta.spatial[0] == 3 && filter_meta.spatial[1] == 3) ||
(filter_meta.spatial[0] == 5 && filter_meta.spatial[1] == 5) ||
(filter_meta.spatial[0] == 7 && filter_meta.spatial[1] == 7);
available &= (filter_meta.stride[0] == 1 && filter_meta.stride[1] == 1);
available &= (args.dst_layout->operator[](3) % 8 == 0);
auto&& param = args.opr->param();
using Param = param::ConvBias;
available &= (param.sparse == Param::Sparse::DENSE);
available &= (!args.filter_meta.should_flip);
available &= (filter_meta.dilation[0] == 1 && filter_meta.dilation[1] == 1);
available &= (param.format == Param::Format::NCHW8);
auto&& device_prop = current_device_prop();
available &=
(device_prop.major > 7 ||
(device_prop.major == 7 && device_prop.minor >= 5));
available &= param.nonlineMode == Param::NonlineMode::RELU ||
param.nonlineMode == Param::NonlineMode::IDENTITY;
available &= (args.src_layout->operator[](1) * 8) % 32 == 0;
return available;
}
WorkspaceBundle ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA::get_workspace_bundle(
dt_byte* raw_ptr, const SizeArgs& args) const {
size_t N = args.src_layout->operator[](0);
size_t OC = args.filter_layout->operator[](0),
IC = args.filter_layout->operator[](1) * 8,
FH = args.filter_layout->operator[](2),
FW = args.filter_layout->operator[](3);
size_t OH = args.dst_layout->operator[](2), OW = args.dst_layout->operator[](3);
size_t ws_size_zp_filter = OC * sizeof(int32_t);
{
size_t A = OC, B = IC * FH * FW / 8, C = 1;
ws_size_zp_filter += do_dispatch_reduce_workspace_in_bytes(A, B, C);
}
size_t ws_size_zp_data = N * OH * OW * sizeof(int32_t);
size_t ws_size_relayout_filter = get_workspace_in_bytes_do_conv(args);
if (ws_size_relayout_filter > 0) {
WorkspaceBundle ws{
raw_ptr, {ws_size_zp_filter, ws_size_zp_data, ws_size_relayout_filter}};
return ws;
}
WorkspaceBundle ws{raw_ptr, {ws_size_zp_filter, ws_size_zp_data}};
return ws;
}
size_t ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA::get_workspace_in_bytes(
const SizeArgs& args) const {
return get_workspace_bundle(nullptr, args).total_size_in_bytes();
}
bool ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA::use_kernel_fhxfw(
const SizeArgs& args) const {
return (args.filter_meta.spatial[0] == 3 && args.filter_meta.spatial[1] == 3);
}
size_t ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA::get_workspace_in_bytes_do_conv(
const SizeArgs& args) const {
if (use_kernel_fhxfw(args))
return 0_z;
size_t OC = args.filter_layout->operator[](0),
IC = args.filter_layout->operator[](1) * 8,
FH = args.filter_layout->operator[](2),
FW = args.filter_layout->operator[](3);
return OC * IC * FH * FW / 2;
}
void ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA::exec(const ExecArgs& args) const {
auto&& handle = concrete_handle(args.opr->handle());
auto&& ws_bundle = get_workspace_bundle(args.workspace.raw_ptr, args);
auto&& ws_zp_filter = ws_bundle.get_workspace(0);
auto&& ws_zp_data = ws_bundle.get_workspace(1);
size_t N = args.src_layout->operator[](0), IC = args.src_layout->operator[](1) * 8,
IH = args.src_layout->operator[](2), IW = args.src_layout->operator[](3),
OC = args.filter_layout->operator[](0), FH = args.filter_meta.spatial[0],
FW = args.filter_meta.spatial[1], OH = args.dst_layout->operator[](2),
OW = args.dst_layout->operator[](3), PH = args.filter_meta.padding[0],
PW = args.filter_meta.padding[1], SH = args.filter_meta.stride[0],
SW = args.filter_meta.stride[1];
int32_t zp_data = args.src_layout->dtype.param<dtype::Quantized4Asymm>().zero_point;
int32_t zp_filter =
args.filter_layout->dtype.param<dtype::Quantized4Asymm>().zero_point;
int32_t zp_data_filter = zp_data * zp_filter * FH * FW * IC;
auto&& stream = cuda_stream(handle);
do_dispatch_reduce_with_scale_filter_4bit<false>(
static_cast<uint8_t*>(args.filter_tensor->raw_ptr()), -zp_data, OC,
FH * FW * IC / 8, ws_zp_filter.ptr<int32_t>(), stream);
do_dispatch_reduce_with_scale_data_u4(
ws_zp_data.ptr<int32_t>(),
static_cast<uint8_t*>(args.src_tensor->raw_ptr()), N, IH, IW, OH, OW, PH,
PW, FH, FW, SH, SW, IC, -zp_filter, static_cast<uint8_t>(zp_data), stream);
if (use_kernel_fhxfw(args)) {
wmma_conv_integer_subbyte::_do_wmma_conv_integer_subbyte_fhxfw(
static_cast<uint8_t*>(args.src_tensor->raw_ptr()),
static_cast<uint8_t*>(args.filter_tensor->raw_ptr()),
args.dst_tensor->compatible_ptr<int32_t>(), N, IH, IW, OH, OW, PH, PW,
IC, OC, FH, FW, SH, SW, static_cast<uint8_t>(zp_data), stream);
} else {
auto&& ws_relayout_filter = ws_bundle.get_workspace(2);
wmma_conv_integer_subbyte::_do_wmma_conv_integer_subbyte_1xfw(
static_cast<uint8_t*>(args.src_tensor->raw_ptr()),
static_cast<uint8_t*>(args.filter_tensor->raw_ptr()),
args.dst_tensor->compatible_ptr<int32_t>(),
ws_relayout_filter.ptr<uint8_t>(), N, IH, IW, OH, OW, PH, PW, IC, OC,
FH, FW, SH, SW, static_cast<uint8_t>(zp_data), stream);
}
int s0 = args.bias_layout->stride[0], s1 = args.bias_layout->stride[1],
s2 = args.bias_layout->stride[2], s3 = args.bias_layout->stride[3];
s0 = args.bias_layout->shape[0] == 1 ? 0 : s0;
s1 = args.bias_layout->shape[1] == 1 ? 0 : s1;
s2 = args.bias_layout->shape[2] == 1 ? 0 : s2;
s3 = args.bias_layout->shape[3] == 1 ? 0 : s3;
activation_u4::BiasVisitor visitor{
args.bias_tensor->compatible_ptr<int32_t>(), s0, s1, s2, s3};
auto&& param = args.opr->param();
if (param.nonlineMode == Param::NonlineMode::RELU) {
do_dispatch_activation_u4<ActivationRELU>(
args.dst_tensor->compatible_ptr<int32_t>(), visitor,
ws_zp_data.ptr<int32_t>(), ws_zp_filter.ptr<int32_t>(), zp_data_filter,
N, OC, OH, OW, stream);
} else if (param.nonlineMode == Param::NonlineMode::IDENTITY) {
do_dispatch_activation_u4<ActivationIdentity>(
args.dst_tensor->compatible_ptr<int32_t>(), visitor,
ws_zp_data.ptr<int32_t>(), ws_zp_filter.ptr<int32_t>(), zp_data_filter,
N, OC, OH, OW, stream);
}
}
#endif