#pragma once
#include "megbrain/graph.h"
#include "megbrain/opr/internal/identical_fwd.h"
namespace mgb {
namespace opr {
namespace intl {
MGB_DEFINE_CLS_WITH_SUPER(LockBase, ForwardInputToOutput) public:
enum Action { ACQUIRE, RELEASE };
struct LockParam {
size_t lock_id, group_id;
};
LockBase(
const OperatorNodeBaseCtorParam& opr_param, VarNode* var,
const LockParam& param, Action action);
~LockBase();
private:
struct LockPool;
struct LockGroup;
class LockGroupSet;
static LockPool sm_lock_pool;
const LockParam m_param;
Action m_action;
LockGroup* m_cur_group = nullptr;
void add_input_layout_constraint() override;
void scn_do_execute_finish(const DeviceTensorND& val) override;
};
template <typename Opr>
MGB_DEFINE_CLS_WITH_SUPER(LockMaker, LockBase) protected:
using Super::Super;
public:
static SymbolVar make(
SymbolVar var, const LockParam& param,
const OperatorNodeConfig& config = {});
};
}
MGB_DEFINE_OPR_CLASS(LockAcquire, intl::LockMaker<LockAcquire>) public:
LockAcquire(VarNode* var, const LockParam& param, const OperatorNodeConfig& config);
};
MGB_DEFINE_OPR_CLASS(LockRelease, intl::LockMaker<LockRelease>) public:
LockRelease(VarNode* var, const LockParam& param, const OperatorNodeConfig& config);
};
} }