megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file dnn/src/rocm/relayout/relayout_contiguous.h.hip
 *
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once

#include "hip_header.h"
#include "megdnn/basic_types.h"
#include "src/rocm/int_fastdiv.h.hip"
#include "src/rocm/elemwise_helper.h.hip"

#include <stdio.h>

namespace megdnn {
namespace rocm {

void copy_last_contiguous(const TensorND &dst, const TensorND &src,
                          size_t contiguous_size, hipStream_t stream);


void get_last_contiguous_launch_spec(const void *kern, size_t size,
                                     size_t contiguous_size, int *grid_size,
                                     int *block_size);

//! internals for contiguous
namespace contiguous_intl {

#define devfunc __device__ __forceinline__
    /*!
     * \brief contiguous type
     * If the layout is contiguous, then the type is CONTIG_FULL, CONTIG_OTHER
     * otherwise.
     */
    enum ContigType {
        CONTIG_OTHER,
        CONTIG_FULL
    };

    /*!
    * \brief visitor to access an element in a tensor at given logic index
    * \tparam ctype plain element ctype (i.e. ctype in DTypeTrait)
    * \tparam contig_mask bit mask for contig of params;
    *
    * host interface:
    *      void host_init(
    *              const TensorND &tensor, int grid_size, int block_size)
    *
    * device interface:
    *      void thread_init(uint32_t idx)
    *          called on thread entrance, with logical indexing; the index
    y
    *          go beyond buffer range
    *
    *      ctype* ptr()
    *          return buffer pointer; can be used by specialized OpCaller
    *
    *      int offset(uint32_t idx)
    *          get physical offset from logical index
    *
    *      ctype& at(uint32_t idx)
    *          ptr()[offset(idx)]
    *
    */
    template<int ndim, typename ctype, ContigType contig_type>
    class ParamElemVisitor;

    /* f{{{ ParamElemVisitor specialization */

#define PARAM_ELEM_VISITOR_COMMON_DEV \
            devfunc ctype* ptr() { \
                return m_ptr; \
            } \
            devfunc ctype& at(uint32_t idx) { \
                return m_ptr[offset(idx)]; \
            } \

    //! specialization for CONTIG_OTHER
    template<int ndim, typename ctype>
    class ParamElemVisitor<ndim, ctype, CONTIG_OTHER> {
        ctype * __restrict m_ptr;
        int m_stride[ndim];

        //! m_shape_highdim[i] = original_shape[i + 1]
#ifdef _MSC_VER
        Uint32Fastdiv m_shape_highdim[ndim > 1 ? ndim - 1 : 1];
#else
        Uint32Fastdiv m_shape_highdim[ndim - 1];
#endif

        public:
            static const int NDIM = ndim;

            void host_init(const TensorND &rv, int grid_size, int block_size);

#if MEGDNN_CC_CUDA
            devfunc void thread_init(uint32_t) {
            }

            devfunc void next() {
            }

            devfunc int offset(uint32_t idx) {
                int offset = 0;
#pragma unroll
                for (int i = ndim - 1; i >= 1; -- i) {
                    Uint32Fastdiv &shp = m_shape_highdim[i - 1];
                    uint32_t idx_div = idx / shp;
                    offset += (idx - idx_div * shp.divisor()) * m_stride[i];
                    idx = idx_div;
                }
                offset += idx * m_stride[0];
                return offset;
            }

            PARAM_ELEM_VISITOR_COMMON_DEV
#endif
    };

    //! specialization for CONTIG_FULL
    template<int ndim, typename ctype>
    class ParamElemVisitor<ndim, ctype, CONTIG_FULL> {
        ctype * __restrict m_ptr;

        public:
            static const int NDIM = ndim;

            void host_init(const TensorND &rv, int grid_size, int block_size);

#if MEGDNN_CC_CUDA
            devfunc void thread_init(uint32_t) {
            }

            devfunc void next() {
            }

            devfunc int offset(uint32_t idx) {
                return idx;
            }

            PARAM_ELEM_VISITOR_COMMON_DEV
#endif

    };

#undef PARAM_ELEM_VISITOR_COMMON_DEV

    /* f}}} */ 

    template <class PVis0, class PVis1>
    struct OpCallerBinaryContiguous {
        PVis0 par0;
        PVis1 par1;
    };

