megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file dnn/src/rocm/convolution/forward/inplace_matmul_impl.cpp.hip
 *
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#include "./inplace_matmul_impl.h.hip"
#include "src/rocm/utils.h.hip"

using namespace megdnn;
using namespace rocm;

namespace {

struct BufferFetcherTexture {
    hipTextureObject_t tex;

    __device__ __forceinline__ float get(uint32_t offset) {
        return tex1Dfetch<float>(tex, offset);
    }
};

struct BufferFetcherRaw {
    const float* ptr;

    __device__ __forceinline__ float get(uint32_t offset) {
        return ptr[offset];
    }
};

struct BufferFetcherTextureHost {
    bool init_succ;
    BufferFetcherTexture val;

    BufferFetcherTextureHost(float* p, const size_t n);

    ~BufferFetcherTextureHost() { reset(); }

    void reset() {
        if (init_succ) {
            hip_check(hipDestroyTextureObject(val.tex));
            init_succ = false;
        }
    }
};

BufferFetcherTextureHost::BufferFetcherTextureHost(float* p, const size_t n) {
    init_succ = false;
    hipTextureObject_t tex_obj;

    hipResourceDesc res_desc;
    memset(&res_desc, 0, sizeof(hipResourceDesc));
    res_desc.resType = hipResourceTypeLinear;
    res_desc.res.linear.devPtr = static_cast<void*>(p);
    res_desc.res.linear.sizeInBytes = n * sizeof(float);
    res_desc.res.linear.desc =
            hipCreateChannelDesc(32, 0, 0, 0, hipChannelFormatKindFloat);
    hipTextureDesc tex_desc;
    memset(&tex_desc, 0, sizeof(hipTextureDesc));
    if (hipCreateTextureObject(&tex_obj, &res_desc, &tex_desc, NULL) ==
        hipSuccess) {
        val.tex = tex_obj;
        init_succ = true;
    } else {
        hipGetLastError();  // reset error
    }
}

template <class BufferFetcher>
struct KernelPtr {
    typedef void (*type)(BufferFetcher, BufferFetcher, float*, uint32_t,
                         uint32_t, uint32_t, uint32_t, uint32_t, uint32_t,
                         uint32_t, uint32_t, uint32_t, uint32_t, uint32_t,
                         uint32_t, uint32_t, uint32_t);
};

//! 1 -> 0xffffffff, 0 -> 0x00000000
__device__ __forceinline__ uint32_t bool_as_mask(uint32_t cond) {
    return (!cond) - 1u;
}

union FloatAndU32 {
    float f;
    uint32_t u;
};

//! \p mask must be either all 1 or 0 bits
template <class BufferFetcher>
__device__ __forceinline__ float visit_with_mask(BufferFetcher buf,
                                                 uint32_t offset,
                                                 uint32_t mask) {
    FloatAndU32 f;
    f.f = buf.get(offset & mask);
    f.u &= mask;
    return f.f;
}

template <uint32_t BY, uint32_t BX, bool is_xcorr, class BufferFetcher>
__global__ void conv_kernel(BufferFetcher src, BufferFetcher filter, float* dst,
                            const uint32_t INP_BS, const uint32_t OUT_BS,
                            const uint32_t IC, const uint32_t IH,
                            const uint32_t IW, const uint32_t OC,
                            const uint32_t OH, const uint32_t OW,
                            const uint32_t FH, const uint32_t FW,
                            const uint32_t SH, const uint32_t SW,
                            const uint32_t PH, const uint32_t PW) {
    const uint32_t BM = BY < BX ? BY : BX;
    const uint32_t n = blockIdx.z;
    const uint32_t tidx = threadIdx.x;
    const uint32_t tidy = threadIdx.y;
    const uint32_t posx = blockIdx.x * blockDim.x + threadIdx.x;
    const uint32_t posy = blockIdx.y * blockDim.y + threadIdx.y;
    const uint32_t posx2 = posx << 2;
    const uint32_t posy2 = posy << 2;
    const uint32_t heightA = OC;
    const uint32_t widthA = IC * FH * FW;
    const uint32_t heightB = widthA;
    const uint32_t widthB = OH * OW;
    const uint32_t oh0 = (posx2 + 0) / OW * SH;
    const uint32_t ow0 = (posx2 + 0) % OW * SW;
    const uint32_t op0 = oh0 * IW + ow0;
    const uint32_t oh1 = (posx2 + 1) / OW * SH;
    const uint32_t ow1 = (posx2 + 1) % OW * SW;
    const uint32_t op1 = oh1 * IW + ow1;
    const uint32_t oh2 = (posx2 + 2) / OW * SH;
    const uint32_t ow2 = (posx2 + 2) % OW * SW;
    const uint32_t op2 = oh2 * IW + ow2;
    const uint32_t oh3 = (posx2 + 3) / OW * SH;
    const uint32_t ow3 = (posx2 + 3) % OW * SW;
    const uint32_t op3 = oh3 * IW + ow3;
    const uint32_t FP = FH * FW;
    __shared__ float4 localA[BY][BM];
    __shared__ float4 localB[BM][BX];
    uint32_t i = 0u;
    uint32_t offsetA = posy2 * widthA + tidx;
    uint32_t offsetB = n * INP_BS - PH * IW - PW;
    float4 sum0 = {0.0f, 0.0f, 0.0f, 0.0f}, sum1 = {0.0f, 0.0f, 0.0f, 0.0f},
           sum2 = {0.0f, 0.0f, 0.0f, 0.0f}, sum3 = {0.0f, 0.0f, 0.0f, 0.0f};
    uint32_t fh = tidy / FW % FH;
    uint32_t fw = tidy % FW;
    uint32_t ic = tidy / (FH * FW);
    uint32_t icm = tidy % (FH * FW);

    const uint32_t fhs = BM / FW % FH;
    const uint32_t fws = BM % FW;
    const uint32_t ics = BM / (FH * FW);
    const uint32_t icms = BM % (FH * FW);

    for (; i < widthA; i += BM, offsetA += BM) {
        // load localA
        if (tidx < BM) {
            localA[tidy][tidx].x = filter.get(offsetA + 0 * widthA);
            localA[tidy][tidx].y = filter.get(offsetA + 1 * widthA);
            localA[tidy][tidx].z = filter.get(offsetA + 2 * widthA);
            localA[tidy][tidx].w = filter.get(offsetA + 3 * widthA);
        }

        // load localB
        uint32_t fh2, fw2;
        if (is_xcorr) {
            fh2 = fh;
            fw2 = fw;
        } else {
            fh2 = FH - fh - 1;
            fw2 = FW - fw - 1;
        }

        if (tidy < BM) {
            uint32_t tmp = offsetB + (ic * IH + (fh2)) * IW + (fw2),
                     ok = bool_as_mask(tidy + i < heightB),
                     p0 = bool_as_mask(fh2 + oh0 >= PH && fh2 + oh0 < IH + PH &&
                                       fw2 + ow0 >= PW && fw2 + ow0 < IW + PW),
                     p1 = bool_as_mask(fh2 + oh1 >= PH && fh2 + oh1 < IH + PH &&
                                       fw2 + ow1 >= PW && fw2 + ow1 < IW + PW),
                     p2 = bool_as_mask(fh2 + oh2 >= PH && fh2 + oh2 < IH + PH &&
                                       fw2 + ow2 >= PW && fw2 + ow2 < IW + PW),
                     p3 = bool_as_mask(fh2 + oh3 >= PH && fh2 + oh3 < IH + PH &&
                                       fw2 + ow3 >= PW && fw2 + ow3 < IW + PW);
            localB[tidy][tidx].x = visit_with_mask(src, tmp + op0, ok & p0);
            localB[tidy][tidx].y = visit_with_mask(src, tmp + op1, ok & p1);
            localB[tidy][tidx].z = visit_with_mask(src, tmp + op2, ok & p2);
            localB[tidy][tidx].w = visit_with_mask(src, tmp + op3, ok & p3);
        }

        __syncthreads();

        for (uint32_t j = 0u; j < BM; ++j) {
            float4 tmpA = localA[tidy][j];
            float4 tmpB = localB[j][tidx];
            sum0.x += tmpA.x * tmpB.x;
            sum0.y += tmpA.x * tmpB.y;
            sum0.z += tmpA.x * tmpB.z;
            sum0.w += tmpA.x * tmpB.w;
            sum1.x += tmpA.y * tmpB.x;
            sum1.y += tmpA.y * tmpB.y;
            sum1.z += tmpA.y * tmpB.z;
            sum1.w += tmpA.y * tmpB.w;
            sum2.x += tmpA.z * tmpB.x;
            sum2.y += tmpA.z * tmpB.y;
            sum2.z += tmpA.z * tmpB.z;
            sum2.w += tmpA.z * tmpB.w;
            sum3.x += tmpA.w * tmpB.x;
            sum3.y += tmpA.w * tmpB.y;
            sum3.z += tmpA.w * tmpB.z;
            sum3.w += tmpA.w * tmpB.w;
        }

        fw += fws;
        fh += fhs;
        fh += (fw >= FW);
        fh -= (fh >= FH) * FH;
        fw -= (fw >= FW) * FW;

        ic += ics;
        icm += icms;
        ic += (icm >= FP);
        icm -= (icm >= FP) * FP;
        __syncthreads();
    }
    const uint32_t dst_idx = n * OUT_BS + posy2 * widthB + posx2;
    bool y0 = (posy2 + 0 < heightA);
    bool y1 = (posy2 + 1 < heightA);
    bool y2 = (posy2 + 2 < heightA);
    bool y3 = (posy2 + 3 < heightA);
    bool x0 = (posx2 + 0 < widthB);
    bool x1 = (posx2 + 1 < widthB);
    bool x2 = (posx2 + 2 < widthB);
    bool x3 = (posx2 + 3 < widthB);
    if (y0) {
        if (x0)
            dst[dst_idx + 0 * widthB + 0] = sum0.x;
        if (x1)
            dst[dst_idx + 0 * widthB + 1] = sum0.y;
        if (x2)
            dst[dst_idx + 0 * widthB + 2] = sum0.z;
        if (x3)
            dst[dst_idx + 0 * widthB + 3] = sum0.w;
    }
    if (y1) {
        if (x0)
            dst[dst_idx + 1 * widthB + 0] = sum1.x;
        if (x1)
            dst[dst_idx + 1 * widthB + 1] = sum1.y;
        if (x2)
            dst[dst_idx + 1 * widthB + 2] = sum1.z;
        if (x3)
            dst[dst_idx + 1 * widthB + 3] = sum1.w;
    }
    if (y2) {
        if (x0)
            dst[dst_idx + 2 * widthB + 0] = sum2.x;
        if (x1)
            dst[dst_idx + 2 * widthB + 1] = sum2.y;
        if (x2)
            dst[dst_idx + 2 * widthB + 2] = sum2.z;
        if (x3)
            dst[dst_idx + 2 * widthB + 3] = sum2.w;
    }
    if (y3) {
        if (x0)
            dst[dst_idx + 3 * widthB + 0] = sum3.x;
        if (x1)
            dst[dst_idx + 3 * widthB + 1] = sum3.y;
        if (x2)
            dst[dst_idx + 3 * widthB + 2] = sum3.z;
        if (x3)
            dst[dst_idx + 3 * widthB + 3] = sum3.w;
    }
}

}  // anonymous namespace

