megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file dnn/src/cuda/relayout/param_visitor.cuh
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */

#include "megdnn/basic_types.h"
#include "src/cuda/int_fastdiv.cuh"
#include "src/cuda/integer_subbyte_utils.cuh"
#include "src/cuda/utils.cuh"

#pragma once

namespace megdnn {
namespace cuda {
#define devfunc __device__ __forceinline__

/*!
 * \brief contiguous type
 * If the layout is contiguous, then the type is CONTIG_FULL, CONTIG_OTHER
 * otherwise.
 */
enum ContigType { CONTIG_OTHER, CONTIG_FULL };

/* f{{{ ParamElemVisitor specialization */
/*!
* \brief visitor to access an element in a tensor at given logic index
* \tparam ctype plain element ctype (i.e. ctype in DTypeTrait)
* \tparam contig_mask bit mask for contig of params;
*
* host interface:
*      void host_init(
*              const TensorND &tensor, int grid_size, int block_size)
*
* device interface:
*      void thread_init(uint32_t idx)
*          called on thread entrance, with logical indexing; the index
y
*          go beyond buffer range
*
*      ctype* ptr()
*          return buffer pointer; can be used by specialized OpCaller
*
*      int offset(uint32_t idx)
*          get physical offset from logical index
*
*      ctype& at(uint32_t idx)
*          ptr()[offset(idx)]
*
*/
template <int ndim, typename ctype, ContigType contig_type>
class ParamElemVisitor;
#define PARAM_ELEM_VISITOR_COMMON_DEV      \
    devfunc ctype* ptr() { return m_ptr; } \
    devfunc ctype& at(uint32_t idx) { return m_ptr[offset(idx)]; }

//! specialization for CONTIG_OTHER
template <int ndim, typename ctype>
class ParamElemVisitor<ndim, ctype, CONTIG_OTHER> {
    ctype* __restrict m_ptr;
    int m_stride[ndim];

    //! m_shape_highdim[i] = original_shape[i + 1]
#ifdef _MSC_VER
    Uint32Fastdiv m_shape_highdim[ndim > 1 ? ndim - 1 : 1];
#else
    Uint32Fastdiv m_shape_highdim[ndim - 1];
#endif

public:
    static const int NDIM = ndim;

    void host_init(const TensorND& rv, int grid_size, int block_size);

#if MEGDNN_CC_CUDA
    devfunc void thread_init(uint32_t) {}

    devfunc void next() {}

    devfunc int offset(uint32_t idx) {
        int offset = 0;
#pragma unroll
        for (int i = ndim - 1; i >= 1; --i) {
            Uint32Fastdiv& shp = m_shape_highdim[i - 1];
            uint32_t idx_div = idx / shp;
            offset += (idx - idx_div * shp.divisor()) * m_stride[i];
            idx = idx_div;
        }
        offset += idx * m_stride[0];
        return offset;
    }

    PARAM_ELEM_VISITOR_COMMON_DEV
#endif
};

//! specialization for CONTIG_FULL
template <int ndim, typename ctype>
class ParamElemVisitor<ndim, ctype, CONTIG_FULL> {
    ctype* __restrict m_ptr;

public:
    static const int NDIM = ndim;

    void host_init(const TensorND& rv, int grid_size, int block_size);

#if MEGDNN_CC_CUDA
    devfunc void thread_init(uint32_t) {}

    devfunc void next() {}

    devfunc int offset(uint32_t idx) { return idx; }

    PARAM_ELEM_VISITOR_COMMON_DEV
#endif
};

#undef PARAM_ELEM_VISITOR_COMMON_DEV

template <int ndim>
class ParamElemVisitor<ndim, dt_quint4, CONTIG_OTHER> {
    using Storage = uint8_t;

    Storage* __restrict m_ptr;
    int m_stride[ndim];
    int m_shape[ndim];
    bool m_is_contiguous;
    bool m_is_physical_contiguous;
    bool m_is_min_stride_2;

    //! m_shape_highdim[i] = original_shape[i + 1]
#ifdef _MSC_VER
    Uint32Fastdiv m_shape_highdim[ndim > 1 ? ndim - 1 : 1];
    Uint32Fastdiv m_align_shape_highdim[ndim > 1 ? ndim - 1 : 1];
#else
    Uint32Fastdiv m_shape_highdim[ndim];
    Uint32Fastdiv m_align_shape_highdim[ndim];
#endif

public:
    static const Storage kMask = 0xf;
    static const Storage kBits = 4;
    static const int NDIM = ndim;
    void host_init(const TensorND& rv, int grid_size, int block_size);

#if MEGDNN_CC_CUDA
    devfunc void thread_init(uint32_t) {}

    devfunc void next() {}

    devfunc void get_shape_from_access(uint32_t access_idx, int (&shape_idx)[ndim]) {
#pragma unroll
        for (int i = ndim - 1; i >= 1; --i) {
            Uint32Fastdiv& align_shp = m_align_shape_highdim[i - 1];
            uint32_t access_idx_div = access_idx / align_shp;
            shape_idx[i] = access_idx - access_idx_div * align_shp.divisor();
            access_idx = access_idx_div;
        }
        shape_idx[0] = access_idx;
    }

    devfunc int offset(uint32_t idx) {
        int offset = 0;
#pragma unroll
        for (int i = ndim - 1; i >= 1; --i) {
            Uint32Fastdiv& shp = m_shape_highdim[i - 1];
            uint32_t idx_div = idx / shp;
            offset += (idx - idx_div * shp.divisor()) * m_stride[i];
            idx = idx_div;
        }
        offset += idx * m_stride[0];
        return offset;
    }

    devfunc int offset_from_access(uint32_t access_idx) {
        int offset = 0;
        if (m_is_contiguous) {
            offset = access_idx;
        } else {
            int shape_idx[ndim];
            get_shape_from_access(access_idx, shape_idx);
#pragma unroll
            for (int i = ndim - 1; i >= 0; --i) {
                offset += shape_idx[i] * m_stride[i];
            }
        }
        return offset;
    }

    devfunc int idx(uint32_t access_idx) {
        int idx = 0;
        if (m_is_physical_contiguous) {
            idx = access_idx;
        } else if (!m_is_min_stride_2) {
            int shape_idx[ndim];
            bool valid = true;
            get_shape_from_access(access_idx, shape_idx);
#pragma unroll
            for (int i = 0; i < ndim; ++i) {
                valid &= (shape_idx[i] < m_shape[i]);
            }
            for (int i = 0; i < ndim - 1; ++i) {
                idx = (idx + shape_idx[i]) * m_shape[i + 1];
            }
            idx = valid ? idx + shape_idx[ndim - 1] : -1;
        } else {  // min_stride == 2
            idx = ((access_idx & 0x1) == 0) ? ((int)access_idx >> 1) : -1;
        }
        return idx;
    }
    devfunc Storage* ptr() { return m_ptr; }

    devfunc Storage at(uint32_t idx) {
        int offset_ = offset(idx);
        int vec_idx = offset_ >> 1;
        int lane_idx = offset_ & 0x1;

        Storage item = Storage(integer_subbyte::unpack_integer_4bits<false>(
                *(Storage*)&m_ptr[vec_idx], lane_idx * 4));

        return item;
    }

    using rwtype = typename elemwise_intl::VectTypeTrait<dt_quint4>::vect_type;

    devfunc rwtype make_vector(Storage x, Storage y) {
        return elemwise_intl::VectTypeTrait<dt_quint4>::make_vector(x, y);
    }
#endif
};

}  // namespace cuda
}  // namespace megdnn

// vim: ft=cpp syntax=cpp.doxygen