#include "./opr_impl.h"
#include "megdnn/tensor_iter.h"
#include "src/naive/handle.h"
#include "src/common/indexing_multi_axis_vec_kdef.h"
#include "src/common/utils.h"
using namespace megdnn;
using namespace naive;
namespace {
template <typename data_type, class Opr, typename idx_type = dt_int32>
void do_exec(
const TensorND& data, const TensorND& value,
const IndexingMultiAxisVec::IndexDesc& index,
const IndexingMultiAxisVec::ExecInfo& exec_info) {
size_t nonidx_axes[TensorLayout::MAX_NDIM],
nr_nonidx_axes = IndexingMultiAxisVec::get_nonindex_axes(
data.layout.ndim, index, nonidx_axes);
auto data_layout = data.layout;
auto data_ptr = data.ptr<data_type>();
std::tuple<size_t, const idx_type*, TensorLayout> index_raw[TensorLayout::MAX_NDIM];
size_t nr_index = index.size();
TensorShape idx_shape;
{
TensorShapeArray idx_shapes;
for (size_t i = 0; i < nr_index; ++i) {
idx_shapes.push_back(index[i].vec.layout);
}
Elemwise::deduce_shape(idx_shapes, idx_shape);
}
for (size_t i = 0; i < nr_index; ++i) {
auto&& s = index[i];
index_raw[i] = std::make_tuple(
s.axis, s.vec.ptr<idx_type>(), s.vec.layout.broadcast(idx_shape));
}
auto value_iter = tensor_iter<data_type>(value).begin();
for (size_t _ = 0, _t = value.layout.total_nr_elems(); _ < _t; ++_) {
ptrdiff_t offset = 0;
auto* index_idx = value_iter.idx() + exec_info.idx_axis;
for (size_t i = 0; i < nr_index; ++i) {
size_t axis = std::get<0>(index_raw[i]),
data_shape = data_layout.shape[axis];
ptrdiff_t data_stride = data_layout.stride[axis];
size_t index_offset = 0;
TensorLayout& index_layout = std::get<2>(index_raw[i]);
for (size_t i = 0; i < index_layout.ndim; ++i) {
index_offset += index_idx[i] * index_layout.stride[i];
}
idx_type data_idx = std::get<1>(index_raw[i])[index_offset];
if (data_idx < 0)
data_idx += data_shape;
megdnn_assert(
data_idx >= 0 && static_cast<size_t>(data_idx) < data_shape,
"bad index value for index %zu at output %zu", i, *index_idx);
offset += data_stride * data_idx;
}
for (size_t i = 0; i < nr_nonidx_axes; ++i) {
auto stride = data_layout.stride[nonidx_axes[i]];
auto idx = value_iter.idx()[i + (i >= exec_info.idx_axis) * idx_shape.ndim];
offset += stride * idx;
}
Opr::apply(data_ptr[offset], *value_iter);
++value_iter;
}
}
template <class Opr>
void dispatch_exec(
HandleImpl* handle, const TensorND& data, const TensorND& value,
const IndexingMultiAxisVec::IndexDesc& index,
const IndexingMultiAxisVec::ExecInfo& exec_info) {
#define cb(_dt) \
case DTypeTrait<_dt>::enumv: { \
MEGDNN_DISPATCH_CPU_KERN( \
handle, do_exec<DTypeTrait<_dt>::ctype MEGDNN_COMMA Opr>( \
data, value, index, exec_info)); \
return; \
}
switch (data.layout.dtype.enumv()) {
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool) default : megdnn_throw("bad dtype");
}
#undef cb
}
}
void IndexingMultiAxisVecImpl::exec(
_megdnn_tensor_in src, const IndexDesc& index, _megdnn_tensor_out dst,
_megdnn_workspace workspace) {
auto info = check_exec(src.layout, index, dst.layout, workspace.size);
dispatch_exec<indexing_multi_axis_vec_kdef::OprFwd>(
static_cast<HandleImpl*>(handle()), src, dst, index, info);
}
void IndexingSetMultiAxisVecImpl::exec(
_megdnn_tensor_in data, _megdnn_tensor_out value, const IndexDesc& index,
_megdnn_workspace workspace) {
auto info = check_exec(data.layout, value.layout, index, workspace.size);
dispatch_exec<indexing_multi_axis_vec_kdef::OprSet>(
static_cast<HandleImpl*>(handle()), data, value, index, info);
}
void IndexingIncrMultiAxisVecImpl::exec(
_megdnn_tensor_in data, _megdnn_tensor_out value, const IndexDesc& index,
_megdnn_workspace workspace) {
auto info = check_exec(data.layout, value.layout, index, workspace.size);
dispatch_exec<indexing_multi_axis_vec_kdef::OprIncr>(
static_cast<HandleImpl*>(handle()), data, value, index, info);
}