void convolution::exec_inplace_matmul_fwd(
        const float* src, const float* filter, float* dst, size_t N,
        size_t INP_BS, size_t OUT_BS, size_t IC, size_t IH, size_t IW,
        size_t OC, size_t OH, size_t OW, size_t FH, size_t FW, size_t PH,
        size_t PW, size_t SH, size_t SW, bool is_xcorr, hipStream_t stream) {
    BufferFetcherTextureHost src_tex(const_cast<float*>(src), N * INP_BS),
            filter_tex(const_cast<float*>(filter), OC * IC * FH * FW);

    BufferFetcherRaw src_buf, filter_buf;
    src_buf.ptr = src;
    filter_buf.ptr = filter;
    if (!src_tex.init_succ || !filter_tex.init_succ) {
        src_tex.reset();
        filter_tex.reset();
    }
    int m = OC;
    int n = OH * OW;
    int BY = 1;
    int BX = 1;
    if (m <= 64) {
        while (BY < 16 && (BY << 2) < m)
            BY <<= 1;
        BX = 256 / BY;
    } else if (n <= 64) {
        while (BX < 16 && (BX << 2) < n)
            BX <<= 1;
        BY = 256 / BX;
    } else {
        BX = BY = 16;
    }
    dim3 blocks((OH * OW + BX * 4 - 1) / (BX * 4), (OC + BY * 4 - 1) / (BY * 4),
                N);
    dim3 threads(BX, BY);
#define DISPATCH_BX_BY(BX, BY)                                                \
    do {                                                                      \
        if (src_tex.init_succ) {                                              \
            KernelPtr<BufferFetcherTexture>::type kptr;                       \
            if (is_xcorr) {                                                   \
                kptr = conv_kernel<BY, BX, true, BufferFetcherTexture>;       \
            } else {                                                          \
                kptr = conv_kernel<BY, BX, false, BufferFetcherTexture>;      \
            }                                                                 \
            hipLaunchKernelGGL(kptr, blocks, threads, 0, stream,                    \
                    src_tex.val, filter_tex.val, dst, INP_BS, OUT_BS, IC, IH, \
                    IW, OC, OH, OW, FH, FW, SH, SW, PH, PW);                  \
        } else {                                                              \
            KernelPtr<BufferFetcherRaw>::type kptr;                           \
            if (is_xcorr) {                                                   \
                kptr = conv_kernel<BY, BX, true, BufferFetcherRaw>;           \
            } else {                                                          \
                kptr = conv_kernel<BY, BX, false, BufferFetcherRaw>;          \
            }                                                                 \
            hipLaunchKernelGGL(kptr, blocks, threads, 0, stream,              \
                    src_buf, filter_buf, dst, INP_BS, OUT_BS, IC, IH, IW, OC, \
                    OH, OW, FH, FW, SH, SW, PH, PW);                          \
        }                                                                     \
    } while (0)
#define DISPATCH_BX(BX)               \
    do {                              \
        DISPATCH_BX_BY(BX, 256 / BX); \
    } while (0)
#define DISPATCH()                                \
    do {                                          \
        switch (BX) {                             \
            case 1:                               \
                DISPATCH_BX(1);                   \
                break;                            \
            case 2:                               \
                DISPATCH_BX(2);                   \
                break;                            \
            case 4:                               \
                DISPATCH_BX(4);                   \
                break;                            \
            case 8:                               \
                DISPATCH_BX(8);                   \
                break;                            \
            case 16:                              \
                DISPATCH_BX(16);                  \
                break;                            \
            case 32:                              \
                DISPATCH_BX(32);                  \
                break;                            \
            case 64:                              \
                DISPATCH_BX(64);                  \
                break;                            \
            case 128:                             \
                DISPATCH_BX(128);                 \
                break;                            \
            case 256:                             \
                DISPATCH_BX(256);                 \
                break;                            \
            default:                              \
                report_error("no usable kernel"); \
        }                                         \
    } while (0)
    DISPATCH();
#undef DISPATCH
#undef DISPATCH_BX
#undef DISPATCH_BX_BY
    after_kernel_launch();
}

// vim: syntax=cpp.doxygen