#include "megdnn/oprs.h"
#include "src/common/utils.h"
#include <numeric>
namespace megdnn {
ConcatSplitBase::ConcatSplitBase(Handle* handle)
: OperatorBase(handle),
m_get_layout([](const TensorND& tensor) { return tensor.layout; }),
m_get_shape([](const TensorLayout& layout) { return TensorShape(layout); }) {}
void ConcatSplitBase::check_layout_common(
const TensorLayoutArray& srcs, const TensorLayout& dst) {
for (auto&& src : srcs) {
megdnn_assert(src.dtype == dst.dtype);
}
for (auto&& src : srcs) {
megdnn_assert_contiguous(src);
}
megdnn_assert_contiguous(dst);
auto ndim = dst.ndim;
for (auto&& src : srcs) {
megdnn_assert_eq_size_t(src.ndim, ndim);
}
megdnn_assert(
param().axis < static_cast<int32_t>(ndim), "param().axis=%u, ndim=%zu",
param().axis, ndim);
for (size_t i = 0; i < ndim; ++i) {
if (i == static_cast<size_t>(param().axis)) {
size_t sum = 0_z;
for (auto&& src : srcs)
sum += src.shape[i];
megdnn_assert_eq_size_t(sum, dst.shape[i]);
} else {
for (auto&& src : srcs) {
megdnn_assert(src.shape[i] == dst.shape[i]);
megdnn_assert_eq_size_t(src.shape[i], dst.shape[i]);
}
}
}
}
void ConcatSplitBase::get_ABC(
const TensorShapeArray& srcs, size_t& A, size_t* B, size_t& C) {
auto axis = param().axis;
auto shape_arr = srcs[0].shape;
auto ndim = srcs[0].ndim;
A = std::accumulate(shape_arr, shape_arr + axis, 1_z, SafeMultiplies<size_t>());
for (size_t i = 0u; i < srcs.size(); ++i) {
B[i] = srcs[i].shape[axis];
}
C = std::accumulate(
shape_arr + (axis + 1), shape_arr + ndim, 1_z, SafeMultiplies<size_t>());
}
void ConcatForward::deduce_layout(const TensorLayoutArray& srcs, TensorLayout& dst) {
dst = srcs[0];
auto i = param().axis;
dst.shape[i] = 0u;
for (auto&& src : srcs) {
dst.shape[i] += src.shape[i];
}
dst.init_contiguous_stride();
}
void ConcatForward::check_exec(
const TensorLayoutArray& srcs, const TensorLayout& dst,
size_t workspace_in_bytes) {
check_layout_common(srcs, dst);
auto required_workspace_in_bytes = get_workspace_in_bytes(srcs, dst);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
void SplitForward::check_exec(
const TensorLayout& src, const TensorLayoutArray& dsts,
size_t workspace_in_bytes) {
check_layout_common(dsts, src);
auto required_workspace_in_bytes = get_workspace_in_bytes(src, dsts);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
}