#include "megbrain/imperative/ops/rng.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/graph/helper.h"
#include "megbrain/opr/rand.h"
#include "../dnn_op_helper.h"
#include "../op_trait.h"
namespace mgb::imperative::rng {
namespace {
template <typename HandleFactory, typename THandle>
class DnnOpManagerT : public CompNodeDepedentObject, public NonCopyableObj {
public:
using DT = CompNode::DeviceType;
using Handle = THandle;
using OpTypeInfo = size_t;
template <typename... Args>
Handle new_handle(Args&&... args) {
return static_cast<HandleFactory*>(this)->do_new_handle(
std::forward<Args>(args)...);
}
size_t delete_handle(Handle handle) {
size_t removed = 0;
if (!is_finalized()) {
MGB_LOCK_GUARD(m_mtx);
removed = m_handle2ops.erase(handle);
}
static_cast<HandleFactory*>(this)->do_delete_handle(handle);
return removed;
}
template <typename DnnOp>
auto get_dnn_op(Handle handle, OpTypeInfo tpinfo, CompNode cn) {
mgb_assert(!is_finalized());
DnnOpWithMutex* dnn_op_with_mtx;
{
MGB_LOCK_GUARD(m_mtx);
dnn_op_with_mtx = &m_handle2ops[handle][tpinfo];
}
auto dnn_handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle();
std::unique_lock<std::mutex> lock(dnn_op_with_mtx->mtx);
bool initialized = false;
DnnOp* dnn_op = static_cast<DnnOp*>(dnn_op_with_mtx->op.get());
if (dnn_op != nullptr) {
mgb_assert(dnn_op->handle() == dnn_handle);
initialized = true;
} else {
auto new_op = dnn_handle->create_operator<DnnOp>();
dnn_op = new_op.get();
dnn_op_with_mtx->op = std::move(new_op);
}
return std::make_tuple(initialized, dnn_op, std::move(lock));
}
protected:
using DnnOpManagerBase = DnnOpManagerT<HandleFactory, Handle>;
DnnOpManagerT() = default;
private:
struct DnnOpWithMutex {
std::mutex mtx;
std::unique_ptr<megdnn::OperatorBase> op;
DnnOpWithMutex() : op{nullptr} {}
};
std::shared_ptr<void> on_comp_node_finalize() override {
MGB_LOCK_GUARD(m_mtx);
m_handle2ops.clear();
return {};
}
std::unordered_map<Handle, std::unordered_map<OpTypeInfo, DnnOpWithMutex>>
m_handle2ops;
std::mutex m_mtx;
};
class RNGDnnOpManager final : public DnnOpManagerT<RNGDnnOpManager, Handle> {
public:
Handle new_handle(CompNode comp_node, uint64_t seed) {
MGB_LOCK_GUARD(sm_mtx);
return DnnOpManagerBase::new_handle(comp_node, seed);
}
size_t delete_handle(Handle handle) {
MGB_LOCK_GUARD(sm_mtx);
return DnnOpManagerBase::delete_handle(handle);
}
Handle do_new_handle(CompNode comp_node, uint64_t seed) {
auto handle = m_handle_pool.alloc(comp_node, seed);
return reinterpret_cast<Handle>(handle);
}
void do_delete_handle(Handle handle) {
m_handle_pool.free(reinterpret_cast<HandleData*>(handle));
}
static uint64_t get_seed(Handle handle) {
if (!handle) {
return glob_default_seed;
}
return reinterpret_cast<HandleData*>(handle)->seed;
}
static CompNode get_comp_node(Handle handle) {
mgb_assert(handle, "invalid handle");
return reinterpret_cast<HandleData*>(handle)->comp_node;
}
static Handle get_default_handle(CompNode comp_node) {
mgb_assert(comp_node.valid());
MGB_LOCK_GUARD(sm_mtx);
auto&& glob_handle = glob_default_handles[comp_node];
if (!glob_handle) {
glob_handle = inst().do_new_handle(comp_node, glob_default_seed);
}
mgb_assert(get_seed(glob_handle) == glob_default_seed);
return glob_handle;
}
static RNGDnnOpManager& inst() {
static RNGDnnOpManager mgr;
return mgr;
}
static void set_glob_default_seed(uint64_t seed) {
MGB_LOCK_GUARD(sm_mtx);
for (auto&& elem : glob_default_handles) {
mgb_assert(elem.first.valid());
if (elem.second) {
inst().DnnOpManagerBase::delete_handle(elem.second);
}
elem.second = inst().do_new_handle(elem.first, seed);
}
glob_default_seed = seed;
}
static uint64_t get_glob_default_seed() {
MGB_LOCK_GUARD(sm_mtx);
return glob_default_seed;
}
private:
struct HandleData {
CompNode comp_node;
uint64_t seed;
HandleData(CompNode cn, uint64_t seed) : comp_node(cn), seed(seed) {}
};
MemPool<HandleData> m_handle_pool;
static std::mutex sm_mtx;
static CompNode::UnorderedMap<Handle> glob_default_handles;
static uint64_t glob_default_seed;
};
uint64_t RNGDnnOpManager::glob_default_seed = 0;
std::mutex RNGDnnOpManager::sm_mtx;
CompNode::UnorderedMap<Handle> RNGDnnOpManager::glob_default_handles;
template <typename Op>
struct OpMeth;
template <>
struct OpMeth<UniformRNG> {
using DnnOp = megdnn::UniformRNG;
using Param = DnnOp::Param;
using OpNode = mgb::opr::UniformRNG;
static Param make_param(const UniformRNG& rng) {
auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
mgb_assert(
handle_seed == rng.seed,
"inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
rng.seed);
return {handle_seed, rng.dtype.enumv()};
}
};
template <>
struct OpMeth<PoissonRNG> {
using DnnOp = megdnn::PoissonRNG;
using Param = DnnOp::Param;
using OpNode = mgb::opr::PoissonRNG;
static Param make_param(const PoissonRNG& rng) {
auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
mgb_assert(
handle_seed == rng.seed,
"inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
rng.seed);
return {handle_seed};
}
};
template <>
struct OpMeth<GaussianRNG> {
using DnnOp = megdnn::GaussianRNG;
using Param = DnnOp::Param;
using OpNode = mgb::opr::GaussianRNG;
static Param make_param(const GaussianRNG& rng) {
auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
mgb_assert(
handle_seed == rng.seed,
"inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
rng.seed);
return {handle_seed, rng.mean, rng.std, rng.dtype.enumv()};
}
};
template <>
struct OpMeth<GammaRNG> {
using DnnOp = megdnn::GammaRNG;
using Param = DnnOp::Param;
using OpNode = mgb::opr::GammaRNG;
static Param make_param(const GammaRNG& rng) {
auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
mgb_assert(
handle_seed == rng.seed,
"inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
rng.seed);
return {handle_seed};
}
};
template <>
struct OpMeth<PermutationRNG> {
using DnnOp = megdnn::PermutationRNG;
using Param = DnnOp::Param;
using OpNode = mgb::opr::PermutationRNG;
static Param make_param(const PermutationRNG& rng) {
auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
mgb_assert(
handle_seed == rng.seed,
"inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
rng.seed);
return {handle_seed, rng.dtype.enumv()};
}
};
template <>
struct OpMeth<BetaRNG> {
using DnnOp = megdnn::BetaRNG;
using Param = DnnOp::Param;
using OpNode = mgb::opr::BetaRNG;
static Param make_param(const BetaRNG& rng) {
auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
mgb_assert(
handle_seed == rng.seed,
"inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
rng.seed);
return {handle_seed};
}
};
template <>
struct OpMeth<ShuffleRNG> {
using DnnOp = megdnn::ShuffleRNG;
using Param = DnnOp::Param;
using OpNode = mgb::opr::ShuffleRNG;
static Param make_param(const ShuffleRNG& rng) {
auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
mgb_assert(
handle_seed == rng.seed,
"inconsistent rng seed: rng op: %lu handle: %lu", handle_seed,
rng.seed);
return {handle_seed};
}
};
template <>
struct OpMeth<Dropout> {
using DnnOp = megdnn::Dropout;
using Param = DnnOp::Param;
using OpNode = mgb::opr::Dropout;
static Param make_param(const Dropout& opdef) {
auto handle_seed = RNGDnnOpManager::get_seed(opdef.handle);
mgb_assert(
handle_seed == opdef.seed,
"inconsistent dropout seed: dropout op: %lu handle: %lu", handle_seed,
opdef.seed);
return {opdef.drop_prob, handle_seed};
}
};
template <bool>
struct _InferLayout;
template <int nr_in>
struct _RNGOprMaker;
template <int nr_in, int nr_out>
struct _RNGOprInvoker;
template <>
struct _InferLayout<true> {
template <typename Op>
static TensorLayout do_infer(const TensorPtr& inp, const Op& rng) {
TensorShape tshape;
auto hv = inp->get_value().proxy_to_default_cpu();
cg::copy_tensor_value_to_shape(tshape, hv);
return TensorLayout(tshape, rng.dtype);
}
template <typename Op>
static TensorLayout do_infer(const LogicalTensorDesc& inp, const Op& rng) {
TensorLayout out_layout = inp.layout;
out_layout.dtype = rng.dtype;
if (inp.layout.ndim == 0 || inp.value.empty()) {
out_layout.ndim = 0;
return out_layout;
}
mgb_assert(
inp.layout.ndim == 1,
"target shape of %s expects ndim=1; got ndim=%lu actually",
rng.dyn_typeinfo()->name, inp.layout.ndim);
size_t target_ndim = inp.layout.shape[0];
out_layout.ndim = target_ndim;
auto* ptr = inp.value.ptr<dt_int32>();
for (size_t i = 0; i < target_ndim; ++i) {
out_layout.shape[i] = ptr[i];
}
out_layout.init_contiguous_stride();
return out_layout;
}
};
template <>
struct _InferLayout<false> {
template <typename Op>
static TensorLayout do_infer(const TensorPtr& inp, const Op& rng) {
return inp->layout();
}
template <typename Op>
static TensorLayout do_infer(const LogicalTensorDesc& inp, const Op& rng) {
mgb_assert(inp.layout.ndim);
return inp.layout;
}
};
#define _INST_RNG_INVOLKER(DNN_NR_INPUTS, DNN_NR_OUTPUTS) \
template <> \
struct _RNGOprInvoker<DNN_NR_INPUTS, DNN_NR_OUTPUTS> { \
template <typename Opr> \
static void exec( \
Opr* dnn_op, const SmallVector<TensorPtr>& inputs, \
const SmallVector<TensorPtr>& outputs) { \
size_t wk_size = 0; \
wk_size = dnn_op->get_workspace_in_bytes( \
_FOR_EACH_IN(->layout()) _FOR_EACH_OUT(->layout())); \
auto workspace = Blob::make(outputs[0]->comp_node(), wk_size); \
megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size); \
dnn_op->exec( \
_FOR_EACH_IN(->dev_tensor().as_megdnn()) \
_FOR_EACH_OUT(->dev_tensor().as_megdnn()), \
dnn_wk); \
} \
};
#define _INST_RNG_MAKER(MGB_NR_INPUTS) \
template <> \
struct _RNGOprMaker<MGB_NR_INPUTS> { \
template <typename Op> \
static auto make(const VarNodeArray& inputs, const Op& rng) { \
auto param = OpMeth<Op>::make_param(rng); \
OperatorNodeConfig config; \
if (rng.handle) { \
config = { \
rng.make_name(), RNGDnnOpManager::get_comp_node(rng.handle)}; \
} else { \
config = {rng.make_name()}; \
} \
return OpMeth<Op>::OpNode::make(_FOR_EACH_IN() param, config); \
} \
};
#define _FOR_EACH_IN(subfix)
#define _FOR_EACH_OUT(subfix) outputs[0] subfix
_INST_RNG_INVOLKER(0, 1)
#undef _FOR_EACH_OUT
#undef _FOR_EACH_IN
#define _FOR_EACH_IN(subfix) inputs[0] subfix,
#define _FOR_EACH_OUT(subfix) outputs[0] subfix
_INST_RNG_INVOLKER(1, 1)
#undef _FOR_EACH_OUT
#define _FOR_EACH_OUT(subfix) outputs[0] subfix, outputs[1] subfix
_INST_RNG_INVOLKER(1, 2)
_INST_RNG_MAKER(1)
#undef _FOR_EACH_OUT
#undef _FOR_EACH_IN
#define _FOR_EACH_IN(subfix) inputs[0] subfix, inputs[1] subfix,
#define _FOR_EACH_OUT(subfix) outputs[0] subfix
_INST_RNG_INVOLKER(2, 1)
_INST_RNG_MAKER(2)
#undef _FOR_EACH_OUT
#undef _FOR_EACH_IN
#undef _INST_RNG_INVOLKER
#undef _INST_RNG_MAKER
template <typename Op>
void exec(
const OpDef& op, const SmallVector<TensorPtr>& inputs,
const SmallVector<TensorPtr>& outputs,
const SmallVector<TensorPtr>& workspace) {
auto&& rng = op.cast_final_safe<Op>();
auto dest = outputs[0];
if (dest->layout().is_empty())
return;
auto cn = dest->comp_node();
auto handle = rng.handle;
if (!handle) {
handle = RNGDnnOpManager::get_default_handle(cn);
}
auto dnn_op_thread_safe =
RNGDnnOpManager::inst().get_dnn_op<typename OpMeth<Op>::DnnOp>(
handle, reinterpret_cast<size_t>(op.dyn_typeinfo()), cn);
auto initialized = std::get<0>(dnn_op_thread_safe);
auto dnn_op = std::get<1>(dnn_op_thread_safe);
if (initialized) {
auto handle_seed = RNGDnnOpManager::get_seed(handle);
mgb_assert(
dnn_op->param().seed == handle_seed,
"inconsistent rng seed: handle: %lu, dnn_op: %lu", handle_seed,
dnn_op->param().seed);
}
dnn_op->param() = OpMeth<Op>::make_param(rng);
_RNGOprInvoker<OpMeth<Op>::DnnOp::NR_INPUTS, OpMeth<Op>::DnnOp::NR_OUTPUTS>::exec(
dnn_op, inputs, outputs);
}
template <typename Op>
SmallVector<LogicalTensorDesc> infer_output_attrs(
const OpDef& op, const SmallVector<TensorPtr>& inputs) {
LogicalTensorDesc dest;
auto&& rng = op.cast_final_safe<Op>();
auto handle = rng.handle;
if (handle) {
dest.comp_node = RNGDnnOpManager::get_comp_node(handle);
} else {
dest.comp_node = inputs[0]->comp_node();
}
constexpr bool rng_with_shape = OpMeth<Op>::DnnOp::NR_INPUTS == 0;
if (!rng_with_shape) {
for (int i = 0; i < inputs.size(); ++i) {
mgb_assert(
inputs[i]->comp_node() == dest.comp_node,
"%s expects the device of inputs[%d] to be same as the device of "
"handle; "
"got %s and %s actually",
rng.dyn_typeinfo()->name, i,
inputs[i]->comp_node().to_string().c_str(),
dest.comp_node.to_string().c_str());
}
}
dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], rng);
return {dest};
}
template <>
SmallVector<LogicalTensorDesc> infer_output_attrs<ShuffleRNG>(
const OpDef& op, const SmallVector<TensorPtr>& inputs) {
SmallVector<LogicalTensorDesc> dests(2);
auto&& rng = op.cast_final_safe<ShuffleRNG>();
auto handle = rng.handle;
if (handle) {
dests[0].comp_node = RNGDnnOpManager::get_comp_node(handle);
dests[1].comp_node = RNGDnnOpManager::get_comp_node(handle);
} else {
dests[0].comp_node = inputs[0]->comp_node();
dests[1].comp_node = inputs[0]->comp_node();
}
dests[0].layout = TensorLayout(inputs[0]->layout());
dests[0].layout.dtype = inputs[0]->layout().dtype;
dests[1].layout =
TensorLayout(TensorShape({inputs[0]->layout()[0]}), dtype::Int32());
return dests;
}
template <>
SmallVector<LogicalTensorDesc> infer_output_attrs<Dropout>(
const OpDef& op, const SmallVector<TensorPtr>& inputs) {
SmallVector<LogicalTensorDesc> dests(2);
auto&& cn = inputs[0]->comp_node();
dests[0].comp_node = cn;
dests[0].layout = TensorLayout(inputs[0]->layout());
dests[0].layout.dtype = inputs[0]->layout().dtype;
auto get_mask_size = [&]() -> size_t {
auto dnn_handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle();
return dnn_handle->create_operator<megdnn::Dropout>()->get_mask_size_in_bytes(
inputs[0]->layout());
};
dests[1].comp_node = cn;
dests[1].layout = TensorLayout(TensorShape({get_mask_size()}), dtype::Byte());
return dests;
}
template <typename Op>
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
SmallVector<TensorPtr> outputs;
SmallVector<LogicalTensorDesc> desc = infer_output_attrs<Op>(def, inputs);
for (auto&& i : desc) {
outputs.push_back(Tensor::make(i.layout, i.comp_node));
}
exec<Op>(def, inputs, outputs, {});
return outputs;
}
template <typename Op, typename Output>
Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
size_t nr_inp = inputs.size();
constexpr size_t dnn_nr_inp = OpMeth<Op>::DnnOp::NR_INPUTS;
auto&& rng = def.cast_final_safe<Op>();
if (dnn_nr_inp == 0) {
mgb_assert(
nr_inp == 1, "%s expects 1 inputs; got %lu actually",
rng.dyn_typeinfo()->name, nr_inp);
}
constexpr size_t mgb_nr_inp = dnn_nr_inp + !dnn_nr_inp;
return _RNGOprMaker<mgb_nr_inp>::make(inputs, rng);
}
template <typename Op>
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
LogicalTensorDesc dest;
auto&& xxx_rng_def = def.cast_final_safe<Op>();
size_t nr_inp = inputs.size();
constexpr bool rng_with_shape = OpMeth<Op>::DnnOp::NR_INPUTS == 0;
if (rng_with_shape) {
mgb_assert(
nr_inp == 1, "%s expects 1 inputs; got %lu actually",
xxx_rng_def.dyn_typeinfo()->name, nr_inp);
}
dest.comp_node = inputs[0].comp_node;
dest.layout = _InferLayout<rng_with_shape>::do_infer(inputs[0], xxx_rng_def);
return {{dest}, inputs[0].layout.ndim != 0};
}
template <>
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<
ShuffleRNG>(const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
bool success = inputs[0].layout.ndim != 0;
SmallVector<LogicalTensorDesc> dests(2);
dests[0].comp_node = inputs[0].comp_node;
dests[0].layout = TensorLayout(inputs[0].layout);
dests[0].layout.dtype = inputs[0].layout.dtype;
dests[1].comp_node = inputs[0].comp_node;
if (success) {
dests[1].layout =
TensorLayout(TensorShape({inputs[0].layout.shape[0]}), dtype::Int32());
} else {
dests[1].layout = TensorLayout(dtype::Int32());
}
return {dests, success};
}
template <>
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<Dropout>(
const OpDef& op, const SmallVector<LogicalTensorDesc>& inputs) {
bool success = inputs[0].layout.ndim != 0;
SmallVector<LogicalTensorDesc> dests(2);
auto cn = inputs[0].comp_node;
dests[0].comp_node = cn;
dests[0].layout = TensorLayout(inputs[0].layout);
dests[0].layout.dtype = inputs[0].layout.dtype;
auto get_mask_size = [&]() -> size_t {
auto dnn_handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle();
return dnn_handle->create_operator<megdnn::Dropout>()->get_mask_size_in_bytes(
inputs[0].layout);
};
dests[1].comp_node = cn;
if (success) {
dests[1].layout = TensorLayout(TensorShape({get_mask_size()}), dtype::Byte());
} else {
dests[1].layout = TensorLayout(dtype::Byte());
}
return {dests, success};
}
template <typename Op>
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size());
return layout_checker;
}
}
Handle new_handle(CompNode comp_node, uint64_t seed) {
return RNGDnnOpManager::inst().new_handle(comp_node, seed);
}
size_t delete_handle(Handle handle) {
return RNGDnnOpManager::inst().delete_handle(handle);
}
void set_global_rng_seed(uint64_t seed) {
RNGDnnOpManager::set_glob_default_seed(seed);
}
uint64_t get_global_rng_seed() {
return RNGDnnOpManager::get_glob_default_seed();
}
CompNode get_rng_handle_compnode(Handle handle) {
return RNGDnnOpManager::get_comp_node(handle);
}
#define REG_RNG_OP(NAME, Output) \
namespace { \
OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \
.apply_on_var_node(apply_on_var_node<NAME, Output>) \
.apply_on_physical_tensor(apply_on_physical_tensor<NAME>) \
.infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \
.get_input_layout_constraint(get_input_layout_constraint<NAME>) \
.fallback(); \
}
REG_RNG_OP(UniformRNG, SymbolVar)
REG_RNG_OP(GaussianRNG, SymbolVar)
REG_RNG_OP(GammaRNG, SymbolVar)
REG_RNG_OP(PermutationRNG, SymbolVar)
REG_RNG_OP(PoissonRNG, SymbolVar)
REG_RNG_OP(BetaRNG, SymbolVar)
REG_RNG_OP(ShuffleRNG, SymbolVarArray)
REG_RNG_OP(Dropout, SymbolVarArray)
#undef REG_RNG_OP
}