megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file dnn/src/cuda/mesh_indexing/mesh_indexing.cuh
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#pragma once

#include <cstdio>
#include "megdnn/basic_types.h"
#include "src/cuda/error_info.cuh"
#include "src/cuda/utils.cuh"

namespace megdnn {
namespace cuda {
namespace mesh_indexing {

// template <int ndim>
struct KernIndexer {
    int ndim;
    int* ptrs[TensorShape::MAX_NDIM];
    int origin_stride[TensorShape::MAX_NDIM];
    int indexed_strde[TensorShape::MAX_NDIM];
    int desc_stride[TensorShape::MAX_NDIM];
    uint32_t indexed_shape[TensorShape::MAX_NDIM];
    uint32_t origin_shape[TensorShape::MAX_NDIM];

    void* error_tracker;
    megcore::AsyncErrorInfo* error_info;
    bool batch_mode;
    uint32_t batch_stride;
    uint32_t size;

    KernIndexer(
            const TensorLayout& origin_layout, const TensorLayout& indexed_layout,
            int** _ptrs, const TensorLayout* desc_layouts, void* _err_tracker = nullptr,
            megcore::AsyncErrorInfo* _err_info = nullptr, bool _batch_mode = false)
            : error_tracker(_err_tracker),
              error_info(_err_info),
              batch_mode(_batch_mode),
              size(indexed_layout.total_nr_elems()) {
        ndim = origin_layout.ndim;
        for (int i = 0; i < ndim; ++i) {
            origin_stride[i] = origin_layout.stride[i];
            indexed_strde[i] = indexed_layout.stride[i];
            origin_shape[i] = origin_layout[i];
            indexed_shape[i] = indexed_layout[i];
            ptrs[i] = _ptrs[i];
            desc_stride[i] = desc_layouts[i].stride[0];
        }
    }

    int __device__ __forceinline__ convert_indxer(uint32_t& index) const {
        int data_offset = 0;
        int value_offset = 0;
        uint32_t n = 0;
        if (batch_mode) {
            n = index;
            for (int i = ndim - 1; i >= 1; --i) {
                n /= indexed_shape[i];
            }
            n %= indexed_shape[0];
        }
        for (int i = ndim - 1; i >= 0; --i) {
            int pos = index % indexed_shape[i];
            value_offset += pos * indexed_strde[i];
            if (ptrs[i]) {
                pos += n * desc_stride[i];
                pos = ptrs[i][pos];
                pos += (pos < 0 ? origin_shape[i] : 0);
            }
            if (static_cast<uint32_t>(pos) >= origin_shape[i]) {
                set_async_error_info(
                        error_info, error_tracker,
                        "invalid mesh indexing: "
                        "indexer=%d idx=%d shape=%d",
                        i, pos, origin_shape[i]);
            }
            data_offset += pos * origin_stride[i];
            index /= indexed_shape[i];
        }

        index = value_offset;
        return data_offset;
    }
};

template <typename T, class Opr>
void mesh_indexing_proxy(
        T* origin, T* indexed, KernIndexer* indexer, cudaStream_t stream);
}  // namespace mesh_indexing
}  // namespace cuda
}  // namespace megdnn