/**
* \file dnn/src/rocm/reduce_helper/largeBC.hipinl
*
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/rocm/reduce_helper.h.hip"
#include <algorithm>
#include <cstdio>
namespace megdnn {
namespace rocm {
namespace reduce_intl {
struct ExecPolicy {
// (BY, BX) is the blockDim to launch reduce kernel
ExecPolicy(size_t A, size_t B, size_t C):
A(A), B(B), C(C)
{
// use C to determine BX
BX = 1;
while (BX < 32 && BX < C) BX *= 2;
BY = 512 / BX;
NA = A;
factor = BY*4;
NB = DIVUP(B, factor);
NC = DIVUP(C, BX);
{
nr_reduces = 0;
size_t tmp = B;
while (tmp > 1) {
tmp = DIVUP(tmp, factor);
++nr_reduces;
}
if (nr_reduces == 0) nr_reduces = 1;
}
}
ExecPolicy next() const
{
return ExecPolicy(A, DIVUP(B, factor), C);
}
size_t factor;
size_t nr_reduces;
size_t BY, BX;
size_t NA, NB, NC;
size_t A, B, C;
};
// Whenever blockIdx is referenced, bidy_offset and bidz_offset should be added.
// This mechanism is to solve thread block size limitation issue by calling
// multiple kernels from host code.
template <class Operator, class Reader, class Writer, typename wtype,
uint32_t BX, uint32_t BY, bool sync_within_warp>
__global__ void kern_largeBC(
Operator opr, Reader rdr, Writer wtr,
uint32_t A, uint32_t B, uint32_t B2, uint32_t C,
uint32_t bidy_offset, uint32_t bidz_offset)
{
volatile __shared__ wtype shared[BY][BX];
wtype s = opr.INIT;
uint32_t c = threadIdx.x + blockIdx.x * blockDim.x;
uint32_t a = blockIdx.z+bidz_offset;
if (c < C) {
uint32_t base = threadIdx.y + (blockIdx.y+bidy_offset)*4*blockDim.y;
if (base + 0*blockDim.y < B) {
s = opr.apply(s, rdr.read(a*B*C + (base + 0*blockDim.y)*C + c));
}
if (base + 1*blockDim.y < B) {
s = opr.apply(s, rdr.read(a*B*C + (base + 1*blockDim.y)*C + c));
}
if (base + 2*blockDim.y < B) {
s = opr.apply(s, rdr.read(a*B*C + (base + 2*blockDim.y)*C + c));
}
if (base + 3*blockDim.y < B) {
s = opr.apply(s, rdr.read(a*B*C + (base + 3*blockDim.y)*C + c));
}
}
shared[threadIdx.y][threadIdx.x] = s;
__syncthreads();
const uint32_t warp_y = 32 / BX;
#pragma unroll
for (uint32_t k = 256; k > warp_y; k >>= 1) {
if (BY >= k<<1) {
if (threadIdx.y < k) {
shared[threadIdx.y][threadIdx.x] = opr.apply(
shared[threadIdx.y][threadIdx.x],
shared[threadIdx.y+k][threadIdx.x]);
}
__syncthreads();
}
}
if (threadIdx.y < warp_y) {
#pragma unroll
for (uint32_t k = warp_y; k > 0; k >>= 1) {
if (threadIdx.y < k) {
shared[threadIdx.y][threadIdx.x] =
opr.apply(shared[threadIdx.y][threadIdx.x],
shared[threadIdx.y + k][threadIdx.x]);
}
if (sync_within_warp) {
__syncthreads();
}
}
}
if (threadIdx.y == 0 && c < C) {
uint32_t b2 = blockIdx.y+bidy_offset;
wtr.write(a*B2*C + b2*C + c, shared[0][threadIdx.x]);
}
}
/**
* \tparam Operator must have method wtype apply(wtype, wtype)
* \tparam Operator must have const member INIT
* \tparam Reader must have method wtype read(size_t idx)
* \tparam Writer must have method void write(size_t idx, wtype)
*/
template <class Operator, class Reader, class Writer, typename wtype,
bool sync_within_warp>
void invoke_kernel(const ExecPolicy &p,
const Operator &opr,
const Reader &rdr,
const Writer &wtr,
hipStream_t stream)
{
// 32768 thread blocks for each call
#define CHECK(nBX, nBY) \
if (p.BX == nBX && p.BY == nBY) { \
for (size_t bidy_offset = 0; bidy_offset < p.NB; bidy_offset += 32768) \
for (size_t bidz_offset = 0; bidz_offset < p.NA; \
bidz_offset += 32768) { \
dim3 blocks; \
blocks.x = p.NC; \
blocks.y = std::min<size_t>(32768, p.NB - bidy_offset); \
blocks.z = std::min<size_t>(32768, p.NA - bidz_offset); \
void (*kptr)(Operator op, Reader rdr, Writer wtr, uint32_t A, \
uint32_t B, uint32_t B2, uint32_t C, \
uint32_t bidy_offset, uint32_t bidz_offset); \
kptr = kern_largeBC<Operator, Reader, Writer, wtype, nBX, nBY, \
sync_within_warp>; \
hipLaunchKernelGGL((kptr), dim3(blocks), dim3(p.BX, p.BY), 0, \
stream, opr, rdr, wtr, p.A, p.B, \
DIVUP(p.B, p.factor), p.C, bidy_offset, \
bidz_offset); \
} \
}
void (*kptr)(Operator op, Reader rdr, Writer wtr, uint32_t A, uint32_t B,
uint32_t B2, uint32_t C, uint32_t bidy_offset,
uint32_t bidz_offset);
#define CHECK2(nBX) \
if (p.BX == nBX && p.BY == 512 / nBX) { \
kptr = kern_largeBC<Operator, Reader, Writer, wtype, nBX, 512 / nBX, \
sync_within_warp>; \
}
CHECK2(1)
CHECK2(2)
CHECK2(4)
CHECK2(8)
CHECK2(16)
CHECK2(32)
for (size_t bidy_offset = 0; bidy_offset < p.NB; bidy_offset += 32768) {
for (size_t bidz_offset = 0; bidz_offset < p.NA; bidz_offset += 32768) {
dim3 blocks;
blocks.x = p.NC;
blocks.y = std::min<size_t>(32768, p.NB - bidy_offset);
blocks.z = std::min<size_t>(32768, p.NA - bidz_offset);
hipLaunchKernelGGL((kptr), dim3(blocks), dim3(p.BX, p.BY), 0,
stream, opr, rdr, wtr, p.A, p.B,
DIVUP(p.B, p.factor), p.C, bidy_offset,
bidz_offset);
}
}
//! using MACRO CHECK to dispatch block size will cause segmentfault when
//! compiling
#undef CHECK
#undef CHECK2
after_kernel_launch();
}
/**
* inherit from PublicOperator
*/
template <class PublicOperator>
struct PublicReader {
PublicOperator opr;
typedef typename PublicOperator::wtype wtype;
PublicReader(const PublicOperator &opr): opr(opr)
{}
__device__ wtype read(uint32_t idx)
{ return opr.read(idx); }
};
/**
* read from workspace
*/
template <typename wtype>
struct WorkspaceReader {
wtype *workspace;
WorkspaceReader(wtype *workspace): workspace(workspace)
{}
__device__ wtype read(uint32_t idx)
{ return workspace[idx]; }
};
/**
* inherit from PublicOperator
*/
template <class PublicOperator>
struct PublicWriter {
PublicOperator opr;
typedef typename PublicOperator::wtype wtype;
PublicWriter(const PublicOperator &opr): opr(opr)
{}
__device__ void write(uint32_t idx, wtype value)
{ opr.write(idx, value); }
};
/**
* write to workspace
*/
template <typename wtype>
struct WorkspaceWriter {
wtype *workspace;
WorkspaceWriter(wtype *workspace): workspace(workspace)
{}
__device__ void write(uint32_t idx, wtype value)
{ workspace[idx] = value; }
};
/**
* \tparam PublicOperator
* must have typedef for wtype
* must have const static member wtype INIT
* must have method wtype read(uint32_t idx)
* must have method wtype apply(const wtype &, const wtype &)
* must have method void write(uint32_t idx, const wtype &)
*/
template <class PublicOperator, bool sync_within_warp>
void run_largeBC(typename PublicOperator::wtype *workspace,
size_t A, size_t B, size_t C,
hipStream_t stream, const PublicOperator &opr)
{
typedef typename PublicOperator::wtype wtype;
using namespace reduce_intl;
ExecPolicy p(A, B, C);
if (p.nr_reduces == 1) {
PublicReader<PublicOperator> rdr(opr);
PublicWriter<PublicOperator> wtr(opr);
invoke_kernel<PublicOperator,
PublicReader<PublicOperator>,
PublicWriter<PublicOperator>,
wtype,
sync_within_warp>(p, opr, rdr, wtr, stream);
} else if (p.nr_reduces == 2) {
PublicReader<PublicOperator> rdr1(opr);
WorkspaceWriter<wtype> wtr1(workspace);
WorkspaceReader<wtype> rdr2(workspace);
PublicWriter<PublicOperator> wtr2(opr);
invoke_kernel<PublicOperator,
PublicReader<PublicOperator>,
WorkspaceWriter<wtype>,
wtype,
sync_within_warp>(p, opr, rdr1, wtr1, stream);
p = p.next();
invoke_kernel<PublicOperator,
WorkspaceReader<wtype>,
PublicWriter<PublicOperator>,
wtype,
sync_within_warp>(p, opr, rdr2, wtr2, stream);
} else {
wtype *workspace1 = workspace;
size_t B2 = DIVUP(B, p.factor);
wtype *workspace2 = workspace + A * B2 * C;
size_t nr_reduces = p.nr_reduces;
{
PublicReader<PublicOperator> rdr(opr);
WorkspaceWriter<wtype> wtr(workspace1);
invoke_kernel<PublicOperator,
PublicReader<PublicOperator>,
WorkspaceWriter<wtype>,
wtype,
sync_within_warp>(p, opr, rdr, wtr, stream);
}
p = p.next();
wtype *current = workspace1;
wtype *next = workspace2;
for (size_t i = 1; i < nr_reduces; ++i) {
WorkspaceReader<wtype> rdr(current);
if (i + 1 == nr_reduces) {
PublicWriter<PublicOperator> wtr(opr);
invoke_kernel<PublicOperator,
WorkspaceReader<wtype>,
PublicWriter<PublicOperator>,
wtype,
sync_within_warp>(p, opr, rdr, wtr, stream);
} else {
WorkspaceWriter<wtype> wtr(next);
invoke_kernel<PublicOperator,
WorkspaceReader<wtype>,
WorkspaceWriter<wtype>,
wtype,
sync_within_warp>(p, opr, rdr, wtr, stream);
}
std::swap(next, current);
p = p.next();
}
}
}
template <typename wtype>
size_t get_workspace_largeBC(size_t A, size_t B, size_t C)
{
using namespace reduce_intl;
ExecPolicy p(A, B, C);
if (p.nr_reduces == 1) {
// direct reduce
return 0;
} else if (p.nr_reduces == 2) {
// src->workspace->dst
size_t B2 = DIVUP(B, p.factor);
return sizeof(wtype) * A * B2 * C;
} else {
// src->workspace1->workspace2->dst
size_t B2 = DIVUP(B, p.factor);
size_t B3 = DIVUP(B2, p.factor);
return sizeof(wtype) * A * B2 * C + sizeof(wtype) * A * B3 * C;
}
}
} // namespace reduce_intl
} // namespace rocm
} // namespace megdnn
// vim: ft=cpp syntax=cpp.doxygen