1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
/**
* \file src/mge/network_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "lite_build_config.h"
#include "megbrain/graph.h"
#if LITE_BUILD_WITH_MGE
#include "lite/network.h"
#include "network_impl_base.h"
#include "tensor_impl.h"
#include <memory>
#include <unordered_map>
#include "megbrain/gopt/inference.h"
#include "megbrain/graph/bases.h"
#include "megbrain/plugin/opr_io_dump.h"
#include "megbrain/plugin/profiler.h"
#include "megbrain/serialization/extern_c_opr.h"
#include "megbrain/serialization/file.h"
#include "megbrain/serialization/load_dump_config.h"
#include "megbrain/serialization/serializer.h"
#include "megbrain/utils/thin/hash_table.h"
namespace lite {
/*!
* \brief implement the Network, contain the mgb related member
*/
class NetworkImplDft final : public Network::NetworkImplBase {
LITE_DYN_TYPE_OBJ_FINAL_DECL;
public:
NetworkImplDft() { m_load_config.comp_graph = mgb::ComputingGraph::make(); }
using S = megdnn::param::ExecutionPolicy::Strategy;
using Var = mgb::cg::SymbolVar;
//! set the config of the network, include:
//! the inference device
//! the other inference options, such as record_level, weight_preprocess...
void set_config(const Config& config) override;
//! set the special io infomation, if not set, default io tensor will used,
//! this is special for input/output is not host tensor, default the
//! input/output tensors are host tensor
void set_io(const NetworkIO& network_io) override;
//! only compute the output tensor in user configured
void compute_only_configured_output() override {
m_compute_configured_output_only = true;
}
//! get the network input and ouput tensor, the layout of which is
//! sync from mge tensor
std::shared_ptr<Tensor> get_io_tensor(
std::string io_name,
LiteTensorPhase phase = LiteTensorPhase::LITE_IO) override;
//! get the input tensor by index in the load_result tensormap
std::shared_ptr<Tensor> get_input_tensor(size_t index) override;
//! get the output tensor by index in the load_result output_var_list
std::shared_ptr<Tensor> get_output_tensor(size_t index) override;
//! get all the input tensor name in the order in load return
std::vector<const char*> get_all_input_name() const override;
//! get all the output tensor name in the order in load return
std::vector<const char*> get_all_output_name() const override;
//! get the input tensor name in the order in load return
const char* get_input_name(size_t index) const override;
//! get the output tensor name in the order in load return
const char* get_output_name(size_t index) const override;
//! set the callback in async model
void set_async_callback(const AsyncCallback& callback) override;
//! set the start callback which will execute before network forward
void set_start_callback(const StartCallback& callback) override {
m_start_callback = std::move(callback);
}
//! set the finish callback which will execute after network forward
void set_finish_callback(const FinishCallback& callback) override {
m_finish_callback = std::move(callback);
}
//! load the model and get the m_load_result
void load_model(
std::shared_ptr<void> model_mem, size_t size,
std::unordered_map<std::string, LiteAny> separate_config_map = {}) override;
//! forward the network with filled input data and fill the output data
//! to the output tensor
void forward() override;
//! in sync model, wait utile the inference finish
void wait() override;
virtual LiteDeviceType get_device_type() const override {
return m_user_config->device_type;
}
//! Set cpu default mode when device is CPU, in some low computation
//! device or single core device, this mode will get good performace
void set_cpu_inplace_mode();
bool is_cpu_inplace_mode() const { return m_is_cpu_inplace_mode; }
//! When device is CPU, this interface will set the to be loaded model
//! run in multi thread mode with the given thread number.
void set_cpu_threads_number(size_t nr_threads);
size_t get_cpu_threads_number() const { return m_nr_threads; }
//! set device id, default device id = 0
void set_device_id(int device_id) override;
int get_device_id() const override { return m_compnode_locator.device; };
LiteBackend get_backend_type() const override { return LiteBackend::LITE_DEFAULT; }
//! set stream id, default stream id = 0
void set_stream_id(int stream_id) override;
int get_stream_id() const override { return m_compnode_locator.stream; };
//! enable tensorrt
void use_tensorrt();
//! enable profile the network, a JSON format file will be generated
void enable_profile_performance(std::string profile_json_file_path) override;
/********************** mge special function ************************/
//! load a new network which will share weights with src network
void shared_weight_with(const NetworkImplBase* src_network);
//! share the runtime memory with other network, the weights is not shared
void share_runtime_memory_with(NetworkImplBase* network);
//! set threads affinity callback;
void set_runtime_thread_affinity(
const ThreadAffinityCallback& thread_affinity_callback);
//! set the network memroy allocator, the allocator is defined by user
void set_memory_allocator(std::shared_ptr<Allocator> user_allocator);
//! set opr algorithm selection strategy in the network
void set_network_algo_policy(
LiteAlgoSelectStrategy strategy, uint32_t shared_batch_size,
bool binary_equal_between_batch);
//! set workspace_limit for oprs with multiple algorithms, set
//! workspace limitation can save memory but may influence the performance
void set_network_algo_workspace_limit(size_t workspace_limit);
//! Dump input/output values of all internal variables to output file,
//! in text format
void enable_io_txt_dump(std::string io_txt_out_file);
//! Dump input/output values of all internal variables to output
//! directory, in binary format
void enable_io_bin_dump(std::string io_bin_out_dir);
//! get static peak memory info showed by Graph visualization
void get_static_memory_alloc_info(
const std::string& log_dir = "logs/test") const override;
//! set global layout transform optimization for network
void enable_global_layout_transform();
//! dump network after global layout transform optimization
void dump_layout_transform_model(std::string optimized_model_path);
private:
//! construct the outputspec according to the m_network_io, and set the
//! call_back to the outputspec
void make_output_spec();
//! do the global layout transform for the given platform target
void global_layout_transform();
//! modify the execution policy
void modify_exection_policy();
//! if the input is dev tensor, the pass will replace the H2D Opr to
//! VolatileSharedDeviceTensor Opr
void replace_dev_input_pass();
//! check whether the model is cross compnode
void cross_compnode_model_detect();
//! when the model have loaded, update the IO, if not set networkio, update
//! the networkio with the IO of loaded model
void update_io();
void update_input();
void update_output();
//! when the model info have loaded, update the config according the model
//! info, finaly use it in compute graph
void application_config();
//! after finish forwarding the netwark, output the result of plugin to file
void output_plugin_result() const;
//! when finish forwarding the network, the function will be called
void finish() const;
//! before forwarding the network, the function will be called
void start() const;
//! compile the graph to get the execute function
void compile_graph();
//! try to infer output tensor layout
void try_infer_tensor_layout(std::shared_ptr<Tensor> tensor, Var var);
//! optimized output tensor copy
void output_tensor_copy_optimize(Var var, std::shared_ptr<Tensor> tensor);
//! adapt option valid, it should call after update_io
void adapt_option_valid();
private:
bool m_async = false;
bool m_is_cpu_inplace_mode = false;
int m_nr_device_type = 0;
size_t m_nr_threads = 1;
bool m_compute_configured_output_only = false;
bool m_set_layout_transform = false;
mgb::CompNode::Locator m_compnode_locator;
AsyncCallback m_async_callback = nullptr;
std::unique_ptr<NetworkIOInner> m_network_io;
std::unique_ptr<Config> m_user_config;
std::unique_ptr<mgb::cg::AsyncExecutable> m_execute_func;
//! The model load related data
S m_execution_policy = static_cast<S>(0);
std::unique_ptr<mgb::serialization::InputFile> m_input_file;
mgb::Maybe<mgb::serialization::GraphDumpFormat> m_format;
mgb::gopt::GraphTuningOptions::Target m_layout_transform_target;
mgb::serialization::GraphLoadConfig m_load_config;
mgb::serialization::GraphLoader::LoadResult m_load_result;
mgb::ComputingGraph::OutputSpec m_output_spec;
std::shared_ptr<mgb::serialization::GraphLoader> m_loader;
//! start and finish callback
StartCallback m_start_callback = nullptr;
FinishCallback m_finish_callback = nullptr;
//! profile and io dump related data
#if MGB_ENABLE_JSON
std::unique_ptr<mgb::GraphProfiler> m_profiler;
std::string m_profiler_output_file;
#endif
std::unique_ptr<mgb::OprIODumpBase> m_iodump;
};
//! get the model information before model loaded by Network
NetworkIO get_model_io_info_dft(const std::string& model_path, const Config& config);
//! get the model information before model loaded by Network by model memory and
//! size
NetworkIO get_model_io_info_dft(
const void* model_mem, size_t size, const Config& config);
} // namespace lite
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}