#include "megdnn/oprs.h"
#include "src/common/utils.h"
#include <numeric>
namespace megdnn {
void TileRepeatBase::check_layout_fwd(
const TensorLayout& src, const TensorLayout& dst) {
auto errmsg = megdnn_layout_msg(src) + ", " + megdnn_layout_msg(dst) + ", " +
"times=" + param().times.to_string();
auto errmsg_c = errmsg.c_str();
MEGDNN_MARK_USED_VAR(errmsg_c);
megdnn_assert_contiguous(src);
megdnn_assert_contiguous(dst);
auto expected_ndim = param().times.ndim;
megdnn_assert(expected_ndim == src.ndim, "%s", errmsg_c);
megdnn_assert(expected_ndim == dst.ndim, "%s", errmsg_c);
rep(i, expected_ndim) {
megdnn_assert(dst.shape[i] == param().times[i] * src.shape[i], "%s", errmsg_c);
}
megdnn_assert(src.dtype == dst.dtype);
}
void TileRepeatBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst) {
dst.ndim = src.ndim;
rep(i, src.ndim) { dst.shape[i] = src.shape[i] * param().times[i]; }
dst.dtype = src.dtype;
dst.init_contiguous_stride();
check_layout_fwd(src, dst);
}
size_t TileRepeatBase::get_workspace_in_bytes_fwd(
const TensorShape& , const TensorShape& dst, const TensorShape& times,
DType dtype) {
size_t nr_workspace = 0;
auto nr_reduces = count_not_ones_in_shape(times);
if (nr_reduces == 0) {
nr_workspace = 0;
} else if (nr_reduces == 1) {
nr_workspace = 0;
} else if (nr_reduces == 2) {
nr_workspace = 1;
} else {
nr_workspace = 2;
}
if (nr_workspace == 0) {
return 0;
} else {
WorkspaceBundle workspaces{
nullptr, {nr_workspace, dst.total_nr_elems() * dtype.size()}};
return workspaces.total_size_in_bytes();
}
}
void TileBase::simplify_shape(
const TensorShape& src, const TensorShape& dst, const TensorShape& times,
TensorShape& src2, TensorShape& dst2, TensorShape& times2) {
size_t n = 0;
for (size_t i = 0; i < src.ndim; ++i) {
if (times.shape[i] == 1 && n > 0) {
src2.shape[n - 1] *= src.shape[i];
dst2.shape[n - 1] *= dst.shape[i];
} else {
src2.shape[n] = src.shape[i];
dst2.shape[n] = dst.shape[i];
times2.shape[n] = times.shape[i];
++n;
}
}
src2.ndim = dst2.ndim = times2.ndim = n;
}
size_t TileBase::get_workspace_in_bytes_fwd(
const TensorLayout& src_, const TensorLayout& dst_) {
TensorShape src, dst, times;
simplify_shape(src_, dst_, param().times, src, dst, times);
return TileRepeatBase::get_workspace_in_bytes_fwd(src, dst, times, src_.dtype);
}
void TileForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) {
deduce_layout_fwd(src, dst);
}
void TileForward::check_exec(
const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) {
check_layout_fwd(src, dst);
auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
void TileBackward::check_exec(
const TensorLayout& diff, const TensorLayout& grad, size_t workspace_in_bytes) {
check_layout_fwd(grad, diff);
auto required_workspace_in_bytes = get_workspace_in_bytes(diff, grad);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
void RepeatBase::simplify_shape(
const TensorShape& src, const TensorShape& , const TensorShape& times,
TensorShape& src2, TensorShape& dst2, TensorShape& times2) {
auto n = 0u;
size_t i = 0;
while (i < times.ndim) {
size_t j = i;
while (j < times.ndim && times.shape[j] == 1)
++j;
if (j < times.ndim)
++j;
src2.shape[n] = std::accumulate(
src.shape + i, src.shape + j, 1_z, SafeMultiplies<size_t>());
times2.shape[n] = times.shape[j - 1];
dst2.shape[n] = src2.shape[n] * times2.shape[n];
++n;
i = j;
}
src2.ndim = dst2.ndim = times2.ndim = n;
}
size_t RepeatBase::get_workspace_in_bytes_fwd(
const TensorLayout& src_, const TensorLayout& dst_) {
TensorShape src, dst, times;
simplify_shape(src_, dst_, param().times, src, dst, times);
return TileRepeatBase::get_workspace_in_bytes_fwd(src, dst, times, src_.dtype);
}
void RepeatForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) {
deduce_layout_fwd(src, dst);
}
void RepeatForward::check_exec(
const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) {
check_layout_fwd(src, dst);
auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
void RepeatBackward::check_exec(
const TensorLayout& diff, const TensorLayout& grad, size_t workspace_in_bytes) {
check_layout_fwd(grad, diff);
auto required_workspace_in_bytes = get_workspace_in_bytes(diff, grad);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
}