#include "megdnn/oprs.h"
#include "src/common/utils.h"
using namespace megdnn;
namespace {
size_t get_index_size_for_workspace(
const TensorShape& shp, const size_t* axes, size_t nr_axes, size_t idx_ndim) {
size_t idx_axis = axes[0];
megdnn_assert(shp.ndim && nr_axes);
for (size_t i = 1; i < nr_axes; ++i) {
megdnn_assert(axes[i] > axes[i - 1]);
if (axes[i] != axes[i - 1] + 1) {
idx_axis = 0;
break;
}
}
megdnn_assert(
shp.ndim > idx_axis, "index on the %zuth axis; but shape is %s", idx_axis,
shp.to_string().c_str());
size_t idx_size = 1;
for (size_t i = 0; i < idx_ndim; ++i) {
idx_size *= shp.shape[idx_axis + i];
}
return idx_size;
}
}
IndexingMultiAxisVecBase::IndexDescLayoutOnly IndexingMultiAxisVecBase::
extract_index_layout(const IndexDesc& index) {
IndexDescLayoutOnly ret(index.size());
for (size_t i = 0; i < index.size(); ++i) {
ret[i].layout = index[i].vec.layout;
ret[i].axis = index[i].axis;
}
return ret;
}
size_t IndexingMultiAxisVecBase::deduce_layout_fwd(
const TensorLayout& data, const IndexDescLayoutOnly& index, TensorLayout& dst) {
megdnn_assert(!index.empty());
megdnn_assert(data.ndim >= index.size());
dst.ndim = data.ndim - index.size();
dst.dtype = data.dtype;
TensorShapeArray index_shapes;
auto brdcast = [&](const TensorLayout& ly) {
megdnn_assert(ly.dtype == dtype::Int32{});
index_shapes.push_back(ly);
};
size_t dst_axis = 0;
ptrdiff_t prev_axis = -1;
for (size_t axis = 0; axis < index.size(); ++axis) {
auto&& idx = index[axis];
megdnn_assert(
idx.layout.dtype == dtype::Int32(), "invalid index dtype: %s",
idx.layout.dtype.name());
megdnn_assert(
idx.axis<data.ndim&& static_cast<ptrdiff_t>(idx.axis)> prev_axis,
"index %zu requests invalid axis %zu", axis, idx.axis);
brdcast(idx.layout);
for (size_t i = prev_axis + 1; i < idx.axis; ++i) {
dst.shape[dst_axis++] = data.shape[i];
}
prev_axis = idx.axis;
}
for (size_t i = prev_axis + 1; i < data.ndim; ++i) {
dst.shape[dst_axis++] = data.shape[i];
}
megdnn_assert(dst_axis == dst.ndim);
size_t idx_axis = 0;
{
bool contig_idx = true;
for (size_t i = 1; i < index.size(); ++i) {
if (index[i].axis != index[i - 1].axis + 1) {
contig_idx = false;
break;
}
}
if (contig_idx) {
idx_axis = index[0].axis;
}
}
TensorShape index_shape;
Elemwise::deduce_shape(index_shapes, index_shape);
for (size_t i = 0; i < index_shape.ndim; ++i) {
dst.add_axis_inplace(idx_axis + i, 1, 0);
dst.shape[idx_axis + i] = index_shape.shape[i];
}
dst.init_contiguous_stride();
return idx_axis;
}
size_t IndexingMultiAxisVecBase::get_nonindex_axes(
size_t src_ndim, const IndexDesc& index, size_t* out) {
auto iter = index.begin();
size_t nr = 0;
for (size_t i = 0; i < src_ndim; ++i) {
if (iter != index.end() && i == iter->axis) {
++iter;
} else {
out[nr++] = i;
}
}
megdnn_assert(nr + index.size() == src_ndim && iter == index.end());
return nr;
}
IndexingMultiAxisVecBase::ExecInfo IndexingMultiAxisVecBase::check_exec_noworkspace(
const TensorLayout& data, const TensorLayout& value, const IndexDesc& index,
IndexDescLayoutOnly& index_layout) {
ExecInfo ret;
index_layout = extract_index_layout(index);
TensorLayout value_expect;
ret.idx_axis = deduce_layout_fwd(data, index_layout, value_expect);
megdnn_assert_eq_shape(value_expect, value);
auto value_contig = value.collapse_contiguous();
megdnn_assert(
value_contig.ndim == 1, "value layout must be 1-dim contiguous; got %s",
value.to_string().c_str());
ret.value_stride = value_contig.stride[0];
return ret;
}
std::tuple<TensorLayout, size_t, TensorShape> IndexingMultiAxisVecBase::
get_value_iter_optimized_layout(
const TensorLayout& data, const TensorLayout& value,
const IndexDesc& index, size_t idx_axis) {
size_t data_axes[TensorLayout::MAX_NDIM],
nr_axes = get_nonindex_axes(data.ndim, index, data_axes);
TensorLayout index_shape;
{
TensorShapeArray index_shapes;
for (auto& idx : index) {
megdnn_assert(idx.vec.layout.dtype == dtype::Int32{});
index_shapes.push_back(idx.vec.layout);
}
Elemwise::deduce_shape(index_shapes, index_shape);
}
megdnn_assert(
nr_axes == value.ndim - index_shape.ndim && idx_axis < value.ndim &&
nr_axes + index.size() == data.ndim);
TensorLayout ret;
if (idx_axis) {
ret.ndim = idx_axis;
for (size_t i = 0; i < idx_axis; ++i) {
ret.shape[i] = data.shape[data_axes[i]];
ret.stride[i] = data.stride[data_axes[i]];
}
ret = ret.collapse_contiguous();
}
size_t ret_idx_axis = ret.ndim;
for (size_t i = 0; i < index_shape.ndim; ++i) {
ret.shape[ret.ndim] = value.shape[idx_axis + i];
ret.stride[ret.ndim] = 0;
++ret.ndim;
}
if (idx_axis < nr_axes) {
TensorLayout tail;
tail.ndim = nr_axes - idx_axis;
for (size_t i = idx_axis; i < nr_axes; ++i) {
tail.shape[i - idx_axis] = data.shape[data_axes[i]];
tail.stride[i - idx_axis] = data.stride[data_axes[i]];
}
tail = tail.collapse_contiguous();
for (size_t i = 0; i < tail.ndim; ++i) {
ret.shape[ret.ndim] = tail.shape[i];
ret.stride[ret.ndim] = tail.stride[i];
++ret.ndim;
}
}
return std::make_tuple(ret, ret_idx_axis, index_shape);
}
size_t IndexingMultiAxisVec::get_workspace_in_bytes(
const TensorShape& dst, const size_t* axes, size_t nr_axes, size_t idx_ndim) {
return get_workspace_in_bytes(
get_index_size_for_workspace(dst, axes, nr_axes, idx_ndim));
}
IndexingMultiAxisVec::ExecInfo IndexingMultiAxisVec::check_exec(
const TensorLayout& src, const IndexDesc& index, const TensorLayout& dst,
size_t workspace_in_bytes) {
IndexDescLayoutOnly index_layout;
auto ret = check_exec_noworkspace(src, dst, index, index_layout);
megdnn_assert(
workspace_in_bytes >= get_workspace_in_bytes(dst.shape[ret.idx_axis]));
megdnn_assert(ret.value_stride, "dst must be non-overlapping");
return ret;
}
size_t IndexingModifyMultiAxisVecBase::get_workspace_in_bytes(
const TensorShape& value, const size_t* axes, size_t nr_axes, size_t idx_ndim) {
return get_workspace_in_bytes(
get_index_size_for_workspace(value, axes, nr_axes, idx_ndim));
}
IndexingModifyMultiAxisVecBase::ExecInfo IndexingModifyMultiAxisVecBase::check_exec(
const TensorLayout& data, const TensorLayout& value, const IndexDesc& index,
size_t workspace_in_bytes) {
megdnn_assert(
data.is_non_overlapping_strong(), "data layout should not overlap: %s",
data.to_string().c_str());
IndexDescLayoutOnly index_layout;
auto ret = check_exec_noworkspace(data, value, index, index_layout);
megdnn_assert(
workspace_in_bytes >= get_workspace_in_bytes(value.shape[ret.idx_axis]));
return ret;
}