1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
/**
* \file imperative/src/impl/backward_graph_opt.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/imperative/backward_graph_opt.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/backward_graph.h"
using namespace mgb;
using namespace imperative;
OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const EncodedSubgraph& src)
: input_has_grad(src.output_mask) {
if (src.graph.exprs.size() <= 1) {
// backward graph only contains a single op
backward = src.graph;
save_for_backward = src.input_mask;
return;
}
save_for_backward.resize(src.input_mask.size(), false);
auto&& graph = src.graph;
auto&& mask = src.input_mask;
size_t input_size = src.output_mask.size();
size_t output_size = (mask.size() - input_size) / 2;
mgb_assert(input_size + output_size * 2 == mask.size());
auto& fgraph = precomp;
auto& bgraph = backward;
// optimization: move ops (e.g. GetVarShape) to forward to
// reduce memory footprint
struct VInfo {
bool appears_in_backward = false;
};
std::unordered_map<size_t, VInfo> vinfo;
// step 1.1: ops not in whitelist must run in backward.
// mark their inputs as always appears in backward
for (auto&& [op, iv, ov] : graph.exprs) {
if (!op->same_type<GetVarShape>()) {
for (auto&& v : iv) {
vinfo[v].appears_in_backward = true;
}
}
}
// step 1.2: inputs only available in backward (i.e. grads)
// should be marked as always appears in backward
for (size_t i = 0, j = 0; i < mask.size(); ++i) {
if (!mask[i])
continue;
if (i >= input_size + output_size) {
vinfo[graph.inputs[j]].appears_in_backward = true;
}
++j;
}
// step 2: try to move ops to forward, if not all their inputs
// are marked always appears in backward (otherwise no memory saving)
for (auto&& expr : graph.exprs) {
auto&& [op, iv, ov] = expr;
if (std::all_of(iv.begin(), iv.end(), [&](auto&& v) {
return vinfo[v].appears_in_backward;
})) {
bgraph.exprs.push_back(expr);
for (auto&& v : ov) {
vinfo[v].appears_in_backward = true;
}
// logically should also mark all inputs as appears in backward
// but clearly that's a no-op.
} else {
fgraph.exprs.push_back(expr);
for (auto&& v : ov) {
if (vinfo[v].appears_in_backward) {
// appears_in_backward won't change after this point
// so it is safe to set fgraph.outputs based on current value
fgraph.outputs.push_back(v);
}
}
}
}
// initialize remaining parts
fgraph.constants = graph.constants;
fgraph.inputs.reserve(input_size + output_size);
for (size_t i = 0, j = 0; i < input_size + output_size; ++i) {
if (!mask[i]) {
fgraph.inputs.push_back(1000000000 + i);
continue;
}
fgraph.inputs.push_back(graph.inputs[j++]);
}
bgraph.constants = graph.constants;
bgraph.outputs = graph.outputs;
bgraph.inputs = fgraph.outputs;
for (size_t i = 0, j = 0; i < mask.size(); ++i) {
if (mask[i]) {
auto&& v = graph.inputs[j++];
if (vinfo[v].appears_in_backward) {
save_for_backward[i] = true;
bgraph.inputs.push_back(v);
}
}
}
if (!fgraph.outputs.size()) {
precomp = {};
}
}