#include "./helper.h"
using namespace megdnn;
using namespace cuda;
using namespace convolution;
bool convolution::is_cudnn_supported(const ForwardSizeArgs& args) {
if (args.src_layout->dtype == args.filter_layout->dtype &&
args.src_layout->dtype == dtype::BFloat16()) {
return false;
}
if (args.handle->is_tegra_k1())
return false;
if (args.filter_meta.format == param::Convolution::Format::NCHW4) {
if (args.dst_layout->dtype.enumv() != DTypeEnum::Int8 &&
args.dst_layout->dtype.enumv() != DTypeEnum::QuantizedS8) {
return false;
}
} else if (
args.filter_meta.format != param::Convolution::Format::NCHW &&
args.filter_meta.format != param::Convolution::Format::NHWC) {
return false;
}
auto& fm = args.filter_meta;
bool supported = true;
supported &= (fm.spatial_ndim == 2);
#if CUDNN_VERSION < 7000
supported &= (fm.group == 1);
#endif
#if CUDNN_VERSION < 7500
supported &= (fm.dilation[0] == 1 && fm.dilation[1] == 1);
#endif
return supported;
}
SmallVector<size_t> convolution::matmul_get_workspace_bundle(
const ForwardSizeArgs& args) {
auto dtype = args.src_layout->dtype;
auto&& fm = args.filter_meta;
megdnn_assert(fm.group == 1);
auto N = args.src_layout->shape[0];
auto OC = fm.ocpg, IC = fm.icpg, FH = fm.spatial[0], FW = fm.spatial[1];
auto OH = args.dst_layout->shape[2], OW = args.dst_layout->shape[3];
SmallVector<size_t> sizes{
dtype.size() * args.dst_layout->total_nr_elems(),
dtype.size() * IC * FH * FW * OH * OW * N};
if (args.filter_meta.should_flip) {
sizes.push_back(dtype.size() * OC * IC * FH * FW);
}
return sizes;
}
void convolution::flip_filter(
const ForwardSizeArgs& args, const Workspace& workspace, RefPtr& ref_ptr) {
auto&& fm = args.filter_meta;
megdnn_assert(fm.group == 1 && fm.spatial_ndim == 2);
auto OC = fm.ocpg, IC = fm.icpg, FH = fm.spatial[0], FW = fm.spatial[1];
auto dtype = fm.dtype;
megdnn_assert(workspace.size >= dtype.size() * OC * IC * FH * FW);
TensorND src{{{OC, IC, FH, FW}, dtype}, ref_ptr},
dst{workspace.raw_ptr + (FH * FW - 1) * dtype.size(), src.layout};
dst.layout.stride[2] = -dst.layout.stride[2];
dst.layout.stride[3] = -dst.layout.stride[3];
args.handle->relayout_opr()->exec(src, dst);
ref_ptr.reset(workspace.raw_ptr);
}