#include "megbrain/opr/tensor_manip.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "../async_releaser.h"
#include "../dnn_op_helper.h"
#include "../op_trait.h"
namespace mgb::imperative {
namespace get_var_shape {
cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op_def = def.cast_final_safe<GetVarShape>();
OperatorNodeConfig config{op_def.make_name()};
return opr::GetVarShape::make(inputs, op_def.param(), config).node()->owner_opr();
}
DispatchMode decide_dispatch_mode(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
bool host_computable = true;
for (auto&& inp : inputs) {
if (inp.value.empty() || inp.value.layout().ndim == 0) {
host_computable = false;
break;
}
}
return host_computable ? DEFAULT_CPU : KERNEL;
}
void apply_on_device_tensornd(
const OpDef& def, const SmallVector<DeviceTensorND>& inputs,
SmallVector<DeviceTensorND>* outputs) {
auto&& op_def = def.cast_final_safe<GetVarShape>();
TensorShape shp;
if (inputs.size() == 1) {
shp = inputs[0].layout();
} else {
TensorShapeArray src(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
src[i] = inputs[i].layout();
}
megdnn::Elemwise::deduce_shape(src, shp);
}
mgb_assert(shp.ndim != 0, "input shape invalid");
mgb_assert(
(*outputs)[0].comp_node() == CompNode::default_cpu(),
"GetVarShape's apply_on_device_tensornd should receive default_cpu "
"outputs.");
HostTensorND hv;
if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS) {
hv = HostTensorND(CompNode::default_cpu(), {shp.ndim}, dtype::Int32());
auto* ptr = hv.ptr<dt_int32>();
for (size_t i = 0; i < shp.ndim; ++i) {
ptr[i] = shp.shape[i];
}
} else {
int32_t axis = op_def.axis;
if (axis < 0) {
axis += shp.ndim;
}
mgb_assert(axis >= 0 && axis < (int32_t)shp.ndim);
hv = HostTensorND(CompNode::default_cpu(), {1}, dtype::Int32());
auto* ptr = hv.ptr<dt_int32>();
ptr[0] = shp.shape[axis];
}
(*outputs)[0] = DeviceTensorND::make_proxy(hv);
}
HostTensorND get_var_shape_host_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
SmallVector<DeviceTensorND> input_tensornds;
for (auto&& inp : inputs) {
input_tensornds.push_back(inp->dev_tensor());
}
SmallVector<DeviceTensorND> output_tensornds = {
{CompNode::default_cpu(), dtype::Int32()}};
apply_on_device_tensornd(def, input_tensornds, &output_tensornds);
return HostTensorND::make_proxy(output_tensornds[0])
.proxy_to_comp_node(inputs[0]->comp_node());
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
return {Tensor::make(std::move(get_var_shape_host_tensor(def, inputs)))};
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
auto&& op_def = def.cast_final_safe<GetVarShape>();
auto&& desc = inputs[0];
TensorShape shp;
if (inputs.size() == 1) {
shp = desc.layout;
} else {
TensorShapeArray src(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
src[i] = inputs[i].layout;
}
megdnn::Elemwise::deduce_shape(src, shp);
}
if (!shp.ndim) {
return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, false};
}
DeviceTensorND value;
if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS) {
value = DeviceTensorND(CompNode::default_cpu(), {shp.ndim}, dtype::Int32());
auto* ptr = value.ptr<dt_int32>();
for (size_t i = 0; i < shp.ndim; ++i) {
ptr[i] = shp[i];
}
} else {
int32_t axis = op_def.axis;
if (axis < 0) {
axis += shp.ndim;
}
mgb_assert(axis >= 0 && axis < (int32_t)shp.ndim);
value = DeviceTensorND(CompNode::default_cpu(), {1}, dtype::Int32());
auto* ptr = value.ptr<dt_int32>();
ptr[0] = shp[axis];
}
return {{{value.layout(), desc.comp_node, std::move(value)}}, true};
}
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
auto* node = &node_->cast_final_safe<opr::GetVarShape>();
return GetVarShape::make(node->param());
}
OP_TRAIT_REG(GetVarShape, GetVarShape, opr::GetVarShape)
.make_from_op_node(make_from_op_node)
.decide_dispatch_mode(decide_dispatch_mode)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_var_node(apply_on_var_node)
.apply_on_device_tensornd(apply_on_device_tensornd)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
}
namespace param_pack {
TensorShapeArray get_shapes(const std::vector<std::vector<size_t>>& shapes) {
TensorShapeArray ret;
for (auto&& i : shapes) {
SmallVector<size_t> shape(i.begin(), i.end());
TensorShape shp(shape);
ret.push_back(shp);
}
return ret;
}
cg::OperatorNodeBase* param_pack_split_apply_on_var_node(
const OpDef& def, const VarNodeArray& inputs) {
auto&& param = def.cast_final_safe<ParamPackSplit>();
auto&& graph = inputs[0]->owner_graph();
auto&& shapes = get_shapes(param.shapes);
OperatorNodeConfig config(param.make_name());
cg::OperatorNodeBase* opr =
graph->insert_opr(std::make_unique<mgb::opr::ParamPackSplit>(
inputs[0], param.offsets, shapes, config));
return opr;
}
SmallVector<TensorPtr> param_pack_split_apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& param = def.cast_final_safe<ParamPackSplit>();
mgb_assert(
inputs.size() == 1, "ParamPackSplit take 1 input, got %lu", inputs.size());
auto&& inp = inputs[0];
auto&& shp = inp->layout();
mgb_assert(shp.ndim == 1, "ParamPackSplit input shape invalid, ndim should be 1");
mgb_assert(param.shapes.size() * 2 == param.offsets.size());
SmallVector<TensorPtr> ret;
auto&& shapes = get_shapes(param.shapes);
size_t dtype_size = inputs[0]->layout().dtype.size();
for (size_t i = 0; i < shapes.size(); ++i) {
ret.push_back(inputs[0]->sub(param.offsets[i * 2] * dtype_size, shapes[i]));
}
return ret;
}
OP_TRAIT_REG(ParamPackSplit, ParamPackSplit, mgb::opr::ParamPackSplit)
.apply_on_var_node(param_pack_split_apply_on_var_node)
.apply_on_physical_tensor(param_pack_split_apply_on_physical_tensor)
.fallback();
cg::OperatorNodeBase* param_pack_concat_apply_on_var_node(
const OpDef& def, const VarNodeArray& inputs) {
auto&& param = def.cast_final_safe<ParamPackConcat>();
auto&& graph = inputs[0]->owner_graph();
VarNodeArray inps(inputs.begin(), inputs.end() - 1);
OperatorNodeConfig config{param.make_name()};
cg::OperatorNodeBase* opr =
graph->insert_opr(std::make_unique<mgb::opr::ParamPackConcat>(
inps, inputs.back(), param.offsets, config));
return opr;
}
SmallVector<TensorPtr> param_pack_concat_apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
def.cast_final_safe<ParamPackConcat>();
mgb_assert(inputs.size() > 1, "param_pack should have at least one input");
auto comp_node = inputs.front()->comp_node();
auto dtype = inputs.front()->dtype();
size_t nr_inputs = inputs.size() - 1;
size_t nr_elems = 0;
for (size_t i = 0; i < nr_inputs; ++i) {
auto& input = inputs[i];
mgb_assert(
comp_node == input->comp_node(),
"inputs for param_pack_concat must in same comp_node");
mgb_assert(
dtype == input->dtype(),
"inputs for param_pack_concat must have same dtype");
nr_elems += input->layout().total_nr_elems();
}
auto dest_layout = TensorLayout({nr_elems}, dtype);
auto output = Tensor::make(dest_layout, comp_node);
auto caller = DnnOprCaller<megdnn::ParamPackConcat>(comp_node);
size_t srcs_size = sizeof(void*) * nr_inputs;
void** srcs_raw_ptr = (void**)comp_node.alloc_host(srcs_size);
std::shared_ptr<dt_byte> srcs_ptr = {
(dt_byte*)srcs_raw_ptr,
[comp_node](dt_byte* ptr) { comp_node.free_host(ptr); }};
TensorLayout srcs_layout = TensorLayout{{nr_inputs}, dtype::Int32()};
size_t ws_size;
{
TensorShapeArray src_shapes;
for (size_t i = 0; i < nr_inputs; ++i) {
src_shapes.push_back(inputs[i]->shape());
}
ws_size = caller.op->get_workspace_in_bytes(
src_shapes, inputs.back()->shape(), TensorShape{});
}
for (size_t i = 0; i < nr_inputs; ++i) {
srcs_raw_ptr[i] = inputs[i]->dev_tensor().as_megdnn().raw_ptr();
}
HostTensorStorage srcs_storage;
srcs_storage.reset(comp_node, srcs_size, srcs_ptr);
caller.op->exec(
{srcs_raw_ptr, srcs_layout}, inputs.back()->dev_tensor().as_megdnn(),
output->dev_tensor().as_megdnn(),
caller.create_workspace({{ws_size}, dtype::Byte()}));
AsyncReleaser::inst()->add(
HostTensorND{comp_node, srcs_layout}.storage(srcs_storage));
return {output};
}
OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat)
.apply_on_var_node(param_pack_concat_apply_on_var_node)
.apply_on_physical_tensor(param_pack_concat_apply_on_physical_tensor)
.fallback();
}
namespace split {
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
using Options = opr::Split::Options;
auto* node = &node_->cast_final_safe<opr::Split>();
auto&& opt = node->options();
int axis = opt.axis;
mgb_assert(
opt.method == Options::Method::SPECIFY,
"only Split with SPECIFY output shapes is supported");
mgb_assert(opt.partition.size() == opt.nr_part);
return Split::make(axis, 0);
}
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
using Options = opr::Split::Options;
auto&& sp = static_cast<const Split&>(def);
OperatorNodeConfig config{sp.make_name()};
opr::Split::Options opt;
if (sp.nsections) {
opt = Options::make_average(sp.axis, sp.nsections);
opt.method = Options::Method::CALL_BACK;
} else {
opt.axis = sp.axis;
opt.method = Options::Method::SPECIFY;
mgb_assert(inputs.size() > 1);
opt.nr_part = inputs.size() - 1;
opt.partition.resize(opt.nr_part);
for (size_t i = 1; i < inputs.size(); ++i)
opt.partition[i - 1] = inputs[i];
}
return opr::Split::make(inputs[0], opt, config);
}
OP_TRAIT_REG(Split, Split, opr::Split)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.fallback();
}
}