    /* f{{{ hip  kern */

#if MEGDNN_CC_CUDA
    /*!
     * \brief hip kern for the last axis stride is contiguous and
     *     contiguous_size is small
     */
    template <typename OpCaller>
    __global__ void hip_last_contiguous_kern(OpCaller op_caller,
                                              uint32_t contiguous_size,
                                              uint32_t size) {
        uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
        idx *= contiguous_size;
        if (idx >= size) return;

        size_t remain_size = size - idx;
        remain_size =
            remain_size > contiguous_size ? contiguous_size : remain_size;

        int offset0 = op_caller.par0.offset(idx);
        int offset1 = op_caller.par1.offset(idx);

    #pragma unroll
        for (size_t i = 0; i < remain_size; ++i) {
            op_caller.par0.ptr()[offset0++] = op_caller.par1.ptr()[offset1++];
        }
    }

    /*!
     * \brief hip kern for the last axis stride is contiguous and
     *     contiguous_size is large
     *
     * \param op_caller user op caller
     * \param contiguous_size the size of last contiguous elements
     * \param size total number of elements
     * \param contig_size_each_block the contiguous elements visited in each
     *     block
     * \param contig_block_size the block size used to visit the contiguous
     *     element
     */
    template <typename OpCaller>
    __global__ void hip_last_contiguous_large_kern(
            OpCaller op_caller, uint32_t contiguous_size, uint32_t size,
            uint32_t contig_size_each_block, uint32_t contig_block_size) {
        // Every block manipulate a sub contiguous elements
        //! The \p contig_idx contiguous elements
        uint32_t contig_idx = blockIdx.x / contig_block_size;
        //! The idx in the current contiguous elemwise
        uint32_t contig_block_idx = blockIdx.x - contig_idx * contig_block_size;
        uint32_t idx = contig_idx * contiguous_size + contig_block_idx *
            contig_size_each_block;
        if (idx >= size) return;

        uint32_t remain = contiguous_size - contig_block_idx *
            contig_size_each_block;
        if (remain > contig_size_each_block) remain = contig_size_each_block;

        uint32_t physical_idx0 = op_caller.par0.offset(idx);
        uint32_t physical_idx1 = op_caller.par1.offset(idx);

        int i = threadIdx.x;
        while (i < remain) {
            op_caller.par0.ptr()[physical_idx0 + i] =
                op_caller.par1.ptr()[physical_idx1 + i];
            i += blockDim.x;
        }
    }

    /* f}}} */


#define DEFINE_CONTIG_RECEIVER(_ndim, _cb_header, _cb_dispatch, _layout) \
    _cb_header(_ndim) { \
        if (_layout.is_contiguous()) { \
            return _cb_dispatch(_ndim, CONTIG_FULL); \
        } \
        return _cb_dispatch(_ndim, CONTIG_OTHER); \
    } \

    //! invoke a user Op passed to run_elemwise
    template<typename ctype, int arity>
    class UserOpInvoker;

    /* f{{{ UserOpInvoker specializations */

    //! specialization for binary opr
    template<typename ctype>
    class UserOpInvoker<ctype, 2> {
        bool m_invoked;
        const ElemwiseOpParamN<2> &m_param;
        hipStream_t m_stream;
        const size_t m_contiguous_size;

        void dispatch0() {
            switch(m_param[0].layout.ndim) {
#define cb(ndim) \
                case ndim: return dispatch1_##ndim();
                MEGDNN_FOREACH_TENSOR_NDIM(cb)
#undef cb
            }
        }

#define cb_header(ndim) void dispatch1_##ndim()
#define cb_dispatch(ndim, contig_mask) \
        dispatch2<ParamElemVisitor<ndim, ctype, contig_mask> >()
DEFINE_CONTIG_RECEIVER(1, cb_header, cb_dispatch, m_param[0].layout)
DEFINE_CONTIG_RECEIVER(2, cb_header, cb_dispatch, m_param[0].layout)
DEFINE_CONTIG_RECEIVER(3, cb_header, cb_dispatch, m_param[0].layout)
DEFINE_CONTIG_RECEIVER(4, cb_header, cb_dispatch, m_param[0].layout)
DEFINE_CONTIG_RECEIVER(5, cb_header, cb_dispatch, m_param[0].layout)
DEFINE_CONTIG_RECEIVER(6, cb_header, cb_dispatch, m_param[0].layout)
DEFINE_CONTIG_RECEIVER(7, cb_header, cb_dispatch, m_param[0].layout)
#undef cb_header
#undef cb_dispatch


