/**
* \file dnn/src/cuda/cumsum/kern_impl.cuinl
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./kern.cuh"
#include "./kern_helper.cuh"
#include "megdnn/dtype.h"
#include "src/cuda/cub/device/device_scan.cuh"
#include "src/cuda/cub/util_ptx.cuh"
namespace megdnn {
namespace cuda {
namespace cumsum {
namespace detail {
/**
* src shape is (A, B, C), performing blockwise scan over B axis.
* Each CUDA block calculates a blockwise scan result of size (BY2, BX).
* The block area corresponds to a 2-D area on (B, C) dimension of src.
*
* Per-block prefix sum is stored in dst (dst has the same shape as src).
*
* The whole scan result of each block as a single value is stored in
* block_sum (of shape (A, B/BY2, C)).
*
* block_sum can be NULL.
*
* src and dst can be inplace.
*
* We need to launch (C/BX)*ceil(B/BY2)*A blocks in total.
* Because in CUDA the number of launched blocks over y and z axis are
* limited (at most 65535), we launch all blocks over axis x.
*
* Param: exclusive
* This flag specifies whether the scan is inclusive or exclusive, namely
* whether src_i influences dst_i.
*
* Param: reverse:
* This flag specifies whether the scan is forward or backward.
*
* Example:
* !exclusive && !reverse: dst_i = op(src_0, src_1, ..., src_i)
* !exclusive && reverse: dst_i = op(src_i, src_{i+1}, ..., src_{n-1})
* exclusive && !reverse: dst_i = op(src_0, src_1, ..., src{i-1})
* exclusive && reverse: dst_i = op(src_{i+1}, src{i+2}, ..., src{n-1})
*
* Op should have the following methods:
* static T init()
* static T apply(T lhs, T rhs)
*/
template <typename T, typename Op, bool exclusive, bool reverse,
uint32_t BY, uint32_t BX>
__global__ void scan_kernel(T *dst, T *block_sum,
uint32_t A, uint32_t B, uint32_t C, const Op op) {
constexpr size_t warp_size = 32;
const uint32_t BY2 = BY*2;
const uint32_t B_ = (B+BY2-1) / BY2;
const uint32_t C_ = (C+BX-1) / BX;
const uint32_t GX = C_;
const uint32_t GY = B_;
// src, dst: (A, B, C)
// block_sum: (A, B_, C)
// shared: (BY2+1, BX)
const uint32_t bx = blockIdx.x % GX;
const uint32_t by = blockIdx.x / GX % GY;
const uint32_t bz = blockIdx.x / GX / GY;
const uint32_t tx = threadIdx.x;
const uint32_t ty = threadIdx.y;
// TODO: shared memory bank conflict optimization
#define shared_idx(x) ((x) + ((x) >> 5))
volatile __shared__ T cache[shared_idx((BY2+1)*BX)];
uint32_t base_offset = (bz)*B*C + (by*BY2)*C + (bx*BX);
dst += base_offset;
// load to cache
if (reverse) {
cache[shared_idx((BY2-ty)*BX+tx)] = ty+by*BY2 < B && tx+bx*BX < C ?
op.visit(base_offset + ty*C + tx) : Op::init();
} else {
cache[shared_idx((ty+1)*BX+tx)] = ty+by*BY2 < B && tx+bx*BX < C ?
op.visit(base_offset + ty*C + tx) : Op::init();
}
if (reverse) {
cache[shared_idx((BY-ty)*BX+tx)] =
(ty+BY) + by*BY2 < B && tx+bx*BX < C ?
op.visit(base_offset + (ty+BY)*C + tx) : Op::init();
} else {
cache[shared_idx((ty+BY+1)*BX+tx)] =
(ty+BY) + by*BY2 < B && tx+bx*BX < C ?
op.visit(base_offset + (ty+BY)*C + tx) : Op::init();
}
if (ty == 0) {
cache[shared_idx(tx)] = Op::init();
}
__syncthreads();
uint32_t total, stride;
// first pass
#pragma unroll
for (total = BY, stride = 1;
total > 0;
total >>= 1, stride <<= 1)
{
if (ty < total) {
uint32_t ai = shared_idx(stride * (2*ty+1) * BX + tx);
uint32_t bi = shared_idx(stride * (2*ty+2) * BX + tx);
cache[bi] = Op::apply(cache[bi], cache[ai]);
}
if (total > warp_size/BX) __syncthreads();
else cub::WARP_SYNC(0xffffffff);
}
// second pass
#pragma unroll
for (total = 1, stride = BY;
stride > 0;
total <<= 1, stride >>= 1)
{
if (total > warp_size/BX) __syncthreads();
else cub::WARP_SYNC(0xffffffff);
if (ty < total) {
uint32_t ai = shared_idx(stride * (2*ty+0) * BX + tx);
uint32_t bi = shared_idx(stride * (2*ty+1) * BX + tx);
cache[bi] = Op::apply(cache[bi], cache[ai]);
}
}
__syncthreads();
uint32_t ty_offset = (exclusive ? 0 : 1);
if (ty+by*BY2 < B && tx+bx*BX < C) {
if (reverse) {
dst[ty*C + tx] = cache[shared_idx((BY2-1-ty+ty_offset)*BX + tx)];
} else {
dst[ty*C + tx] = cache[shared_idx((ty+ty_offset)*BX + tx)];
}
}
if (ty+BY+by*BY2 < B && tx+bx*BX < C) {
if (reverse) {
dst[(ty+BY)*C + tx] =
cache[shared_idx((BY2-1-(ty+BY)+ty_offset)*BX + tx)];
} else {
dst[(ty+BY)*C + tx] =
cache[shared_idx((ty+BY+ty_offset)*BX + tx)];
}
}
if (block_sum && ty == 0 && bx*BX+tx < C) {
block_sum[(bz)*B_*C + (by)*C + (bx*BX) + tx] =
cache[shared_idx(BY2*BX + tx)];
}
}
template <typename T, typename Op, uint32_t BY, uint32_t BX>
__global__ void update_kernel(T *dst, const T *delta,
uint32_t A, uint32_t B, uint32_t C) {
const uint32_t BY2 = BY*2;
const uint32_t B_ = (B+BY2-1) / BY2;
const uint32_t C_ = (C+BX-1) / BX;
const uint32_t GX = C_;
const uint32_t GY = B_;
// src: (A, B, C)
// delta: (A, B_, C)
const uint32_t bx = blockIdx.x % GX;
const uint32_t by = blockIdx.x / GX % GY;
const uint32_t bz = blockIdx.x / GX / GY;
const uint32_t tx = threadIdx.x;
const uint32_t ty = threadIdx.y;
if (tx + bx*BX < C) {
T delta_v = delta[(bz)*B_*C + (by)*C + (bx*BX) + tx];
if (ty+by*BY2 < B && tx+bx*BX < C) {
T &res = dst[bz*B*C + (ty+by*BY2)*C + (tx+bx*BX)];
res = Op::apply(res, delta_v);
}
if (ty+BY+by*BY2 < B && tx+bx*BX < C) {
T &res = dst[bz*B*C + (ty+BY+by*BY2)*C + (tx+bx*BX)];
res = Op::apply(res, delta_v);
}
}
}
template <typename T, typename Op, bool exclusive, bool reverse>
void run_kern_multiAC(T* dst, T* workspace, uint32_t A, uint32_t B,
uint32_t C, const Op& op, cudaStream_t stream);
template <typename T, typename Op, bool exclusive, bool reverse,
uint32_t BX, uint32_t BY>
void do_run_kern(T *dst, T *workspace,
uint32_t A, uint32_t B, uint32_t C, const Op &op, cudaStream_t stream) {
const uint32_t BY2 = BY*2;
const uint32_t B_ = (B+BY2-1)/BY2;
const uint32_t C_ = (C+BX-1)/BX;
dim3 blocks(C_*B_*A);
dim3 threads(BX, BY);
scan_kernel<T, Op, exclusive, reverse, BY, BX>
<<<blocks, threads, 0, stream>>>(
dst, B > BY2 ? workspace : NULL, A, B, C, op);
if (B <= BY2)
return;
run_kern_multiAC<T, typename Op::ContigOp, true, reverse>(
workspace, workspace + A*B_*C, A, B_, C,
Op::make_contig(workspace), stream);
update_kernel<T, Op, BY, BX><<<blocks, threads, 0, stream>>>(
dst, workspace, A, B, C);
}
template <typename T, typename Op, bool exclusive, bool reverse>
void run_kern_multiAC(T* dst, T* workspace, uint32_t A, uint32_t B, uint32_t C,
const Op& op, cudaStream_t stream) {
#define IF(BX, BY) \
do { \
if (vBX == BX && vBY == BY) { \
return do_run_kern<T, Op, exclusive, reverse, BX, BY>( \
dst, workspace, A, B, C, op, stream); \
} \
} while (0)
uint32_t vBX, vBY;
get_BX_BY(A, B, C, vBX, vBY);
IF(1, 512);
IF(2, 256);
IF(4, 128);
IF(8, 64);
IF(16, 32);
IF(32, 16);
megdnn_trap();
#undef IF
}
//! wrap cub library for 1-dim scan
namespace cubwrap {
template <typename T, typename Op, bool reverse>
class InputIterator : public std::iterator<std::random_access_iterator_tag, T> {
int m_offset, m_len;
Op m_op;
public:
InputIterator(Op op, int len) : m_offset(0), m_len(len), m_op(op) {}
__device__ InputIterator(int offset, int len, Op op)
: m_offset(offset), m_len(len), m_op(op) {}
__device__ T operator[](int idx) {
idx += m_offset;
if (reverse) {
idx = m_len - 1 - idx;
}
return m_op.visit(idx);
}
__device__ InputIterator operator+(int offset) {
return InputIterator(m_offset + offset, m_len, m_op);
}
};
template <typename T, bool reverse>
class OutputIterator
: public std::iterator<std::random_access_iterator_tag, T> {
int m_offset, m_len;
T* m_dst;
public:
OutputIterator(T* dst, int len) : m_offset(0), m_len(len), m_dst(dst) {}
__device__ OutputIterator(int offset, int len, T* dst)
: m_offset(offset), m_len(len), m_dst(dst) {}
__device__ T& operator[](int idx) {
idx += m_offset;
if (reverse) {
idx = m_len - 1 - idx;
}
return m_dst[idx];
}
__device__ OutputIterator operator+(int offset) {
return OutputIterator(m_offset + offset, m_len, m_dst);
}
};
template <typename T, typename Op>
struct ScanOp {
__device__ __host__ T operator()(T a, T b) {
// cub requires it to be a __device__ __host__ function but MegDNN has
// no such contraint on Op::apply; so we just trap on host
#ifdef __CUDA_ARCH__
return Op::apply(a, b);
#else
megdnn_trap();
#endif
}
};
template <typename T, typename Op, bool exclusive, bool reverse>
void invoke(T* dst, void* workspace, size_t wk_size, const Op& op, uint32_t len,
cudaStream_t stream) {
InputIterator<T, Op, reverse> inp_iter(op, len);
OutputIterator<T, reverse> out_iter(dst, len);
ScanOp<T, Op> scan_op;
if (exclusive) {
cuda_check(cub::DeviceScan::ExclusiveScan(workspace, wk_size, inp_iter,
out_iter, scan_op, Op::init(),
len, stream));
} else {
cuda_check(cub::DeviceScan::InclusiveScan(
workspace, wk_size, inp_iter, out_iter, scan_op, len, stream));
}
}
} // namespace cubwrap
} // namespace detail
template <typename T, typename Op, bool exclusive, bool reverse>
void run_kern(T* dst, void* workspace, uint32_t workspace_size, uint32_t A,
uint32_t B, uint32_t C, const Op& op, cudaStream_t stream) {
if (A == 1 && C == 1) {
return detail::cubwrap::invoke<T, Op, exclusive, reverse>(
dst, workspace, workspace_size, op, B, stream);
}
return detail::run_kern_multiAC<T, Op, exclusive, reverse>(
dst, static_cast<T*>(workspace), A, B, C, op, stream);
}
} // namespace cumsum
} // namespace cuda
} // namespace megdnn
// vim: ft=cuda syntax=cuda.doxygen