/**
* \file dnn/src/cuda/reduce_helper/largeBC.cuinl
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/cuda/reduce_helper.cuh"
#include "src/cuda/cub/util_ptx.cuh"
#include <algorithm>
#include <cstdio>
namespace megdnn {
namespace cuda {
namespace reduce_intl {
struct ExecPolicy {
// (BY, BX) is the blockDim to launch reduce kernel
ExecPolicy(size_t A, size_t B, size_t C):
A(A), B(B), C(C)
{
// use C to determine BX
BX = 1;
while (BX < 32 && BX < C) BX *= 2;
BY = 512 / BX;
NA = A;
factor = BY*4;
NB = DIVUP(B, factor);
NC = DIVUP(C, BX);
{
nr_reduces = 0;
size_t tmp = B;
while (tmp > 1) {
tmp = DIVUP(tmp, factor);
++nr_reduces;
}
if (nr_reduces == 0) nr_reduces = 1;
}
}
ExecPolicy next() const
{
return ExecPolicy(A, DIVUP(B, factor), C);
}
size_t factor;
size_t nr_reduces;
size_t BY, BX;
size_t NA, NB, NC;
size_t A, B, C;
};
// Whenever blockIdx is referenced, bidy_offset and bidz_offset should be added.
// This mechanism is to solve thread block size limitation issue by calling
// multiple kernels from host code.
template <class Operator, class Reader, class Writer, typename wtype,
uint32_t BX, uint32_t BY, bool sync_within_warp>
__global__ void kern_largeBC(
Operator opr, Reader rdr, Writer wtr,
uint32_t A, uint32_t B, uint32_t B2, uint32_t C,
uint32_t bidy_offset, uint32_t bidz_offset)
{
volatile __shared__ wtype shared[BY][BX];
wtype s = opr.INIT;
uint32_t c = threadIdx.x + blockIdx.x * blockDim.x;
uint32_t a = blockIdx.z+bidz_offset;
if (c < C) {
uint32_t base = threadIdx.y + (blockIdx.y+bidy_offset)*4*blockDim.y;
if (base + 0*blockDim.y < B) {
s = opr.apply(s, rdr.read(a*B*C + (base + 0*blockDim.y)*C + c));
}
if (base + 1*blockDim.y < B) {
s = opr.apply(s, rdr.read(a*B*C + (base + 1*blockDim.y)*C + c));
}
if (base + 2*blockDim.y < B) {
s = opr.apply(s, rdr.read(a*B*C + (base + 2*blockDim.y)*C + c));
}
if (base + 3*blockDim.y < B) {
s = opr.apply(s, rdr.read(a*B*C + (base + 3*blockDim.y)*C + c));
}
}
shared[threadIdx.y][threadIdx.x] = s;
__syncthreads();
const uint32_t warp_y = 32 / BX;
#pragma unroll
for (uint32_t k = 256; k > warp_y; k >>= 1) {
if (BY >= k<<1) {
if (threadIdx.y < k) {
shared[threadIdx.y][threadIdx.x] = opr.apply(
shared[threadIdx.y][threadIdx.x],
shared[threadIdx.y+k][threadIdx.x]);
}
__syncthreads();
}
}
if (threadIdx.y < warp_y) {
#pragma unroll
for (uint32_t k = warp_y; k > 0; k >>= 1) {
if (threadIdx.y < k) {
shared[threadIdx.y][threadIdx.x] =
opr.apply(shared[threadIdx.y][threadIdx.x],
shared[threadIdx.y + k][threadIdx.x]);
}
if (sync_within_warp) {
__syncthreads();
}
/**
* \warning Since CUDA 9.0, for Volta and Turing architecture,
* applications that assume reads and writes are implicitly visible
* to other threads in same warp need to insert the new __syncwarp()
* warp-wide barrier synchronization instruction between steps where
* data is exchanged between threads via global or shared memory.
* For details, please refer to
* https://docs.nvidia.com/cuda/volta-tuning-guide/index.html
*/
cub::WARP_SYNC(0xffffffff);
}
}
if (threadIdx.y == 0 && c < C) {
uint32_t b2 = blockIdx.y+bidy_offset;
wtr.write(a*B2*C + b2*C + c, shared[0][threadIdx.x]);
}
}
/**
* \tparam Operator must have method wtype apply(wtype, wtype)
* \tparam Operator must have const member INIT
* \tparam Reader must have method wtype read(size_t idx)
* \tparam Writer must have method void write(size_t idx, wtype)
*/
template <class Operator, class Reader, class Writer, typename wtype,
bool sync_within_warp>
void invoke_kernel(const ExecPolicy &p,
const Operator &opr,
const Reader &rdr,
const Writer &wtr,
cudaStream_t stream)
{
// 32768 thread blocks for each call
#define CHECK(nBX) \
if (p.BX == nBX && p.BY == 512/nBX) { \
for (size_t bidy_offset = 0; bidy_offset < p.NB; bidy_offset += 32768) \
for (size_t bidz_offset = 0; bidz_offset < p.NA; bidz_offset += 32768) \
{ \
dim3 blocks; \
blocks.x = p.NC; \
blocks.y = std::min<size_t>(32768, p.NB - bidy_offset); \
blocks.z = std::min<size_t>(32768, p.NA - bidz_offset); \
kern_largeBC<Operator, Reader, Writer, wtype, nBX, 512/nBX, \
sync_within_warp><<<blocks, dim3(p.BX, p.BY), 0, stream>>>( \
opr, rdr, wtr, p.A, p.B, DIVUP(p.B, p.factor), p.C, \
bidy_offset, bidz_offset); \
} \
}
CHECK(1);
CHECK(2);
CHECK(4);
CHECK(8);
CHECK(16);
CHECK(32);
#undef CHECK
after_kernel_launch();
}
/**
* inherit from PublicOperator
*/
template <class PublicOperator>
struct PublicReader {
PublicOperator opr;
typedef typename PublicOperator::wtype wtype;
PublicReader(const PublicOperator &opr): opr(opr)
{}
__device__ wtype read(uint32_t idx)
{ return opr.read(idx); }
};
/**
* read from workspace
*/
template <typename wtype>
struct WorkspaceReader {
wtype *workspace;
WorkspaceReader(wtype *workspace): workspace(workspace)
{}
__device__ wtype read(uint32_t idx)
{ return workspace[idx]; }
};
/**
* inherit from PublicOperator
*/
template <class PublicOperator>
struct PublicWriter {
PublicOperator opr;
typedef typename PublicOperator::wtype wtype;
PublicWriter(const PublicOperator &opr): opr(opr)
{}
__device__ void write(uint32_t idx, wtype value)
{ opr.write(idx, value); }
};
/**
* write to workspace
*/
template <typename wtype>
struct WorkspaceWriter {
wtype *workspace;
WorkspaceWriter(wtype *workspace): workspace(workspace)
{}
__device__ void write(uint32_t idx, wtype value)
{ workspace[idx] = value; }
};
/**
* \tparam PublicOperator
* must have typedef for wtype
* must have const static member wtype INIT
* must have method wtype read(uint32_t idx)
* must have method wtype apply(const wtype &, const wtype &)
* must have method void write(uint32_t idx, const wtype &)
*/
template <class PublicOperator, bool sync_within_warp>
void run_largeBC(typename PublicOperator::wtype *workspace,
size_t A, size_t B, size_t C,
cudaStream_t stream, const PublicOperator &opr)
{
typedef typename PublicOperator::wtype wtype;
using namespace reduce_intl;
ExecPolicy p(A, B, C);
if (p.nr_reduces == 1) {
PublicReader<PublicOperator> rdr(opr);
PublicWriter<PublicOperator> wtr(opr);
invoke_kernel<PublicOperator,
PublicReader<PublicOperator>,
PublicWriter<PublicOperator>,
wtype,
sync_within_warp>(p, opr, rdr, wtr, stream);
} else if (p.nr_reduces == 2) {
PublicReader<PublicOperator> rdr1(opr);
WorkspaceWriter<wtype> wtr1(workspace);
WorkspaceReader<wtype> rdr2(workspace);
PublicWriter<PublicOperator> wtr2(opr);
invoke_kernel<PublicOperator,
PublicReader<PublicOperator>,
WorkspaceWriter<wtype>,
wtype,
sync_within_warp>(p, opr, rdr1, wtr1, stream);
p = p.next();
invoke_kernel<PublicOperator,
WorkspaceReader<wtype>,
PublicWriter<PublicOperator>,
wtype,
sync_within_warp>(p, opr, rdr2, wtr2, stream);
} else {
wtype *workspace1 = workspace;
size_t B2 = DIVUP(B, p.factor);
wtype *workspace2 = workspace + A * B2 * C;
size_t nr_reduces = p.nr_reduces;
{
PublicReader<PublicOperator> rdr(opr);
WorkspaceWriter<wtype> wtr(workspace1);
invoke_kernel<PublicOperator,
PublicReader<PublicOperator>,
WorkspaceWriter<wtype>,
wtype,
sync_within_warp>(p, opr, rdr, wtr, stream);
}
p = p.next();
wtype *current = workspace1;
wtype *next = workspace2;
for (size_t i = 1; i < nr_reduces; ++i) {
WorkspaceReader<wtype> rdr(current);
if (i + 1 == nr_reduces) {
PublicWriter<PublicOperator> wtr(opr);
invoke_kernel<PublicOperator,
WorkspaceReader<wtype>,
PublicWriter<PublicOperator>,
wtype,
sync_within_warp>(p, opr, rdr, wtr, stream);
} else {
WorkspaceWriter<wtype> wtr(next);
invoke_kernel<PublicOperator,
WorkspaceReader<wtype>,
WorkspaceWriter<wtype>,
wtype,
sync_within_warp>(p, opr, rdr, wtr, stream);
}
std::swap(next, current);
p = p.next();
}
}
}
template <typename wtype>
size_t get_workspace_largeBC(size_t A, size_t B, size_t C)
{
using namespace reduce_intl;
ExecPolicy p(A, B, C);
if (p.nr_reduces == 1) {
// direct reduce
return 0;
} else if (p.nr_reduces == 2) {
// src->workspace->dst
size_t B2 = DIVUP(B, p.factor);
return sizeof(wtype) * A * B2 * C;
} else {
// src->workspace1->workspace2->dst
size_t B2 = DIVUP(B, p.factor);
size_t B3 = DIVUP(B2, p.factor);
return sizeof(wtype) * A * B2 * C + sizeof(wtype) * A * B3 * C;
}
}
} // namespace reduce_intl
} // namespace cuda
} // namespace megdnn
// vim: ft=cpp syntax=cpp.doxygen