#pragma once
#include "./opr_impl.h"
#include "src/common/algo_chooser.h"
#include "src/common/utils.h"
#include "src/cuda/cudnn_wrapper.h"
#include "src/cuda/handle.h"
namespace megdnn {
namespace cuda {
namespace convolution {
using CanonizedFilterMeta = ConvolutionForward::CanonizedFilterMeta;
struct ForwardSizeArgs {
HandleImpl* handle;
const TensorLayout* src_layout;
const TensorLayout* filter_layout;
CanonizedFilterMeta filter_meta;
const TensorLayout* dst_layout;
};
bool is_cudnn_supported(const ForwardSizeArgs& args);
SmallVector<size_t> matmul_get_workspace_bundle(const ForwardSizeArgs& args);
struct CUDNNForwardDescs {
TensorDesc src_desc, dst_desc;
FilterDesc<param::Convolution> filter_desc;
ConvDesc conv_desc;
void set(
const TensorLayout& src, const CanonizedFilterMeta& filter,
const TensorLayout& dst, const param::Convolution& param) {
src_desc.set(src, param.format);
filter_desc.set(filter);
dst_desc.set(dst, param.format);
conv_desc.set(src.dtype, param, filter.group);
}
};
struct CUDNNBwdDataDescs {
TensorDesc diff_desc, grad_desc;
FilterDesc<param::Convolution> filter_desc;
ConvDesc conv_desc;
void set(
const CanonizedFilterMeta& filter, const TensorLayout& diff,
const TensorLayout& grad, const param::Convolution& param) {
filter_desc.set(filter);
diff_desc.set(diff, param.format);
grad_desc.set(grad, param.format);
conv_desc.set(filter.dtype, param, filter.group);
}
};
struct CUDNNBwdFilterDescs {
TensorDesc diff_desc, src_desc;
FilterDesc<param::Convolution> grad_desc;
ConvDesc conv_desc;
void set(
const TensorLayout& src, const TensorLayout& diff,
const CanonizedFilterMeta& grad, const param::Convolution& param) {
src_desc.set(src, param.format);
diff_desc.set(diff, param.format);
grad_desc.set(grad);
conv_desc.set(src.dtype, param, grad.group);
}
};
void flip_filter(
const ForwardSizeArgs& args, const Workspace& workspace, RefPtr& raw_ptr);
} } }