/**
* \file dnn/src/cuda/rng/kernel.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <cuda_runtime_api.h>
#include <stdint.h>
#include <curand.h>
#include <curand_kernel.h>
#include "megdnn/dtype.h"
#include "src/cuda/elemwise_helper.cuh"
#include "src/cuda/utils.cuh"
#if MEGDNN_CC_HOST
#include "megdnn/oprs.h"
#endif
namespace megdnn {
namespace cuda {
namespace random {
using Philox = curandStatePhilox4_32_10_t;
QUALIFIERS float _curand_uniform(Philox* state) {
float r = curand_uniform(state);
if (r >= 1.0f) {
r = 0.0f;
}
return r;
}
template <typename ctype, typename = void>
struct RandomKernel;
template <typename ctype>
using enable_64bit = typename std::enable_if<
std::is_integral<ctype>::value && ((sizeof(ctype)) == 8)>::type;
template <typename ctype>
using enable_32bit = typename std::enable_if<
std::is_integral<ctype>::value && ((sizeof(ctype)) <= 4)>::type;
template <typename ctype>
struct RandomKernel<ctype, enable_64bit<ctype>> {
ctype* output;
uint64_t seed, offset;
uint64_t mask = static_cast<uint64_t>(std::numeric_limits<ctype>::max());
__device__ void operator()(uint32_t idx) {
Philox local_state;
curand_init(seed, idx, offset, &local_state);
uint4 rand = curand4(&local_state);
uint64_t val = (static_cast<uint64_t>(rand.x) << 32) | rand.y;
output[idx] = static_cast<ctype>(val & mask);
}
#if MEGDNN_CC_HOST
RandomKernel(const ctype* output, uint64_t seed, uint64_t offset)
: output{output}, seed{seed}, offset{offset} {}
#endif
};
template <typename ctype>
struct RandomKernel<ctype, enable_32bit<ctype>> {
ctype* output;
uint64_t seed, offset;
uint32_t mask = static_cast<uint32_t>(std::numeric_limits<ctype>::max());
__device__ void operator()(uint32_t idx) {
Philox local_state;
curand_init(seed, idx, offset, &local_state);
uint32_t val = curand(&local_state);
output[idx] = static_cast<ctype>(val & mask);
}
#if MEGDNN_CC_HOST
RandomKernel(const ctype* output, uint64_t seed, uint64_t offset)
: output{output}, seed{seed}, offset{offset} {}
#endif
};
template <typename ctype>
struct RangeKernel {
ctype* output;
__device__ void operator()(uint32_t idx) { output[idx] = static_cast<ctype>(idx); }
#if MEGDNN_CC_HOST
RangeKernel(const ctype* output) : output{output} {}
#endif
};
template <typename ctype_src, typename ctype_dst>
struct AsTypeKernel {
ctype_src* input;
ctype_dst* output;
using ctype_mask = typename std::conditional<
std::is_integral<ctype_dst>::value, ctype_dst, ctype_src>::type;
ctype_src mask = static_cast<ctype_src>(std::numeric_limits<ctype_mask>::max());
__device__ void operator()(uint32_t idx) {
output[idx] = static_cast<ctype_dst>(input[idx] & mask);
}
#if MEGDNN_CC_HOST
AsTypeKernel(const ctype_src* input, const ctype_dst* output)
: input{input}, output{output} {}
#endif
};
template <typename ctype>
struct GammaKernel {
ctype* output;
ctype* shape;
ctype* scale;
uint64_t seed, offset;
static __device__ float sample_gamma(float a, float b, Philox* state) {
float scale = b;
if (a <= 0)
return 0.f;
if (a < 1.0f) {
scale *= powf(_curand_uniform(state), 1.0f / a);
a += 1.0f;
}
float d = a - 1.0f / 3.0f;
float c = 1.0f / sqrtf(9.0f * d);
while (1) {
float x, y;
x = curand_normal(state);
y = 1.0f + c * x;
if (y <= 0)
continue;
float v = y * y * y;
float u = _curand_uniform(state);
float xx = x * x;
if ((u < 1.0f - 0.0331f * xx * xx) ||
logf(u) < 0.5f * xx + d * (1.0f - v + logf(v)))
return scale * d * v;
}
}
__device__ void operator()(uint32_t idx) {
Philox local_state;
curand_init(seed, idx, offset, &local_state);
float a = static_cast<float>(shape[idx]);
float b = static_cast<float>(scale[idx]);
output[idx] = static_cast<ctype>(sample_gamma(a, b, &local_state));
}
#if MEGDNN_CC_HOST
GammaKernel(
const TensorND& output, const TensorND& shape, const TensorND& scale,
uint64_t seed, uint64_t offset)
: output{output.ptr<ctype>()},
shape{shape.ptr<ctype>()},
scale{scale.ptr<ctype>()},
seed{seed},
offset{offset} {}
#endif
};
template <typename ctype>
struct PoissonKernel {
ctype* output;
ctype* lambda;
uint64_t seed, offset;
__device__ void operator()(uint32_t idx) {
Philox local_state;
curand_init(seed, idx, offset, &local_state);
float lam = static_cast<float>(lambda[idx]);
output[idx] = static_cast<ctype>(curand_poisson(&local_state, lam));
}
#if MEGDNN_CC_HOST
PoissonKernel(
const TensorND& output, const TensorND& lambda, uint64_t seed,
uint64_t offset)
: output{output.ptr<ctype>()},
lambda{lambda.ptr<ctype>()},
seed{seed},
offset{offset} {}
#endif
};
template <typename ctype>
struct BetaKernel {
ctype* output;
ctype* alpha;
ctype* beta;
uint64_t seed, offset;
__device__ void operator()(uint32_t idx) {
Philox local_state;
curand_init(seed, idx, offset, &local_state);
float a = static_cast<float>(alpha[idx]);
float b = static_cast<float>(beta[idx]);
if (a <= 0 || b <= 0) {
output[idx] = 0;
return;
}
if (a < 1.0f && b < 1.0f) {
float u, v, x, y;
while (true) {
u = _curand_uniform(&local_state);
v = _curand_uniform(&local_state);
x = powf(u, 1.0f / a);
y = powf(v, 1.0f / b);
if (x + y < 1.0f) {
if (x + y > 0) {
output[idx] = static_cast<ctype>(x / (x + y));
return;
} else {
float logx = logf(u) / a;
float logy = logf(v) / b;
float log_max = logx > logy ? logx : logy;
logx -= log_max;
logy -= log_max;
output[idx] = static_cast<ctype>(
exp(logx - log(exp(logx) + exp(logy))));
return;
}
}
}
} else {
float ga = GammaKernel<float>::sample_gamma(a, 1.0f, &local_state);
float gb = GammaKernel<float>::sample_gamma(b, 1.0f, &local_state);
output[idx] = static_cast<ctype>(ga / (ga + gb));
return;
}
}
#if MEGDNN_CC_HOST
BetaKernel(
const TensorND& output, const TensorND& alpha, const TensorND& beta,
uint64_t seed, uint64_t offset)
: output{output.ptr<ctype>()},
alpha{alpha.ptr<ctype>()},
beta{beta.ptr<ctype>()},
seed{seed},
offset{offset} {}
#endif
};
template <typename ctype>
void permutation_forward(
ctype* dst, void* workspace, size_t size, uint64_t seed, uint64_t offset,
cudaStream_t stream);
size_t get_permutation_workspace_in_bytes(size_t N);
template <typename T>
void shuffle_forward(
T* sptr, T* dptr, dt_int32* iptr, size_t len, size_t step, cudaStream_t stream);
template <typename T>
void shuffle_backward(
T* dptr, dt_int32* iptr, T* sptr, size_t len, size_t step, cudaStream_t stream);
#define ARGSORT_FOREACH_CTYPE(cb) cb(float) cb(int32_t) DNN_INC_FLOAT16(cb(dt_float16))
} // namespace random
} // namespace cuda
} // namespace megdnn