#pragma once
#include "./output_recorder.h"
#include "megbrain/graph/grad_impl.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/opr/internal/mixin_base.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/loop.h"
#include "megdnn/oprs.h"
#include <list>
namespace mgb {
namespace opr {
namespace intl {
class LoopImpl::OutputRecordSpecItem final : public Hashable {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
static Desc::OutputRecorderBase* const m_dummy_recorder;
bool m_enabled = true;
VarNode *m_var_sub, *m_var_owner = nullptr;
std::unique_ptr<Desc::OutputRecorderBase> m_recorder;
bool is_same_st(const Hashable& rhs) const override {
auto&& robj = static_cast<const OutputRecordSpecItem&>(rhs);
return m_var_sub == robj.m_var_sub && m_recorder->is_same(*robj.m_recorder);
}
public:
mutable bool user_data = false;
OutputRecordSpecItem(
SymbolVar sub, std::unique_ptr<Desc::OutputRecorderBase> recorder)
: m_var_sub(sub.node()), m_recorder(std::move(recorder)) {}
size_t hash() const override {
return hash_pair_combine(std::hash<void*>{}(m_var_sub), m_recorder->hash());
}
Desc::OutputRecorderBase* recorder() const {
return m_enabled ? m_recorder.get() : m_dummy_recorder;
}
Desc::OutputMode output_mode() const { return m_recorder->output_mode(); }
void bind(VarNode* var_owner) {
mgb_assert(!m_var_owner && var_owner);
m_var_owner = var_owner;
m_recorder->bind_var(m_var_sub, var_owner);
}
VarNode* var_sub() const { return m_var_sub; }
VarNode* var_owner() const { return m_var_owner; }
bool enabled() const { return m_enabled; }
OutputRecordSpecItem& enable(bool flag) {
m_enabled = flag;
return *this;
}
void var_sub(VarNode* var) {
mgb_assert(!m_var_owner, "bind() must not be called");
m_var_sub = var;
}
};
MGB_DEFINE_OPR_CLASS(LoopImpl::InputMaker, cg::SingleCNOperatorNodeBase) public:
struct Param {
bool disable_value_infer;
bool has_assign;
};
InputMaker(DescImplBase* desc, VarNode* orig_var, const Param& param);
static SymbolVar make(DescImplBase* desc, SymbolVar orig_var, const Param& param);
void set_assignor(VarNode* var) {
mgb_assert(m_param.has_assign && var && !m_assignor_committed);
m_assignor_var = var;
}
void commit_assignor();
VarNode* assignor() const {
mgb_assert(
m_assignor_var,
"assignment value not set for "
"%s (orig: %s)",
cname(), cg::dump_var_info({m_orig_var}).c_str());
return m_assignor_var;
}
VarNode* orig_var() const { return m_orig_var; }
const Param& param() const { return m_param; }
void on_exec_end() {
m_first_exec = true;
m_assignor_value = {};
}
private:
const Param m_param;
bool m_first_exec = true;
bool m_assignor_committed = false;
VarNode* m_orig_var;
DescImplBase* m_desc;
VarNode* m_assignor_var = nullptr;
DeviceTensorND m_assignor_value;
NodeProp* do_make_node_prop() const override;
void init_output_comp_node() override { comp_node(m_orig_var->comp_node()); }
void init_output_static_infer_desc() override;
void init_output_mem_plan(bool dynamic) override;
void scn_do_execute() override;
};
class LoopImpl::SubgraphDepIter : public NonCopyableObj {
size_t m_input_makers_sorted_size = 0;
VarNodeArray m_unresolved_assignors;
std::vector<InputMaker*> m_input_makers;
cg::OprNodeArray m_oprs;
cg::DepOprIter m_dep_iter;
void sort_input_makers();
void dep_iter_cb(cg::OperatorNodeBase* opr);
public:
SubgraphDepIter();
~SubgraphDepIter() noexcept;
void add(VarNode* dest);
auto&& input_makers() {
if (m_input_makers_sorted_size != m_input_makers.size()) {
sort_input_makers();
}
return m_input_makers;
}
auto&& oprs() const { return m_oprs; }
};
class LoopImpl::DescImplBase : public LoopImpl::Desc {
public:
using OutputRecordSpec = std::list<OutputRecordSpecItem>;
class CounterProvider;
class LoopCondManager final : NonCopyableObj {
SymbolVar m_var;
class GetCondOpr;
GetCondOpr* m_get_cond_opr = nullptr;
public:
SymbolVar var() const { return m_var; }
LoopCondManager& setup(SymbolVar var) {
m_var = var;
return *this;
}
ComputingGraph::OutputSpec::value_type subgraph_outspec_item();
bool should_loop();
};
DescImplBase();
SymbolVar get_counter_var() override {
mgb_throw_if(
!m_counter_var.node(), GraphError,
"could only get counter var "
"when there is at least one input");
return m_counter_var;
}
Desc& set_loop_condition(SymbolVar cond) override {
mgb_throw_if(
!check_in_sub_graph(cond), GraphError,
"loop condition must be in the sub graph");
m_loop_cond_manager.setup(cond);
return *this;
}
void set_loop_opr(LoopImpl* opr) {
mgb_assert(!m_owner_loop_opr);
m_owner_loop_opr = opr;
}
ComputingGraph* owner_graph() const { return m_owner_graph; }
ComputingGraph* sub_graph() const { return m_sub_graph.get(); }
std::unique_ptr<cg::AsyncExecutable> compile();
auto&& output_record_spec() const { return m_output_record_spec; }
auto&& output_record_spec_no_dedup() const { return m_output_record_spec_no_dedup; }
auto&& loop_cond_manager() { return m_loop_cond_manager; }
const std::vector<InputMaker*>& cur_func_input() const {
return m_cur_func_input.val();
}
virtual const std::vector<InputMaker*>& all_inputs() = 0;
cg::static_infer::SubgraphStaticInferHelper& sub_graph_static_infer_helper() {
return *m_sub_graph_static_infer_helper;
}
virtual void reset_counter_provider();
virtual void update_counter_provider();
CounterProvider* counter_provider() const { return m_counter_provider; }
SymbolVar do_add_input(SymbolVar inp, const InputMaker::Param& param);
protected:
LoopImpl* m_owner_loop_opr = nullptr;
std::shared_ptr<cg::ComputingGraph> m_sub_graph;
OutputRecordSpec m_output_record_spec;
std::vector<OutputRecordSpecItem*> m_output_record_spec_no_dedup;
bool check_in_owner_graph(SymbolVar var) {
return m_owner_graph == var.node()->owner_graph();
}
bool check_in_sub_graph(SymbolVar var) {
return m_sub_graph.get() == var.node()->owner_graph();
}
size_t do_add_output(
SymbolVar val, std::unique_ptr<OutputRecorderBase> recorder) override;
virtual void on_sub_graph_func_compile(ComputingGraph::OutputSpec& out_spec) {}
private:
struct OutputRecordSpecPtr {
OutputRecordSpecItem* p;
bool operator==(const OutputRecordSpecPtr& rhs) const {
return p->is_same(*rhs.p);
}
struct Hash {
size_t operator()(const OutputRecordSpecPtr& ptr) const {
return ptr.p->hash();
}
};
};
Maybe<std::vector<InputMaker*>> m_cur_func_input;
cg::ComputingGraph* m_owner_graph = nullptr;
std::unique_ptr<cg::static_infer::SubgraphStaticInferHelper>
m_sub_graph_static_infer_helper =
cg::static_infer::SubgraphStaticInferHelper::make();
std::unordered_set<OutputRecordSpecPtr, OutputRecordSpecPtr::Hash>
m_output_record_spec_dedup;
SymbolVar m_counter_var;
CounterProvider* m_counter_provider = nullptr;
LoopCondManager m_loop_cond_manager;
void on_first_input_added(SymbolVar inp);
};
MGB_DEFINE_OPR_CLASS(
LoopImpl::DescImplBase::CounterProvider, cg::SingleCNOperatorNodeBase) HostTensorND m_delta_host, m_next_val_host;
DeviceTensorND m_delta_dev, m_next_val_dev;
int m_delta, m_next_val;
std::unique_ptr<megdnn::AddUpdate> m_add_update;
void init_output_comp_node() override;
void init_output_mem_plan(bool dynamic) override;
void scn_do_execute() override;
void init_output_static_infer_desc() override;
NodeProp* do_make_node_prop() const override;
public:
CounterProvider(ComputingGraph& graph, const OperatorNodeConfig& config);
static CounterProvider* make(
ComputingGraph& graph, const OperatorNodeConfig& config);
void update_next_val();
void next_val(int v);
void delta(int v);
int next_val() { return m_next_val; }
};
MGB_DEFINE_CLS_WITH_SUPER(
MultidepProxyOperatorNodeBase, cg::SingleCNOperatorNodeBase) void init_output_static_infer_desc() override final;
protected:
MultidepProxyOperatorNodeBase(const OperatorNodeBaseCtorParam& opr);
};
MGB_DEFINE_OPR_CLASS(LoopImpl::DepTensorUpdator, MultidepProxyOperatorNodeBase) public:
struct AccumulatorState {
DeviceTensorND* dest = nullptr;
bool first_sum = true;
intl::UniqPtrWithCN<megdnn::Elemwise> adder;
void reset() { first_sum = true; }
};
DepTensorUpdator(
DeviceTensorND* dest, const std::shared_ptr<AccumulatorState>& accum_state,
VarNode* val, VarNode* dep, const OperatorNodeConfig& config = {});
static SymbolVar make(DeviceTensorND* dest, SymbolVar val, SymbolVar dep);
static SymbolVar make(
const std::shared_ptr<AccumulatorState>& state, SymbolVar val,
SymbolVar dep);
cg::OperatorNodeBase* shallow_copy(
const VarNodeArray& inputs, const OperatorNodeConfig& config) const;
private:
DeviceTensorND* const m_dest;
std::shared_ptr<AccumulatorState> const m_accum_state;
void scn_do_execute() override;
NodeProp* do_make_node_prop() const override;
};
class LoopImpl::FwdDesc final : public LoopImpl::DescImplBase {
ThinHashMap<VarNode*, bool> m_input_assigned;
ThinHashMap<VarNode*, VarNode*> m_input_no_assign_dedup;
ThinHashMap<VarNode*, OutputRecordSpecItem*> m_output_record_spec_mode_all;
std::unique_ptr<SubgraphDepIter> m_dep_iter;
public:
SymbolVar add_input(SymbolVar inp, bool has_assign) override;
size_t add_output(SymbolVar val, OutputMode mode) override;
Desc& assign(SymbolVar dest, SymbolVar val) override;
VarNode* owner_graph_output_at(size_t idx) const;
SymbolVarArray user_output_vars_including_dup() const;
auto&& output_record_spec_mode_all() const { return m_output_record_spec_mode_all; }
const std::vector<InputMaker*>& all_inputs() override;
const cg::OprNodeArray& sub_graph_oprs() {
all_inputs();
return m_dep_iter->oprs();
}
void on_sub_graph_optimized() { m_dep_iter.reset(); }
};
class LoopImpl::MutableStateSaver {
class Recorder;
class ValueUpdator;
class ShapeUpdator;
struct SavedVarInfo {
VarNode* var = nullptr;
bool need_value = false, need_shape = false;
std::unique_ptr<Recorder> recorder;
SymbolVar value_updator, shape_updator;
};
Loop* const m_owner_opr;
bool m_slowcopy_warn_printed = false;
bool m_enabled = true;
int m_swap_interval_setting = 5;
ThinHashMap<VarNode*, SavedVarInfo> m_var2info;
ThinHashSet<VarNode*> m_recorded_vars;
void print_slowcopy_warn(const char* msg);
inline VarNode* get_user_recorded_output_all(VarNode* var);
public:
MutableStateSaver(Loop* owner_opr);
~MutableStateSaver();
void swap_interval(int v) { m_swap_interval_setting = v; }
void add_var_to_record(VarNode* var);
bool enabled() const { return m_enabled; }
void disable();
void enable_for_grad(cg::AsyncExecutable* seq);
bool is_var_recorded(VarNode* var) const { return m_recorded_vars.count(var); }
VarNode* get_state_for_grad(VarNode* fwd_var, DescImplBase* grad_desc);
void update_subgraph_outspec(ComputingGraph::OutputSpec& spec);
void on_fwd_begin();
void on_fwd_finish();
void on_grad_finish();
ThinHashMap<VarNode*, bool> test_get_var_rec_spec();
};
} } }