#include "megbrain/imperative/opr_utility.h"
#include "./mgb_cg_impl.h"
namespace {
mgb::OperatorNodeConfig setup_config_cn(
const mgb::OperatorNodeConfig& config_, const mgb::CompNode& cn) {
auto prev_cn = config_.get_single_comp_node();
mgb_assert(!prev_cn.valid() || cn == prev_cn);
auto config = config_;
config.comp_node(cn);
return config;
}
} namespace mgb {
namespace opr {
MGB_DYN_TYPE_OBJ_FINAL_IMPL(InputCallback);
InputCallback::InputCallback(
cg::ComputingGraph& graph, callback_t callback, const VarNodeArray& inputs,
const TensorShape& output_shape, const OperatorNodeConfig& config,
bool use_static_shape)
: Super(&graph, config, "input_callback", inputs),
m_output_shape(output_shape),
m_callback(callback),
m_use_static_shape(use_static_shape) {
for (VarNode* i : inputs) {
add_input({i});
}
DType dt = config.output_dtype();
mgb_assert(dt.valid());
if (m_use_static_shape) {
mgb_assert(m_output_shape.ndim);
}
add_output(None)
->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
.add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC)
.dtype(dt);
add_output(None)
->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
.add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC)
.dtype(DType::from_enum(DTypeEnum::Byte));
add_equivalence_component<ScalarHash<void*>>(this);
}
SymbolVarArray InputCallback::make(
cg::ComputingGraph& graph, callback_t callback, CompNode comp_node, DType dtype,
const TensorShape& shape, const SymbolVarArray& inputs, bool use_static_shape) {
mgb_assert(comp_node.valid());
mgb_assert(dtype.valid());
OperatorNodeConfig config;
config.comp_node(comp_node);
config.output_dtype(dtype);
auto vinputs = to_var_node_array(inputs);
auto opr = graph.insert_opr(std::make_unique<InputCallback>(
graph, callback, vinputs, shape, config, use_static_shape));
return to_symbol_var_array(opr->output());
}
void InputCallback::init_output_static_infer_desc() {
using namespace cg::static_infer;
if (m_use_static_shape) {
auto&& mgr = owner_graph()->static_infer_manager();
auto infer_shape = [this](TensorShape& dest, const InpVal&) {
dest = m_output_shape;
return true;
};
mgr.register_shape_infer(output(0), {SourceType::CONSTANT, {}, infer_shape});
} else {
if (m_output_shape.ndim) {
auto* var = output(0);
var->shape(m_output_shape);
auto&& mgr = cg::ComputingGraphImpl::downcast(owner_graph())
->static_infer_manager_impl();
auto* handle = mgr.get_tag_handler_for_shape(var);
handle->sync_from_var();
}
}
}
cg::OperatorNodeBase::NodeProp* InputCallback::do_make_node_prop() const {
NodeProp* prop = Super::do_make_node_prop();
prop->add_flag(NodeProp::Flag::NO_AUTOMATIC_DUP);
SmallVector<NodeProp::DepType> dep_types(
input().size(), NodeProp::DepType::DEV_COMP_ORDER);
prop->reset_dep_type(input(), dep_types);
return prop;
}
void InputCallback::scn_do_execute() {
auto dev_tensor = m_callback();
if (m_use_static_shape) {
mgb_assert(dev_tensor.shape().eq_shape(m_output_shape));
}
if (dev_tensor.empty()) {
auto layout = dev_tensor.layout();
layout.init_contiguous_stride();
dev_tensor.reset(dev_tensor.storage(), layout);
}
output(0)->force_assign_dev_tensor_from_tensor(dev_tensor);
}
cg::OperatorNodeBase* InputCallback::shallow_copy(
const serialization::OprShallowCopyContext& ctx,
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto&& opr = opr_.cast_final_safe<InputCallback>();
auto* graph = ctx.owner_graph(opr, inputs);
return graph->insert_opr(std::make_unique<InputCallback>(
*graph, opr.m_callback, inputs, opr.m_output_shape, config,
opr.m_use_static_shape));
}
MGB_REG_OPR_SHALLOW_COPY(InputCallback, InputCallback::shallow_copy);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(OutputCallback);
OutputCallback::OutputCallback(
Param param, const VarNodeArray& inputs, const OperatorNodeConfig& config)
: Super(inputs[0]->owner_graph(),
setup_config_cn(config, inputs[0]->comp_node()), "output_callback",
inputs),
m_param(std::move(param)) {
for (VarNode* i : inputs) {
add_input({i});
}
if (!m_param.borrow) {
input(0)->add_flag(VarNode::Flag::NO_SYS_STATIC_MEM_ALLOC);
}
add_output(None)
->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
.add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC)
.dtype(DType::from_enum(DTypeEnum::Byte));
add_equivalence_component<ScalarHash<void*>>(this);
}
SymbolVar OutputCallback::make(Param param, const SymbolVarArray& inputs) {
mgb_assert(inputs.size() >= 1);
auto vinputs = to_var_node_array(inputs);
OperatorNodeConfig config;
return inputs[0].insert_single_output_opr<OutputCallback>(
std::move(param), vinputs, config);
}
void OutputCallback::init_output_static_infer_desc() {}
cg::OperatorNodeBase::NodeProp* OutputCallback::do_make_node_prop() const {
NodeProp* prop = Super::do_make_node_prop();
prop->add_flag(NodeProp::Flag::NO_AUTOMATIC_DUP);
prop->add_flag(NodeProp::Flag::CROSS_COMP_NODE_MEMORY);
SmallVector<NodeProp::DepType> dep_types(
input().size(), NodeProp::DepType::DEV_COMP_ORDER);
using IT = cg::static_infer::InferType;
auto host_value_avail = [&]() -> bool {
auto inp = input(0);
auto it = owner_graph()->static_infer_manager().get_infer_type(inp).value;
return it & (IT::CONST | IT::RT_STATIC | IT::MISSING_INP);
};
m_use_host_value = m_param.prefer_host_value && host_value_avail();
dep_types[0] = m_use_host_value ? NodeProp::DepType::HOST_VALUE
: NodeProp::DepType::DEV_VALUE;
dep_types[0] |= NodeProp::DepType::VALUE_ALLOW_EMPTY;
prop->reset_dep_type(input(), dep_types);
return prop;
}
void OutputCallback::add_input_layout_constraint() {
if (m_param.require_contiguous) {
input(0)->add_layout_constraint_contiguous();
}
}
void OutputCallback::scn_do_execute() {
if (m_use_host_value) {
m_param.callback(owner_graph()->static_infer_manager().infer_value(input(0)));
} else {
m_param.callback(input(0)->dev_tensor());
}
}
cg::OperatorNodeBase* OutputCallback::shallow_copy(
const serialization::OprShallowCopyContext& ctx,
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto&& opr = opr_.cast_final_safe<OutputCallback>();
auto* graph = ctx.owner_graph(opr, inputs);
return graph->insert_opr(
std::make_unique<OutputCallback>(opr.m_param, inputs, config));
}
MGB_REG_OPR_SHALLOW_COPY(OutputCallback, OutputCallback::shallow_copy);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(NopCallback);
NopCallback::NopCallback(
cg::ComputingGraph& graph, callback_t callback, const VarNodeArray& inputs,
const OperatorNodeConfig& config)
: Super(&graph, config, "nop_callback", inputs), m_callback(callback) {
for (VarNode* i : inputs) {
add_input({i});
}
add_output(None)
->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
.add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC)
.dtype(DType::from_enum(DTypeEnum::Byte));
add_equivalence_component<ScalarHash<void*>>(this);
}
SymbolVar NopCallback::make(
cg::ComputingGraph& graph, callback_t callback, CompNode comp_node,
const SymbolVarArray& inputs) {
mgb_assert(comp_node.valid());
OperatorNodeConfig config;
config.comp_node(comp_node);
auto vinputs = to_var_node_array(inputs);
auto opr = graph.insert_opr(
std::make_unique<NopCallback>(graph, callback, vinputs, config));
return opr->output(0);
}
void NopCallback::init_output_static_infer_desc() {}
void NopCallback::on_output_comp_node_stream_changed() {}
void NopCallback::init_output_comp_node() {
auto cn = config().get_single_comp_node();
mgb_assert(cn.valid());
output(0)->comp_node(cn);
}
cg::OperatorNodeBase::NodeProp* NopCallback::do_make_node_prop() const {
NodeProp* prop = Super::do_make_node_prop();
SmallVector<NodeProp::DepType> dep_types(
input().size(), NodeProp::DepType::DEV_COMP_ORDER);
prop->reset_dep_type(input(), dep_types);
prop->add_flag(cg::OperatorNodeBase::NodeProp::Flag::CROSS_COMP_NODE_MEMORY);
return prop;
}
void NopCallback::do_execute(ExecEnv& env) {
auto cn = output(0)->comp_node();
auto runner = [this, cn] {
owner_graph()->event().signal_inplace<cg::event::BeforeKernel>(this, cn);
cn.activate();
m_callback();
owner_graph()->event().signal_inplace<cg::event::AfterKernel>(this, cn);
};
env.dispatch_on_comp_node(cn, runner);
}
MGB_DYN_TYPE_OBJ_FINAL_IMPL(MutableTensor);
MutableTensor::MutableTensor(
cg::ComputingGraph& graph, std::shared_ptr<DeviceTensorND> dev_tensor,
std::shared_ptr<HostTensorND> host_tensor, const OperatorNodeConfig& config)
: Super(&graph, config, {}, {}) {
m_dev_tensor = dev_tensor;
m_host_tensor = host_tensor;
add_output(None)
->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
.add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC)
.dtype(m_dev_tensor->dtype());
add_equivalence_component<ScalarHash<const void*>>(this);
}
SymbolVar MutableTensor::make(
cg::ComputingGraph& graph, std::shared_ptr<DeviceTensorND> dev_tensor,
std::shared_ptr<HostTensorND> host_tensor, const OperatorNodeConfig& config) {
return graph
.insert_opr(std::make_unique<MutableTensor>(
graph, dev_tensor, host_tensor, config))
->output(0);
}
void MutableTensor::init_output_comp_node() {
if (config().has_comp_node_set()) {
mgb_assert(
config().get_single_comp_node() == m_dev_tensor->comp_node(),
"comp_node mismatch");
}
comp_node(m_dev_tensor->comp_node());
}
cg::OperatorNodeBase::NodeProp* MutableTensor::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->add_flag(NodeProp::Flag::IMPURE_OUTPUT_MEM_PLAN);
return ret;
}
void MutableTensor::scn_do_execute() {
output(0)->force_assign_dev_tensor_from_tensor(*m_dev_tensor);
}
void MutableTensor::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto& mgr = owner_graph()->static_infer_manager();
auto infer_shape = [this](TensorShape& dest, const InpVal&) {
dest = m_dev_tensor->shape();
return true;
};
mgr.register_shape_infer(output(0), {SourceType::MUTABLE, {}, infer_shape});
if (m_host_tensor) {
auto infer_value = [this](DeviceTensorND& dest, const InpVal&) {
if (!m_host_tensor->layout().ndim) {
return false;
}
dest = m_host_tensor->proxy_to_default_cpu();
return true;
};
mgr.register_value_infer(output(0), {SourceType::MUTABLE, {}, infer_value});
}
}
} }