#pragma once
#include "megbrain/comp_node.h"
#include "megbrain/graph/cg.h"
#include "megbrain/graph/grad_impl.h"
#include "megbrain/imperative.h"
#include "megbrain/imperative/ops/backward_graph.h"
namespace mgb {
namespace imperative {
class ProxyGraph : public NonCopyableObj {
public:
static ProxyGraph* get_default_graph();
static std::unique_ptr<MegBrainError> get_async_error() {
return std::move(tm_async_error);
}
EncodedSubgraph make_backward_graph(
const OpDef& opdef, const SmallVector<LogicalTensorDesc>& input_descs,
const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad);
private:
ProxyGraph();
class ProxyGraphImpl;
class StaticInferManager;
class SeqCompNodeOptimizer;
class InputPlaceholder;
struct ProxyGraphInst;
class CurOprGuard;
void reset();
void cleanup();
cg::VarNodeArray make_input_place_holders(
const SmallVector<LogicalTensorDesc>& inputs);
TensorPtr as_tensor(cg::OperatorNodeBase* opr, bool share = true);
cg::OperatorNodeBase* m_cur_opr = nullptr;
std::unique_ptr<ProxyGraphImpl> m_graph;
size_t m_max_op_cnt = 100;
std::unique_ptr<StaticInferManager> m_static_infer_manager;
std::unique_ptr<SeqCompNodeOptimizer> m_seq_comp_node_optimizer;
static thread_local std::unique_ptr<MegBrainError> tm_async_error;
};
} }