1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
/**
* \file src/core/impl/graph/eager_eval.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "./impl_common.h"
#include "./var_node_mem_mgr.h"
#include "megbrain/graph/cg.h"
namespace mgb {
namespace cg {
#if MGB_BUILD_SLIM_SERVING
class EagerEvalManager {
public:
explicit EagerEvalManager(ComputingGraphImpl*) {}
void on_opr_insert(OperatorNodeBase*) {}
const ComputingGraph::VarReceiverInfo* var_receiver_info(const VarNode*) const {
return nullptr;
}
GraphExecutable::ExecEnv* exec_env() { return nullptr; }
const CompSeqExtraInfo* comp_seq_extra_info() { return nullptr; }
bool enabled() const { return false; }
size_t get_var_nr_readers(VarNode*) const { return REFCNT_INF; }
};
#else
class EagerEvalManager {
class EagerExecEnv;
struct VersionTrait {
enum Flag : uint8_t {
// never re-eval, all outputs of this operator could be treated
// as constant; conflicts with MUTABLE
CONST = 1 << 0,
// always re-eval; conflicts with CONST
MUTABLE = 1 << 1,
// always re-eval and would mark all readers of this op as MUTABLE;
// used together with MUTABLE
MUTABLE_SOURCE = 1 << 2
};
Flag flag = static_cast<Flag>(0);
bool need_reeval;
void update_version() {
mgb_assert(need_reeval);
if (!(flag & Flag::MUTABLE)) {
need_reeval = false;
}
for (auto&& i : readers) {
i->need_reeval = true;
}
}
SmallVector<VersionTrait*> readers;
};
//! -1: uninitialized (before first opr insertion); 0/1: disabled/enabled
int m_first_opr_enable_status = -1;
ComputingGraph* const m_owner_graph;
std::unique_ptr<EagerExecEnv> m_exec_env;
CompSeqExtraInfo m_comp_seq_extra_info;
MemPool<CompNodeSyncManager> m_var_sync_mgr_pool;
MemPool<VersionTrait> m_version_trait_pool;
ThinHashMap<OperatorNodeBase*, VersionTrait*> m_opr2version;
bool m_record_mode = false;
ThinHashSet<OperatorNodeBase*> m_record_oprs;
ThinHashMap<VarNode*, size_t> m_var2nr_readers;
//! run ID used for static memory allocator and would not get increased
size_t m_run_id = 1;
void do_on_opr_insert(OperatorNodeBase* opr);
void update_static_infer_result(OperatorNodeBase* opr);
void prepare_for_exec(OperatorNodeBase* opr);
void alloc_output_mem(OperatorNodeBase* opr);
void init_waiting_spec(OperatorNodeBase* opr);
//! copy var tensor as contiguous if layout constraint is not satisified
void ensure_input_layout(VarNode* var);
//! check version of the given operator and return opr's current status
//! -1: uninitilized / 0: version unchanged / 1: version changed
int check_version(OperatorNodeBase* opr);
public:
explicit EagerEvalManager(ComputingGraph* graph);
~EagerEvalManager() noexcept;
bool enabled() const { return m_owner_graph->options().eager_evaluation; }
//! called after an operator is inserted; output vars would be evaluated if
//! eager_eval is enabled
//! re-evaluation would be triggered if a previously inserted operator
//! was reinserted and its version was changed
void on_opr_insert(OperatorNodeBase* opr);
/*!
* \brief return faked VarReceiverInfo; or nullptr if not enabled
*
* VarReceiverInfo should be faked so that all vars would be considered as
* being used
*/
const ComputingGraph::VarReceiverInfo* var_receiver_info(const VarNode* var) const;
/*!
* \brief get curresponding ExecEnv if enabled; return nullptr if not
* enabled
*/
GraphExecutable::ExecEnv* exec_env();
/*!
* \brief get a suitable CompSeqExtraInfo if enabled; return nullptr if not
* enabled
*/
const CompSeqExtraInfo* comp_seq_extra_info() {
if (enabled()) {
return &m_comp_seq_extra_info;
}
return nullptr;
}
/*!
* \brief record oprs rather than really execute them when insert oprs
* into graph, which only use in symbolic gradients computing.
*/
bool enter_record_mode() {
bool old = m_record_mode;
mgb_assert(old || m_record_oprs.empty());
m_record_mode = true;
return old;
}
/*!
* \brief flush all oprs recorded and execute the oprs which were depended on
* dest_vars. Note it would also turn off record mode after calling this method.
*/
void flush_record_oprs(const VarNodeArray& dest_vars);
/*!
* \brief get the reader numbers of a var. return REFCNT_INF if var is not an
* intermediate result when calculating grad.
*/
size_t get_var_nr_readers(VarNode* var) const;
};
#endif // MGB_BUILD_SLIM_SERVING
} // namespace cg
} // namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}