megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file dnn/src/rocm/convolution/chanwise/bwd_data.cpp.hip
 *
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "hip_header.h"
#include "./kern.h.hip"
#include "./kern_helper.h.hip"

using namespace megdnn;
using namespace rocm;
using namespace convolution;
using namespace chanwise;

namespace {

// grid idx is (inp_chl, worker_index)
// each y-slice of a block works on an (N, IH, IW) spatial image at given
// inp_chl
template <typename T, int CHL_MUL_SET, int FH_SET, int FW_SET, int SH_SET,
          int SW_SET>
__global__ void kern_bwd_data(T* src_grad, const T* dst_grad, const T* flt_tot,
                              Param param) {
    extern __shared__ uint8_t flt_storage[];

    T* const flt = reinterpret_cast<T*>(flt_storage);

    const uint32_t N = param.batch, IC = param.src_chl, ic = blockIdx.x,
                   IH = param.src_h, IW = param.src_w,
                   CHL_MUL = CHL_MUL_SET ? CHL_MUL_SET : param.chl_mul,
                   FH = FH_SET ? FH_SET : param.flt_h,
                   FW = FW_SET ? FW_SET : param.flt_w, FSIZE = FH * FW,
                   PH = param.pad_h, PW = param.pad_w,
                   SH = SH_SET ? SH_SET : param.stride_h,
                   SW = SW_SET ? SW_SET : param.stride_w, OH = param.out_h,
                   OW = param.out_w, TOT_OUT = N * IH * IW;

    block_memcpy(flt, flt_tot + ic * FSIZE * CHL_MUL, FSIZE * CHL_MUL);
    dst_grad += ic * CHL_MUL * OH * OW;
    src_grad += ic * IH * IW;

    uint32_t out_idx_ = blockIdx.y * blockDim.x + threadIdx.x,
             nr_out_per_launch = blockDim.x * gridDim.y;
    for (; out_idx_ < TOT_OUT; out_idx_ += nr_out_per_launch) {
        uint32_t out_idx = out_idx_, n, ih, iw;
        out_idx = div_mod(out_idx, IW, iw);
        out_idx = div_mod(out_idx, IH, ih);
        n = out_idx;

        const T* dst_grad_base = dst_grad + n * (IC * CHL_MUL * OH * OW);

        T sum(0);

        // o >= max(0, floor_div((i+P-F+1), S))
        uint32_t ohmin = max(int32_t(ih + PH - FH + SH), 0) / SH,
                 owmin = max(int32_t(iw + PW - FW + SW), 0) / SW,
                 ohmax = min((ih + PH) / SH, OH - 1),
                 owmax = min((iw + PW) / SW, OW - 1);
        if (SH_SET == 1 && SW_SET == 1 && FH_SET && FW_SET) {
#pragma unroll
            for (uint32_t doh = 0; doh < FH; ++doh) {
                uint32_t oh = ohmin + doh;
                if (oh <= ohmax) {
                    uint32_t fh = ih - oh * SH + PH;
#pragma unroll
                    for (uint32_t dow = 0; dow < FW; ++dow) {
                        uint32_t ow = owmin + dow;
                        if (ow <= owmax) {
                            uint32_t fw = iw - ow * SW + PW;
                            const T* pd = dst_grad_base + oh * OW + ow;
                            const T* pf = flt + fh * FW + fw;
#pragma unroll
                            for (uint32_t chl_mul = 0; chl_mul < CHL_MUL;
                                 ++chl_mul) {
                                sum += *pd * *pf;
                                pd += OH * OW;
                                pf += FSIZE;
                            }
                        }
                    }
                }
            }
        } else {
            for (uint32_t oh = ohmin; oh <= ohmax; ++oh) {
                uint32_t fh = ih - oh * SH + PH;
                for (uint32_t ow = owmin; ow <= owmax; ++ow) {
                    uint32_t fw = iw - ow * SW + PW;
                    const T* pd = dst_grad_base + oh * OW + ow;
                    const T* pf = flt + fh * FW + fw;
#pragma unroll
                    for (uint32_t chl_mul = 0; chl_mul < CHL_MUL; ++chl_mul) {
                        sum += *pd * *pf;
                        pd += OH * OW;
                        pf += FSIZE;
                    }
                }
            }
        }

        src_grad[(n * (IC * IH) + ih) * IW + iw] = sum;
    }
}

template <typename T>
class KernDispatch {
public:
    typedef void (*kern_ptr_t)(T*, const T*, const T*, Param);

    static kern_ptr_t dispatch(int chl_mul, int fh, int fw, int sh, int sw) {
        if (chl_mul == 1) {
            if (fh == 3 && fw == 3)
                return d1<1, 3, 3>(sh, sw);
            if (fh == 4 && fw == 4)
                return d1<1, 4, 4>(sh, sw);
        }
        return d1<0, 0, 0>(sh, sw);
    }

private:
    template <int chl_mul, int fh, int fw>
    static kern_ptr_t d1(int sh, int sw) {
        if (sh == 1 && sw == 1)
            return kern_bwd_data<T, chl_mul, fh, fw, 1, 1>;
        if (sh == 1 && sw == 2)
            return kern_bwd_data<T, chl_mul, fh, fw, 1, 2>;
        if (sh == 2 && sw == 1)
            return kern_bwd_data<T, chl_mul, fh, fw, 2, 1>;
        if (sh == 2 && sw == 2)
            return kern_bwd_data<T, chl_mul, fh, fw, 2, 2>;
        return kern_bwd_data<T, chl_mul, fh, fw, 0, 0>;
    }
};

}  // anonymous namespace

template <typename T>
void chanwise::run_bwd_data(T* src_grad, const T* dst_grad, const T* flt,
                            const Param& param, hipStream_t stream) {
    typename KernDispatch<T>::kern_ptr_t kern =
            KernDispatch<T>::dispatch(param.chl_mul, param.flt_h, param.flt_w,
                                      param.stride_h, param.stride_w);
    int nr_thread = 256, nr_out_dimx = param.src_h * param.src_w * param.batch;
    dim3 nr_block(param.src_chl,
                  std::min(512, max(nr_out_dimx / (nr_thread * 4), 1)));
    uint32_t shared = param.chl_mul * param.flt_h * param.flt_w * sizeof(T);
    hipLaunchKernelGGL(kern, nr_block, nr_thread, shared, stream, src_grad, dst_grad, flt,
                                                  param);
    after_kernel_launch();
}

namespace megdnn {
namespace rocm {
namespace convolution {
namespace chanwise {

#define INST(_dt)                                                   \
    template void run_bwd_data(                                     \
            DTypeTrait<_dt>::ctype*, const DTypeTrait<_dt>::ctype*, \
            const DTypeTrait<_dt>::ctype*, const Param&, hipStream_t);
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(INST)
#undef INST
#undef DO_INST

} // namespace chanwise
} // namespace convolution
} // namespace rocm
} // namespace megdnn

// vim: syntax=cuda.doxygen