        template<class PVis0>
        void dispatch2() {
            switch(m_param[1].layout.ndim) {
#define cb(ndim) \
                case ndim: return dispatch3_##ndim<PVis0>();
                MEGDNN_FOREACH_TENSOR_NDIM(cb)
#undef cb
            }
        }

#define cb_header(ndim) \
    template<class PVis0> \
    void dispatch3_##ndim()
#define cb_dispatch(ndim, contig_mask) \
        do_run<PVis0, ParamElemVisitor<ndim, ctype, contig_mask> >()
DEFINE_CONTIG_RECEIVER(1, cb_header, cb_dispatch, m_param[1].layout)
DEFINE_CONTIG_RECEIVER(2, cb_header, cb_dispatch, m_param[1].layout)
DEFINE_CONTIG_RECEIVER(3, cb_header, cb_dispatch, m_param[1].layout)
DEFINE_CONTIG_RECEIVER(4, cb_header, cb_dispatch, m_param[1].layout)
DEFINE_CONTIG_RECEIVER(5, cb_header, cb_dispatch, m_param[1].layout)
DEFINE_CONTIG_RECEIVER(6, cb_header, cb_dispatch, m_param[1].layout)
DEFINE_CONTIG_RECEIVER(7, cb_header, cb_dispatch, m_param[1].layout)
#undef cb_header
#undef cb_dispatch

        template<class PVis0, class PVis1>
        void do_run() {
            megdnn_assert(!m_invoked);
            m_invoked = true;
            typedef OpCallerBinaryContiguous<PVis0, PVis1> Caller;
            size_t size = m_param.size;
            int grid_size, block_size;
            if (m_contiguous_size > 32) {
                void (*fptr)(Caller, uint32_t, uint32_t, uint32_t, uint32_t);
                fptr = hip_last_contiguous_large_kern<Caller>;
                safe_size_in_kern(size);
                block_size = m_contiguous_size;
                if (block_size > 256) {
                    block_size = 256;
                } else if (block_size > 128) {
                    block_size = 128;
                } else if (block_size > 64) {
                    block_size = 64;
                } else {
                    block_size = 32;
                }

                const uint32_t MAX_CONTIG_SIZE_EACH_BLOCK = 1024;
                uint32_t contig_block_size = 1;
                uint32_t contig_size_each_block = m_contiguous_size;
                if (m_contiguous_size > MAX_CONTIG_SIZE_EACH_BLOCK) {
                    contig_size_each_block = MAX_CONTIG_SIZE_EACH_BLOCK;
                    contig_block_size =
                            (m_contiguous_size + contig_size_each_block - 1) /
                            contig_size_each_block;
                }
                grid_size = (size + m_contiguous_size - 1) / m_contiguous_size *
                            contig_block_size;
                Caller caller;
                caller.par0.host_init(m_param[0], grid_size, block_size);
                caller.par1.host_init(m_param[1], grid_size, block_size);
                hipLaunchKernelGGL(fptr, dim3(grid_size), dim3(block_size), 0,
                                   m_stream, caller, m_contiguous_size, size,
                                   contig_size_each_block, contig_block_size);
            } else {
                void (*fptr)(Caller, uint32_t, uint32_t);
                fptr = hip_last_contiguous_kern<Caller>;
                get_last_contiguous_launch_spec(
                        reinterpret_cast<const void *>(fptr),
                        size, m_contiguous_size, &grid_size, &block_size);
                Caller caller;
                caller.par0.host_init(m_param[0], grid_size, block_size);
                caller.par1.host_init(m_param[1], grid_size, block_size);
                hipLaunchKernelGGL(fptr, dim3(grid_size), dim3(block_size), 0,
                                   m_stream, caller, m_contiguous_size, size);
            }
            after_kernel_launch();
        }

        public:
            UserOpInvoker(const ElemwiseOpParamN<2> &param, hipStream_t stream,
                    const size_t contiguous_size):
                m_param(param), m_stream(stream),
                m_contiguous_size(contiguous_size)
            {
                m_invoked = false;
                dispatch0();
                megdnn_assert(m_invoked);
            }
    };

#undef DEFINE_CONTIG_RECEIVER

    /* f}}} */ 

#endif

#undef devfunc

} // namespace contiguous_intl

} // rocm
} // megdnn

// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}