#include "megbrain/imperative/profiler_plugin.h"
#include "megbrain/graph.h"
#include "megbrain/graph/event.h"
#include "./profiler/events.h"
namespace mgb::imperative::interpreter::intl {
ProfilerPlugin::ProfilerPlugin(cg::ComputingGraph* graph) : PluginBase(graph) {
using namespace cg;
using namespace cg::event;
using namespace profiler;
auto on_seq_start = [this](CompSeqExecBeforeStart const& event) {
mgb_assert(!event.graph->options().imperative_proxy_graph);
CompNode::foreach ([](CompNode device) {
MGB_RECORD_EVENT_IF(
(Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
Timer::record_device(device));
});
if (m_opr_dict.empty() && m_var_dict.empty()) {
init_seq(event.exec);
}
Profiler::record<ScopeEvent>("Constants");
for (auto&& [var, var_info] : m_var_dict) {
if (var_info->is_const) {
bool valid = var->dev_tensor_valid();
auto layout = valid ? var->layout() : TensorLayout();
var_info->id = Profiler::next_id();
Profiler::record<TensorDeclareEvent>(var_info->id, var->name());
Profiler::record<TensorProduceEvent>(
var_info->id, layout, var->comp_node(),
valid ? var->dev_tensor().raw_ptr() : nullptr);
} else {
var_info->rt_ref_cnt = var_info->ref_cnt;
}
}
Profiler::record<ScopeFinishEvent>("Constants");
Profiler::record<ScopeEvent>("DispatchOprs");
event.exec->iter_opr_seq([this](OperatorNodeBase* opr) -> bool {
auto& opr_info = get_opr_info(opr);
for (auto output : opr->output()) {
auto& var_id = get_var_info(output).id;
var_id = Profiler::next_id();
Profiler::record<TensorDeclareEvent>(var_id, output->name());
}
auto opr_name = opr->dyn_typeinfo()->name;
auto copy_params = [params = opr_info.params] { return *params; };
SmallVector<uint64_t> inputs, outputs;
for (auto input : opr->input()) {
inputs.push_back(get_var_info(input).id);
}
for (auto output : opr->output()) {
outputs.push_back(get_var_info(output).id);
}
Profiler::record<OpDispatchEvent>(
opr_info.id = Profiler::next_id(), opr_name, copy_params, inputs,
outputs);
return true;
});
Profiler::record<ScopeFinishEvent>("DispatchOprs");
};
auto on_opr_start = [this](OprExecStart const& event) {
OperatorNodeBase* opr = event.opr;
auto& opr_info = get_opr_info(opr);
auto comp_node = opr_info.comp_node;
auto runner = [&opr_info] { Profiler::record<OpExecuteEvent>(opr_info.id); };
event.env->dispatch_on_comp_node(comp_node, runner);
auto inputs = opr->input();
for (auto&& input : inputs) {
auto& var_info = get_var_info(input);
auto runner = [&var_info, input] {
auto inp_id = var_info.id;
Profiler::record<OpInputEvent>(inp_id, input->shape());
Profiler::record<TensorUsageEvent>(inp_id);
Profiler::record<OpInputFinishEvent>(inp_id, input->shape());
};
event.env->dispatch_on_comp_node(comp_node, runner);
}
};
auto on_opr_finish = [this](OprExecKernelEnd const& event) {
OperatorNodeBase* opr = event.opr;
auto& opr_info = get_opr_info(opr);
auto comp_node = opr_info.comp_node;
auto inputs = opr->input();
auto outputs = opr->output();
for (auto input : inputs) {
auto& var_info = get_var_info(input);
auto runner = [&var_info] {
if (!var_info.is_const) {
if (--var_info.rt_ref_cnt == 0) {
Profiler::record<TensorReleaseEvent>(var_info.id);
}
}
};
event.env->dispatch_on_comp_node(comp_node, runner);
}
for (auto output : outputs) {
auto& var_info = get_var_info(output);
mgb_assert(comp_node == output->comp_node(), "opr comp_node mismatch");
auto runner = [&var_info, output] {
auto out_id = var_info.id;
bool valid = output->dev_tensor_valid();
auto layout = valid ? output->layout() : TensorLayout();
Profiler::record<OpOutputEvent>(out_id, output->shape());
Profiler::record<TensorProduceEvent>(
out_id, layout, output->comp_node(),
valid ? output->dev_tensor().raw_ptr() : nullptr);
if (!var_info.ref_cnt) {
Profiler::record<TensorReleaseEvent>(var_info.id);
}
Profiler::record<OpOutputFinishEvent>(out_id, output->shape());
};
event.env->dispatch_on_comp_node(comp_node, runner);
}
auto runner = [&opr_info]() {
Profiler::record<OpExecuteFinishEvent>(opr_info.id);
};
event.env->dispatch_on_comp_node(comp_node, runner);
};
auto on_before_kern = [this](BeforeKernel const& event) {
OperatorNodeBase* opr = event.opr;
Profiler::record<KernelLaunchEvent>(
get_opr_info(opr).id, get_opr_info(opr).id, event.comp_node);
MGB_RECORD_EVENT_IF(
(Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
Timer::record_device(event.comp_node));
};
auto on_after_kern = [this](AfterKernel const& event) {
OperatorNodeBase* opr = event.opr;
MGB_RECORD_EVENT_IF(
(Profiler::get_option("profile_device", 0)), RecordDeviceEvent,
Timer::record_device(event.comp_node));
Profiler::record<KernelLaunchFinishEvent>(
get_opr_info(opr).id, get_opr_info(opr).id, event.comp_node);
};
auto on_graph_compile = [this](const CompSeqOrderDetermined&) {
m_opr_dict.clear();
m_var_dict.clear();
};
auto on_seq_finish = [this](CompSeqExecFinished const& event) {
for (auto&& [var, var_info] : m_var_dict) {
MGB_MARK_USED_VAR(var);
if (var_info->is_const) {
Profiler::record<TensorReleaseEvent>(var_info->id);
}
Profiler::record<TensorEraseEvent>(var_info->id, var_info->ref_cnt);
var_info->id = 0;
}
};
add_event_handler(
graph->event().register_receiver<CompSeqExecBeforeStart>(on_seq_start));
add_event_handler(graph->event().register_receiver<OprExecStart>(on_opr_start));
add_event_handler(
graph->event().register_receiver<OprExecKernelEnd>(on_opr_finish));
add_event_handler(graph->event().register_receiver<BeforeKernel>(on_before_kern));
add_event_handler(graph->event().register_receiver<AfterKernel>(on_after_kern));
add_event_handler(
graph->event().register_receiver<CompSeqOrderDetermined>(on_graph_compile));
add_event_handler(
graph->event().register_receiver<CompSeqExecFinished>(on_seq_finish));
}
void ProfilerPlugin::init_seq(cg::AsyncExecutable* comp_seq) {
mgb_assert(m_opr_dict.empty());
mgb_assert(m_var_dict.empty());
comp_seq->iter_opr_seq([this](cg::OperatorNodeBase* opr) {
auto comp_nodes = get_opr_comp_node_set(opr);
mgb_assert(comp_nodes.size() == 1);
register_opr(opr);
for (auto&& input : opr->input()) {
if (m_var_dict.count(input) == 0) {
register_var(input).is_const = true;
} else {
get_var_info(input).ref_cnt++;
}
}
for (auto&& output : opr->output()) {
register_var(output).is_const = false;
}
return true;
});
}
ProfilerPlugin::OprInfo& ProfilerPlugin::register_opr(cg::OperatorNodeBase* opr) {
OprInfo info;
auto params = std::make_shared<std::unordered_map<std::string, std::string>>();
auto params_json = opr->to_json();
for (auto&& [k, v] : params_json->cast_final<json::Object>().get_impl()) {
params->insert({k.get_impl(), v->to_string()});
}
info.params = std::move(params);
auto comp_nodes = cg::get_opr_comp_node_set(opr);
mgb_assert(comp_nodes.size() == 1, "only support single comp_node opr");
info.comp_node = *comp_nodes.begin();
return m_opr_dict.insert({opr, info}).first->second;
}
ProfilerPlugin::VarInfo& ProfilerPlugin::register_var(cg::VarNode* var) {
auto info = std::make_unique<VarInfo>();
info->id = 0;
info->is_const = false;
info->ref_cnt = 0;
info->rt_ref_cnt = 0;
mgb_assert(m_var_dict.count(var) == 0, "var exists");
return *(m_var_dict[var] = std::move(info));
}
ProfilerPlugin::OprInfo& ProfilerPlugin::get_opr_info(cg::OperatorNodeBase* opr) {
return m_opr_dict.at(opr);
}
ProfilerPlugin::VarInfo& ProfilerPlugin::get_var_info(cg::VarNode* var) {
return *m_var_dict.at(var);
}
}