#include "megbrain/plugin/var_value_checker.h"
#include "megbrain/opr/io.h"
using namespace mgb;
void VarValueChecker::Checker::reset() {
m_func.reset();
}
void VarValueChecker::Checker::init(
VarNode* var, const std::shared_ptr<DeviceTensorND>& expected) {
if (!m_inp) {
m_inp = std::make_shared<DeviceTensorND>();
}
setup_inp(var);
auto graph = ComputingGraph::make();
auto ex = opr::SharedDeviceTensor::make(*graph, expected, {"expected"}),
get = opr::SharedDeviceTensor::make(
*graph, m_inp, {ssprintf("get:%s", cg::dump_var_info({var}).c_str())}),
out = opr::AssertEqual::make(ex, get, {false});
m_func = graph->compile({{out, {}}});
}
void VarValueChecker::Checker::check(VarNode* var) {
setup_inp(var);
m_func->clear_device_memory(); m_func->execute().wait();
}
void VarValueChecker::Checker::setup_inp(VarNode* var) {
auto&& val = var->dev_tensor();
if (val.layout().is_contiguous()) {
*m_inp = var->dev_tensor();
} else {
*m_inp = {};
m_inp->copy_from(val);
}
}
VarValueChecker::VarValueChecker(
ComputingGraph* graph, size_t var_switch_interval, size_t init_var_idx)
: PluginBase(graph),
m_init_var_idx{init_var_idx},
m_var_switch_interval{var_switch_interval} {
add_member_func_as_event_handler(&VarValueChecker::on_comp_seq_order_determined);
add_member_func_as_event_handler(&VarValueChecker::on_opr_kern_end);
add_member_func_as_event_handler(&VarValueChecker::on_comp_seq_exec_finished);
}
void VarValueChecker::on_comp_seq_order_determined(
const cg::event::CompSeqOrderDetermined& event) {
m_init_val_dumped = false;
m_var2val.clear();
m_vars.clear();
}
void VarValueChecker::on_opr_kern_end(const cg::event::OprExecKernelEnd& event) {
if (event.opr->same_type<opr::AssertEqual>()) {
event.opr->cast_final<opr::AssertEqual>().disable_throw_on_error();
}
for (VarNode* var : event.opr->output()) {
if (!var->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
auto callback = [this, var]() { on_var_computed(var); };
m_vars.push_back(var);
event.env->dispatch_on_comp_node(var->comp_node(), callback);
}
}
}
void VarValueChecker::on_comp_seq_exec_finished(const cg::event::CompSeqExecFinished&) {
if (!m_init_val_dumped) {
m_init_val_dumped = true;
mgb_assert(!m_vars.empty());
m_cur_var_idx = m_init_var_idx;
m_nr_exec = 0;
m_checker.reset();
} else {
++m_nr_exec;
if (m_nr_exec != m_var_switch_interval)
return;
m_nr_exec = 0;
++m_cur_var_idx;
}
if (m_cur_var_idx >= m_vars.size()) {
fprintf(stderr,
"VarValueChecker: all check passed; "
"start from beginning\n");
m_cur_var_idx = 0;
}
auto var = m_vars[m_cur_var_idx];
fprintf(stderr, "VarValueChecker: going to check #%zu: %s\n", m_cur_var_idx,
cg::dump_var_info({var}).c_str());
m_checker.reset();
}
void VarValueChecker::on_var_computed(VarNode* var) {
if (!var->dev_tensor_valid()) {
if (m_init_val_dumped && var == m_vars[m_cur_var_idx]) {
on_comp_seq_exec_finished({});
}
return;
}
if (!m_init_val_dumped) {
#if !__DEPLOY_ON_XP_SP2__
m_var2val_mtx.lock();
#endif
auto&& val = m_var2val[var];
#if !__DEPLOY_ON_XP_SP2__
m_var2val_mtx.unlock();
#endif
mgb_assert(!val);
val = std::make_shared<DeviceTensorND>();
val->copy_from(var->dev_tensor());
return;
}
if (var != m_vars[m_cur_var_idx])
return;
if (!m_checker.valid()) {
m_checker.init(var, m_var2val.at(var));
}
m_checker.check(var);
}