#include <limits>
#include "./kern.cuh"
#include "src/common/cond_take/predicate.cuh"
#include "src/cuda/cumsum/kern_impl.cuinl"
#include "src/cuda/query_blocksize.cuh"
using namespace megdnn;
using namespace megdnn::cond_take;
using namespace megdnn::cuda::cond_take;
namespace {
template <uint32_t mode, typename T>
struct IdxGetter {
typedef ::megdnn::cuda::cumsum::SumOp<IdxType> ContigOp;
const T* data;
Pred<mode, T> pred;
IdxGetter(const T* d, const ::megdnn::cond_take::KParam& p) : data(d), pred(p) {}
__host__ __device__ static IdxType init() { return 0; }
__device__ static IdxType apply(IdxType lhs, IdxType rhs) { return lhs + rhs; }
__device__ IdxType visit(uint32_t idx) const { return pred(data[idx]); }
static ContigOp make_contig(const IdxType* data) { return ContigOp(data); }
};
template <typename T>
__global__ void copy_kern(
T* dest_data, IdxType* dest_idx, const T* src_data, const IdxType* src_idx,
uint32_t size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < size && src_idx[tid] > src_idx[tid - 1]) {
uint32_t v = src_idx[tid] - 1;
dest_data[v] = src_data[tid];
dest_idx[v] = tid;
}
}
__global__ void set_zero(IdxType* dest) {
dest[0] = 0;
}
}
template <typename T>
size_t cuda::cond_take::gen_idx(
void* workspace, size_t workspace_size, IdxType* dest_idx, const T* mask,
size_t size, uint32_t mode, const KParam& kparam, cudaStream_t stream) {
switch (mode) {
#define cb(_m) \
case PEnum::_m: { \
typedef IdxGetter<PEnum::_m, T> Op; \
cuda::cumsum::run_kern<IdxType, Op, false, false>( \
dest_idx + 1, workspace, workspace_size, 1, size, 1, Op(mask, kparam), \
stream); \
break; \
}
MEGDNN_FOREACH_COND_TAKE_MODE(cb)
#undef cb
default:
megdnn_trap();
}
IdxType host_sum_size;
cuda_check(cudaMemcpyAsync(
&host_sum_size, dest_idx + size, sizeof(IdxType), cudaMemcpyDeviceToHost,
stream));
cuda_check(cudaStreamSynchronize(stream));
return host_sum_size;
}
template <typename T>
void cuda::cond_take::copy_output(
T* dest_data, IdxType* dest_idx, const T* src_data, IdxType* src_idx,
uint32_t size, cudaStream_t stream) {
int nr_thread = query_blocksize_for_kernel(copy_kern<T>);
int nr_block = DIVUP(size, nr_thread);
set_zero<<<1, 1, 0, stream>>>(src_idx);
copy_kern<T><<<nr_block, nr_thread, 0, stream>>>(
dest_data, dest_idx, src_data, src_idx + 1, size);
after_kernel_launch();
}
namespace megdnn {
namespace cuda {
namespace cond_take {
#define inst_genidx(dt) \
template size_t gen_idx( \
void*, size_t, IdxType*, const DTypeTrait<dt>::ctype*, size_t, uint32_t, \
const KParam&, cudaStream_t);
#define inst_copy_(ct) \
template void copy_output( \
ct*, IdxType*, const ct*, IdxType*, uint32_t, cudaStream_t);
#define inst_copy(dt) inst_copy_(DTypeTrait<dt>::ctype)
} } }