megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file dnn/src/rocm/convolution/im2col.cpp.hip
 *
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#include "./im2col.h.hip"
#include "megdnn/dtype.h"
#include "src/rocm/utils.h.hip"

using namespace megdnn;
using namespace rocm;

namespace {

template <typename T>
__global__ void im2col_kernel(const T* im, T* col, uint32_t N, uint32_t INP_BS,
                              uint32_t IC, uint32_t IH, uint32_t IW,
                              uint32_t FH, uint32_t FW, uint32_t OH,
                              uint32_t OW, uint32_t PH, uint32_t PW,
                              uint32_t SH, uint32_t SW, uint32_t DH,
                              uint32_t DW) {
    uint32_t n = threadIdx.x + blockIdx.y * blockDim.x;
    uint32_t ow = threadIdx.y + blockIdx.z * blockDim.y;
    uint32_t oh = blockIdx.x % OH;
    uint32_t fw = blockIdx.x / OH % FW;
    uint32_t fh = blockIdx.x / OH / FW % FH;
    uint32_t ic = blockIdx.x / OH / FW / FH;
    if (n < N && ow < OW) {
        uint32_t didx = blockIdx.x * OW * N + ow * N + n;
        uint32_t ih = -PH + oh * SH + fh * DH;
        uint32_t iw = -PW + ow * SW + fw * DW;
        col[didx] = (ih < IH && iw < IW
                             ? im[n * INP_BS + ic * IH * IW + ih * IW + iw]
                             : T(0.0f));
    }
}

template <typename T>
__global__ void col2im_kernel(const T* col, T* im, uint32_t N, uint32_t INP_BS,
                              uint32_t IC, uint32_t IH, uint32_t IW,
                              uint32_t FH, uint32_t FW, uint32_t OH,
                              uint32_t OW, uint32_t PH, uint32_t PW,
                              uint32_t SH, uint32_t SW, uint32_t DH,
                              uint32_t DW) {
    uint32_t iw = threadIdx.x + blockIdx.y * blockDim.x;
    uint32_t ih = threadIdx.y + blockIdx.z * blockDim.y;
    uint32_t ic = blockIdx.x % IC;
    uint32_t n = blockIdx.x / IC;
    if (iw < IW && ih < IH) {
        T res(0);
        for (uint32_t fh = 0; fh < FH; ++fh) {
            uint32_t anchorh = ih + PH - fh * DH;
            if (anchorh < OH * SH && anchorh % SH == 0) {
                uint32_t oh = anchorh / SH;
                for (uint32_t fw = 0; fw < FW; ++fw) {
                    uint32_t anchorw = iw + PW - fw * DW;
                    if (anchorw < OW * SW && anchorw % SW == 0) {
                        uint32_t ow = anchorw / SW;
                        res += col[ic * FH * FW * OH * OW * N +
                                   fh * FW * OH * OW * N + fw * OH * OW * N +
                                   oh * OW * N + ow * N + n];
                    }
                }
            }
        }
        im[n * INP_BS + ic * IH * IW + ih * IW + iw] = res;
    }
}

}  // anonymous namespace

template <typename T>
void convolution::im2col(const T* im, T* col, size_t N, size_t INP_BS,
                         size_t IC, size_t IH, size_t IW, size_t FH, size_t FW,
                         size_t OH, size_t OW, size_t PH, size_t PW, size_t SH,
                         size_t SW, size_t DH, size_t DW, hipStream_t stream) {
    dim3 threads(NR_THREADS_X, NR_THREADS_Y);
    dim3 blocks(IC * FH * FW * OH, DIVUP(N, NR_THREADS_X),
                DIVUP(OW, NR_THREADS_Y));
    hipLaunchKernelGGL(im2col_kernel<T>, blocks, threads, 0, stream, im, col, N,
                       INP_BS, IC, IH, IW, FH, FW, OH, OW, PH, PW, SH, SW, DH,
                       DW);
    after_kernel_launch();
}

template <typename T>
void convolution::col2im(const T* col, T* im, size_t N, size_t INP_BS,
                         size_t IC, size_t IH, size_t IW, size_t FH, size_t FW,
                         size_t OH, size_t OW, size_t PH, size_t PW, size_t SH,
                         size_t SW, size_t DH, size_t DW, hipStream_t stream) {
    dim3 threads(NR_THREADS_X, NR_THREADS_Y);
    dim3 blocks(N * IC, DIVUP(IW, NR_THREADS_X), DIVUP(IH, NR_THREADS_Y));
    hipLaunchKernelGGL(col2im_kernel<T>, blocks, threads, 0, stream, col, im, N,
                       INP_BS, IC, IH, IW, FH, FW, OH, OW, PH, PW, SH, SW, DH,
                       DW);
    after_kernel_launch();
}

namespace megdnn {
namespace rocm {
namespace convolution {

#define DO_INST(T)                                                        \
    template void im2col<T>(const T* im, T* col, size_t N, size_t INP_BS, \
                            size_t IC, size_t IH, size_t IW, size_t FH,   \
                            size_t FW, size_t OH, size_t OW, size_t PH,   \
                            size_t PW, size_t SH, size_t SW, size_t DH,   \
                            size_t DW, hipStream_t stream);               \
    template void col2im<T>(const T* col, T* im, size_t N, size_t INP_BS, \
                            size_t IC, size_t IH, size_t IW, size_t FH,   \
                            size_t FW, size_t OH, size_t OW, size_t PH,   \
                            size_t PW, size_t SH, size_t SW, size_t DH,   \
                            size_t DW, hipStream_t stream);

#define INST(_dt) DO_INST(DTypeTrait<_dt>::ctype)

MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(INST);

#undef DO_INST
#undef INST

} // namespace convolution
} // namespace rocm
} // namespace megdnn

// vim: syntax=cpp.doxygen