megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file dnn/src/common/mesh_indexing.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "megdnn/oprs/general.h"
#include "src/common/utils.h"

namespace megdnn {

/* ============================== MeshIndexing ============================= */

void MeshBase::check_exec(
        const TensorLayout& origin, const TensorLayout& indexed,
        const IndexDesc& desc) {
    megdnn_assert(origin.dtype == indexed.dtype);
    megdnn_assert(origin.ndim == indexed.ndim);
    for (auto&& index : desc) {
        megdnn_assert(index.vec.layout.dtype == dtype::Int32());
    }
}

void NormalMeshBase::check_exec(
        const TensorLayout& src, const TensorLayout& dst, const IndexDesc& desc) {
    MeshBase::check_exec(src, dst, desc);
    for (auto&& index : desc) {
        size_t ndim = index.vec.layout.ndim;
        megdnn_assert(ndim == 1, "index must be 1-dim vector, while dim %zu", ndim);
        megdnn_assert(dst.shape[index.axis] == index.vec.layout[0]);
    }
}

void BatchedMeshBase::check_exec(
        const TensorLayout& src, const TensorLayout& dst, const IndexDesc& desc) {
    MeshBase::check_exec(src, dst, desc);
    megdnn_assert(src[0] == dst[0], "batch mismatch, src %zu, dst %zu", src[0], dst[0]);
    for (auto&& index : desc) {
        size_t ndim = index.vec.layout.ndim;
        megdnn_assert(ndim == 2, "index must be a 2-dim matrix, while ndim %zu", ndim);
        megdnn_assert(
                dst[0] == index.vec.layout[0] && dst[index.axis] == index.vec.layout[1],
                "require each index shape equals (%zu, %zu), but got "
                "(%zu, %zu)",
                dst[0], dst[index.axis], index.vec.layout[0], index.vec.layout[1]);
        megdnn_assert(
                index.axis != 0,
                "index axis should be 0-th dim when executing "
                "BatchedMeshIndexing");
    }
}

void MeshIndexing::deduce_layout(
        const TensorLayout& inp, const IndexDescLayoutOnly& layouts,
        TensorLayout& out_layout) {
    out_layout = inp;
    for (auto&& index : layouts) {
        megdnn_assert(
                index.layout.ndim == 1,
                "mesh indexing require index being 1-dim vector");
        out_layout[index.axis] = index.layout[0];
    }
    out_layout.init_contiguous_stride();
}

void BatchedMeshIndexing::deduce_layout(
        const TensorLayout& inp, const IndexDescLayoutOnly& layouts,
        TensorLayout& out_layout) {
    out_layout = inp;
    for (auto&& index : layouts) {
        megdnn_assert(
                index.layout.ndim == 2,
                "batch mesh indexing require index being 2-dim matrix");
        out_layout[index.axis] = index.layout[1];
    }
    out_layout.init_contiguous_stride();
}

}  // namespace megdnn