#pragma once
#include "./eager_eval.h"
#include "./grad_manager.h"
#include "./graph_opt.h"
#include "./seq_comp_node_opt_impl.h"
#include "./seq_dtr.h"
#include "./seq_sublinear_memory.h"
#include "./static_infer_impl.h"
#include "./swap/memory_swap.h"
#include "./topo_sort.h"
#include "./var_node_mem_mgr.h"
#include "megbrain/utils/mempool.h"
namespace mgb {
namespace cg {
class ComputingGraphImpl final : public ComputingGraph {
class CallbackCaller;
class RecordedComputingSequence;
class MegDNNDtorCheck;
class MultiPartCompiler;
friend class GradManager;
struct CompileState {
CompSeqExtraInfo extra_info;
const OprNodeArray* opr_seq = nullptr;
VarNodeArray dest_vars;
};
struct CallbackCallerKey {
OperatorNodeBase* opr;
CompNode comp_node;
bool operator==(const CallbackCallerKey& rhs) const {
return opr == rhs.opr && comp_node == rhs.comp_node;
}
struct Hash {
size_t operator()(const CallbackCallerKey& b) const {
return hash_pair_combine(mgb::hash(b.opr), mgb::hash(b.comp_node));
}
};
};
struct CallbackCallerVal {
SmallVector<VarNode*> vars;
SmallVector<SmallVector<size_t>> indexs;
};
struct Components : public NonCopyableObj {
TopoSorter topo_sorter;
VarNodeMemManager var_node_mem_manager;
SeqCompNodeOptimizerImpl seq_comp_node_opt;
static_infer::StaticInferManagerImpl static_infer_manager;
static_infer::CompSeqManager static_infer_comp_seq_manager;
GradManager grad_manager;
GraphOptimizer graph_optimizer;
#if MGB_ENABLE_SUBLINEAR
SeqModifierForSublinearMemory seq_modifier_for_sublinear_memory;
#endif
#if MGB_ENABLE_DTR
SeqModifierForDTR seq_modifier_for_dtr;
#endif
#if MGB_ENABLE_MEMORY_SWAP
swap::MemorySwap memory_swap_support;
#endif
EagerEvalManager eager_eval_manager;
explicit Components(ComputingGraphImpl* owner);
};
std::unique_ptr<MegDNNDtorCheck> m_recorded_seq_level2_dtor_chk;
MemPool<VarNode> m_var_node_pool;
ComputingGraphImpl* m_parent_graph = nullptr;
std::vector<ComputingGraphImpl*> m_subgraphs;
AsyncExecutable* m_current_comp_seq = nullptr;
std::shared_ptr<size_t> m_node_id_counter = std::make_shared<size_t>();
std::vector<std::unique_ptr<OperatorNodeBase>> m_opr_refkeeper;
ThinHashMap<VarNode*, OprNodeArray> m_var_receiver;
std::aligned_storage_t<sizeof(Components), alignof(Components)>
m_components_storage;
VarNodeArray get_dest_vars_from_out_spec(
const OutputSpec& spec, SpecialOprStat& sopr_stat);
void cleanup();
std::shared_ptr<void> on_comp_node_finalize() override;
Components& components() {
return reinterpret_cast<Components&>(m_components_storage);
}
const Components& components() const {
return reinterpret_cast<const Components&>(m_components_storage);
}
CompileState compile_prepare(const OutputSpec& out_spec);
std::unique_ptr<AsyncExecutable> compile_commit(CompileState state);
void dest_var_optimize(VarNodeArray& dest_vars);
public:
class ComputingSequence;
MGE_WIN_DECLSPEC_FUC ComputingGraphImpl();
MGE_WIN_DECLSPEC_FUC ~ComputingGraphImpl();
template <typename T>
static ComputingGraphImpl* downcast(T* ptr) = delete;
inline static ComputingGraphImpl* downcast(ComputingGraph* graph) {
mgb_assert(!graph->options().imperative_proxy_graph);
return static_cast<ComputingGraphImpl*>(graph);
}
friend struct ComputingGraph::Options;
std::unique_ptr<AsyncExecutable> compile(const OutputSpec& out_spec) override;
SmallVector<std::unique_ptr<AsyncExecutable>> compile_multi_part(
const SmallVector<OutputSpec>& out_specs) override;
MGE_WIN_DECLSPEC_FUC OperatorNodeBase* insert_opr(
std::unique_ptr<OperatorNodeBase> opr) override;
void* alloc_varnode_storage() override;
void free_varnode_storage(void* ptr) override;
const VarReceiverInfo& var_receiver_in_current_comp_seq(
const VarNode* var) const override;
const OprNodeArray& var_receiver(VarNode* var) const {
return m_var_receiver.at(var);
}
std::string get_mem_allocation_info() const override;
VarNode* find_var_by_id(size_t id) const override;
TopoSorter& topo_sorter() { return components().topo_sorter; }
size_t next_node_id() override { return (*m_node_id_counter)++; }
VarNodeMemManager& var_node_mem_manager() {
return components().var_node_mem_manager;
}
SeqCompNodeOptimizer& seq_comp_node_optimizer() override {
return components().seq_comp_node_opt;
}
static_infer::StaticInferManager& static_infer_manager() override {
return components().static_infer_manager;
}
static_infer::StaticInferManagerImpl& static_infer_manager_impl() {
return components().static_infer_manager;
}
static_infer::CompSeqManager& static_infer_comp_seq_manager() {
return components().static_infer_comp_seq_manager;
}
GraphOptimizer& graph_optimizer() { return components().graph_optimizer; }
EagerEvalManager& eager_eval_manager() { return components().eager_eval_manager; }
#if MGB_ENABLE_SUBLINEAR
SeqModifierForSublinearMemory& seq_modifier_for_sublinear_memory();
#endif
#if MGB_ENABLE_DTR
SeqModifierForDTR& seq_modifier_for_dtr();
#endif
void share_device_memory_with(ComputingGraph& other) override;
void set_device_memory_allocator(
std::shared_ptr<DeviceMemoryAllocator> allocator) override;
size_t get_device_memory_size(CompNode cn) override;
size_t clear_device_memory() override;
void set_as_subgraph(ComputingGraph& par_graph) override;
void record_async_error(std::unique_ptr<MegBrainError> async_exc) override;
AsyncExecutable* current_comp_seq() override {
return static_cast<AsyncExecutable*>(m_current_comp_seq);
}
GraphExecutable::ExecEnv* current_exec_env();
Maybe<size_t> opr_step_num_in_cur_comp_seq(OperatorNodeBase* opr);
const CompSeqExtraInfo& current_comp_seq_extra_info();
GradManager& grad_manager() { return components().grad_manager; }
auto&& all_oprs() const { return m_opr_refkeeper; }
size_t nr_oprs_in_graph() const override { return m_opr_refkeeper.size(); }
auto&& var_node_pool() { return m_var_node_pool; }
};
} }