#include "./opr_impl.h"
#include "./kern.cuh"
#include "src/common/indexing_multi_axis_vec_kdef.h"
#include "src/cuda/utils.h"
using namespace megdnn;
using namespace cuda;
using namespace indexing_multi_axis_vec;
namespace {
class ExecImplHelper {
template <int nidx, int idx_ndim>
void dispatch_gen_offset_base_nidx_ndim();
template <int nidx>
void dispatch_gen_offset_base_nidx();
void dispatch_gen_offset_base();
protected:
using IndexDesc = IndexingMultiAxisVec::IndexDesc;
using ExecInfo = IndexingMultiAxisVec::ExecInfo;
cudaStream_t m_stream;
const TensorND* const m_data;
const TensorND* const m_value;
const IndexDesc* const m_index;
const ExecInfo* const m_exec_info;
int* const m_offset_base;
TensorLayout m_value_layout_on_data;
size_t m_idx_axis;
TensorShape m_idx_shape;
int m_value_stride;
public:
ExecImplHelper(
const TensorND& data, const TensorND& value, const IndexDesc& index,
const Workspace& workspace, const ExecInfo& exec_info, cudaStream_t stream);
};
template <class Opr>
class ExecImpl : public ExecImplHelper {
void dispatch_exec();
template <typename ctype>
void dispatch_exec_ctype();
template <typename ctype, int ndim>
void dispatch_exec_ctype_ndim();
public:
using ExecImplHelper::ExecImplHelper;
void operator()() {
dispatch_exec();
after_kernel_launch();
}
};
}
ExecImplHelper::ExecImplHelper(
const TensorND& data, const TensorND& value, const IndexDesc& index,
const Workspace& workspace, const ExecInfo& exec_info, cudaStream_t stream)
: m_stream{stream},
m_data{&data},
m_value{&value},
m_index{&index},
m_exec_info{&exec_info},
m_offset_base{workspace.ptr<int>()} {
safe_size_in_kern(data.layout.total_nr_elems());
std::tie(m_value_layout_on_data, m_idx_axis, m_idx_shape) =
IndexingMultiAxisVec::get_value_iter_optimized_layout(
data.layout, value.layout, index, exec_info.idx_axis);
dispatch_gen_offset_base();
m_value_stride = exec_info.value_stride;
}
template <int nidx, int idx_ndim>
void ExecImplHelper::dispatch_gen_offset_base_nidx_ndim() {
GenOffsetBaseParam<nidx, idx_ndim> param;
param.size = m_idx_shape.total_nr_elems();
param.output = m_offset_base;
param.error_tracker = m_exec_info->error_tracker;
param.error_info = m_exec_info->error_info;
megdnn_assert(m_idx_shape.ndim == idx_ndim);
for (int i = 0; i < nidx; ++i) {
auto&& dst = param.indexer[i];
auto&& src = m_index->at(i);
auto src_layout = src.vec.layout.broadcast(m_idx_shape);
for (size_t i = 0; i < idx_ndim; ++i) {
if (i) {
dst.shape[i - 1] = src_layout.shape[i];
}
dst.stride[i] = src_layout.stride[i];
}
dst.ptr = src.vec.ptr<int>();
param.data_shape[i] = m_data->layout.shape[src.axis];
param.data_stride[i] = m_data->layout.stride[src.axis];
}
gen_offset_base(param, m_stream);
}
template <int nidx>
void ExecImplHelper::dispatch_gen_offset_base_nidx() {
switch (m_idx_shape.ndim) {
#define cb(_n) \
case _n: \
return dispatch_gen_offset_base_nidx_ndim<nidx, _n>();
MEGDNN_FOREACH_TENSOR_NDIM(cb)
#undef cb
}
megdnn_throw("bad index ndim");
}
void ExecImplHelper::dispatch_gen_offset_base() {
switch (m_index->size()) {
#define cb(_n) \
case _n: \
return dispatch_gen_offset_base_nidx<_n>();
MEGDNN_FOREACH_TENSOR_NDIM(cb)
#undef cb
}
megdnn_throw("bad index size");
}
template <class Opr>
void ExecImpl<Opr>::dispatch_exec() {
switch (m_data->layout.dtype.enumv()) {
#define cb(_dtype) \
case DTypeTrait<_dtype>::enumv: \
return dispatch_exec_ctype<DTypeTrait<_dtype>::ctype>();
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb
default : megdnn_throw("bad dtype");
}
}
template <class Opr>
template <typename ctype>
void ExecImpl<Opr>::dispatch_exec_ctype() {
switch (m_value_layout_on_data.ndim) {
#define cb(_n) \
case _n: \
return dispatch_exec_ctype_ndim<ctype, _n>();
MEGDNN_FOREACH_TENSOR_NDIM(cb)
#undef cb
default:
megdnn_throw("bad data ndim");
}
}
template <class Opr>
template <typename ctype, int ndim>
void ExecImpl<Opr>::dispatch_exec_ctype_ndim() {
ApplyOprParam<ctype, ndim> param;
param.tot_size = safe_size_in_kern(m_value->layout.total_nr_elems());
param.offset_base = m_offset_base;
param.data = m_data->ptr<ctype>();
param.value = m_value->ptr<ctype>();
param.idx_axis = m_idx_axis;
param.idx_axis_end = m_idx_axis + m_idx_shape.ndim;
param.idx_nelems = m_idx_shape.total_nr_elems();
param.value_stride = m_value_stride;
for (int i = 0; i < ndim; ++i) {
param.value_ly_on_data.stride[i] = m_value_layout_on_data.stride[i];
if (i) {
param.value_ly_on_data.shape[i - 1] = m_value_layout_on_data.shape[i];
}
}
apply_opr<ctype, ndim, Opr>(param, m_stream);
}
size_t IndexingMultiAxisVecImpl::get_workspace_in_bytes(size_t dst_idx_size) {
return dst_idx_size * sizeof(int);
}
void IndexingMultiAxisVecImpl::exec(
_megdnn_tensor_in src, const IndexDesc& index, _megdnn_tensor_out dst,
_megdnn_workspace workspace) {
auto info = check_exec(src.layout, index, dst.layout, workspace.size);
info.error_tracker = m_error_tracker;
info.error_info = async_error_info(handle());
ExecImpl<indexing_multi_axis_vec_kdef::OprFwd>{
src, dst, index, workspace, info, cuda_stream(handle())}();
}
size_t IndexingSetMultiAxisVecImpl::get_workspace_in_bytes(size_t value_idx_size) {
return value_idx_size * sizeof(int);
}
void IndexingSetMultiAxisVecImpl::exec(
_megdnn_tensor_inout data, _megdnn_tensor_in value, const IndexDesc& index,
_megdnn_workspace workspace) {
auto info = check_exec(data.layout, value.layout, index, workspace.size);
info.error_tracker = m_error_tracker;
info.error_info = async_error_info(handle());
ExecImpl<indexing_multi_axis_vec_kdef::OprSet>{
data, value, index, workspace, info, cuda_stream(handle())}();
}
size_t IndexingIncrMultiAxisVecImpl::get_workspace_in_bytes(size_t value_idx_size) {
return value_idx_size * sizeof(int);
}
void IndexingIncrMultiAxisVecImpl::exec(
_megdnn_tensor_inout data, _megdnn_tensor_in value, const IndexDesc& index,
_megdnn_workspace workspace) {
auto info = check_exec(data.layout, value.layout, index, workspace.size);
info.error_tracker = m_error_tracker;
info.error_info = async_error_info(handle());
ExecImpl<OprAtomicIncr>{data, value, index,
workspace, info, cuda_stream(handle())}();
}