#pragma once
#include "megdnn/basic_types.h"
#include "megdnn/oprs/base.h"
#include "src/common/handle_impl.h"
#include "src/naive/convolution/algorithms.h"
#include "src/naive/convolution3d/algorithms.h"
#include "src/naive/local_share/algorithms.h"
#include "src/naive/matrix_mul/algorithms.h"
#include <functional>
#include <mutex>
#include <type_traits>
namespace megdnn {
namespace naive {
class HandleImpl : public HandleImplHelper {
using KernFunc = MegcoreCPUDispatcher::Task;
using MultiThreadingKernFunc = MegcoreCPUDispatcher::MultiThreadingTask;
MegcoreCPUDispatcher* m_dispatcher;
static DefaultConvolutionForwardAlgorithm m_default_conv_fwd_algo;
static DefaultConvolutionBackwardDataAlgorithm m_default_conv_bwd_data_algo;
static DefaultConvolutionBackwardFilterAlgorithm m_default_conv_bwd_filter_algo;
static DefaultConvBiasForwardAlgorithm m_default_conv_bias_fwd_algo;
static DefaultConvolution3DForwardAlgorithm m_default_conv3d_fwd_algo;
static DefaultConvolution3DBackwardDataAlgorithm m_default_conv3d_bwd_data_algo;
static DefaultConvolution3DBackwardFilterAlgorithm m_default_conv3d_bwd_filter_algo;
static DefaultBatchConvBiasForwardAlgorithm m_default_batch_conv_bias_fwd_algo;
static DefaultLocalShareForwardAlgorithm m_default_local_share_fwd_algo;
static DefaultLocalShareBackwardDataAlgorithm m_default_local_share_bwd_data_algo;
static DefaultLocalShareBackwardFilterAlgorithm
m_default_local_share_bwd_filter_algo;
static DefaultMatrixMulAlgorithm m_default_matmul_fwd_algo;
static DefaultBatchedMatrixMulAlgorithm m_default_batched_matmul_fwd_algo;
static DefaultPoolingForwardAlgorithm m_default_pooling_fwd_algo;
static DefaultPoolingBackwardAlgorithm m_default_pooling_bwd_algo;
template <typename T>
void move_kern_func_to_new_kern_and_dispatch(T& func) {
m_dispatcher->dispatch(std::move(func));
func.~T();
}
template <typename T>
void move_kern_func_to_new_kern_and_dispatch(T& func, size_t parallelism) {
m_dispatcher->dispatch(std::move(func), parallelism);
func.~T();
}
public:
HandleImpl(
megcoreComputingHandle_t computing_handle,
HandleType type = HandleType::NAIVE);
template <typename Opr>
std::unique_ptr<Opr> create_operator();
ConvolutionForward::Algorithm* default_conv_fwd_algo() {
return &m_default_conv_fwd_algo;
}
ConvolutionBackwardData::Algorithm* default_conv_bwd_data_algo() {
return &m_default_conv_bwd_data_algo;
}
ConvolutionBackwardFilter::Algorithm* default_conv_bwd_filter_algo() {
return &m_default_conv_bwd_filter_algo;
}
ConvBiasForward::Algorithm* default_conv_bias_fwd_algo() {
return &m_default_conv_bias_fwd_algo;
}
Convolution3DForward::Algorithm* default_conv3d_fwd_algo() {
return &m_default_conv3d_fwd_algo;
}
Convolution3DBackwardData::Algorithm* default_conv3d_bwd_data_algo() {
return &m_default_conv3d_bwd_data_algo;
}
Convolution3DBackwardFilter::Algorithm* default_conv3d_bwd_filter_algo() {
return &m_default_conv3d_bwd_filter_algo;
}
BatchConvBiasForward::Algorithm* default_batch_conv_bias_fwd_algo() {
return &m_default_batch_conv_bias_fwd_algo;
}
LocalShareForward::Algorithm* default_local_share_fwd_algo() {
return &m_default_local_share_fwd_algo;
}
LocalShareBackwardData::Algorithm* default_local_share_bwd_data_algo() {
return &m_default_local_share_bwd_data_algo;
}
LocalShareBackwardFilter::Algorithm* default_local_share_bwd_filter_algo() {
return &m_default_local_share_bwd_filter_algo;
}
MatrixMulForward::Algorithm* default_matmul_fwd_algo() {
return &m_default_matmul_fwd_algo;
}
BatchedMatrixMulForward::Algorithm* default_batched_matmul_fwd_algo() {
return &m_default_batched_matmul_fwd_algo;
}
PoolingForward::Algorithm* default_pooling_fwd_algo() {
return &m_default_pooling_fwd_algo;
}
PoolingBackward::Algorithm* default_pooling_bwd_algo() {
return &m_default_pooling_bwd_algo;
}
Relayout* relayout_opr() override { return get_helper_opr<Relayout, 2>(this); }
template <class T>
void dispatch_kern(T&& kern) {
std::aligned_storage<sizeof(KernFunc), alignof(KernFunc)>::type s;
move_kern_func_to_new_kern_and_dispatch(
*new (&s) KernFunc(std::forward<T>(kern)));
}
template <class T>
void dispatch_kern(T&& kern, size_t parallelism) {
std::aligned_storage<
sizeof(MultiThreadingKernFunc), alignof(MultiThreadingKernFunc)>::type
s;
move_kern_func_to_new_kern_and_dispatch(
*new (&s) MultiThreadingKernFunc(std::forward<T>(kern)), parallelism);
}
MegcoreCPUDispatcher* megcore_dispatcher() const { return m_dispatcher; }
size_t image2d_pitch_alignment() const override;
static size_t exchange_image2d_pitch_alignment(size_t alignment);
HandleVendorType vendor_type() const override;
};
} }
#define MEGDNN_DISPATCH_CPU_KERN(_handle, _stmt) \
do { \
auto _kern = [=]() { _stmt; }; \
_handle->dispatch_kern(_kern); \
} while (0)
#define MEGDNN_DISPATCH_CPU_KERN_OPR(_stmt) \
MEGDNN_DISPATCH_CPU_KERN(static_cast<::megdnn::naive::HandleImpl*>(handle()), _stmt)
#define MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN(_handle, _parallelism, _stmt) \
do { \
_handle->dispatch_kern(_stmt, _parallelism); \
} while (0)
#define MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN_OPR(_stmt, _parallelism) \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<::megdnn::naive::HandleImpl*>(handle()), _parallelism, _stmt)