#include "megdnn/oprs/general.h"
#include "src/common/utils.h"
namespace megdnn {
void MeshBase::check_exec(
const TensorLayout& origin, const TensorLayout& indexed,
const IndexDesc& desc) {
megdnn_assert(origin.dtype == indexed.dtype);
megdnn_assert(origin.ndim == indexed.ndim);
for (auto&& index : desc) {
megdnn_assert(index.vec.layout.dtype == dtype::Int32());
}
}
void NormalMeshBase::check_exec(
const TensorLayout& src, const TensorLayout& dst, const IndexDesc& desc) {
MeshBase::check_exec(src, dst, desc);
for (auto&& index : desc) {
size_t ndim = index.vec.layout.ndim;
megdnn_assert(ndim == 1, "index must be 1-dim vector, while dim %zu", ndim);
megdnn_assert(dst.shape[index.axis] == index.vec.layout[0]);
}
}
void BatchedMeshBase::check_exec(
const TensorLayout& src, const TensorLayout& dst, const IndexDesc& desc) {
MeshBase::check_exec(src, dst, desc);
megdnn_assert(src[0] == dst[0], "batch mismatch, src %zu, dst %zu", src[0], dst[0]);
for (auto&& index : desc) {
size_t ndim = index.vec.layout.ndim;
megdnn_assert(ndim == 2, "index must be a 2-dim matrix, while ndim %zu", ndim);
megdnn_assert(
dst[0] == index.vec.layout[0] && dst[index.axis] == index.vec.layout[1],
"require each index shape equals (%zu, %zu), but got "
"(%zu, %zu)",
dst[0], dst[index.axis], index.vec.layout[0], index.vec.layout[1]);
megdnn_assert(
index.axis != 0,
"index axis should be 0-th dim when executing "
"BatchedMeshIndexing");
}
}
void MeshIndexing::deduce_layout(
const TensorLayout& inp, const IndexDescLayoutOnly& layouts,
TensorLayout& out_layout) {
out_layout = inp;
for (auto&& index : layouts) {
megdnn_assert(
index.layout.ndim == 1,
"mesh indexing require index being 1-dim vector");
out_layout[index.axis] = index.layout[0];
}
out_layout.init_contiguous_stride();
}
void BatchedMeshIndexing::deduce_layout(
const TensorLayout& inp, const IndexDescLayoutOnly& layouts,
TensorLayout& out_layout) {
out_layout = inp;
for (auto&& index : layouts) {
megdnn_assert(
index.layout.ndim == 2,
"batch mesh indexing require index being 2-dim matrix");
out_layout[index.axis] = index.layout[1];
}
out_layout.init_contiguous_stride();
}
}