#include "src/fallback/tile/opr_impl.h"
#include <cstring>
#include <numeric>
#include "src/common/tile_repeat_helper.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"
namespace megdnn {
namespace fallback {
size_t TileImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) {
auto workspace_size = get_workspace_in_bytes_fwd(src, dst);
return workspace_size;
}
void TileImpl::exec(
_megdnn_tensor_in src_, _megdnn_tensor_out dst_, _megdnn_workspace workspace) {
check_exec(src_.layout, dst_.layout, workspace.size);
TensorShape src, dst, times;
simplify_shape(src_.layout, dst_.layout, param().times, src, dst, times);
auto nr_reduces = count_not_ones_in_shape(times);
if (nr_reduces == 0) {
MEGDNN_DISPATCH_CPU_KERN_OPR(std::memcpy(
dst_.raw_ptr(), src_.raw_ptr(), sizeof(float) * dst.total_nr_elems()));
return;
}
auto kern = [=]() {
auto ndim = times.ndim;
WorkspaceBundle workspaces(
workspace.raw_ptr, {dst.total_nr_elems() * sizeof(float),
dst.total_nr_elems() * sizeof(float)});
auto workspace0 = static_cast<float*>(workspaces.get(0));
auto workspace1 = static_cast<float*>(workspaces.get(1));
float *current, *next;
size_t state;
init_tile_repeat_state(
src_.ptr<dt_float32>(), dst_.ptr<dt_float32>(), workspace0, workspace1,
current, next, state, nr_reduces);
for (size_t i = ndim; i > 0; --i) {
size_t j = i - 1;
if (times.shape[j] != 1) {
auto m = std::accumulate(
src.shape, src.shape + j, 1_z, SafeMultiplies<size_t>());
auto n = std::accumulate(
dst.shape + i, dst.shape + ndim, 1_z,
SafeMultiplies<size_t>()) *
src.shape[j];
tile_or_repeat_single_axis(current, next, m, n, times[j]);
update_tile_repeat_state(
src_.ptr<dt_float32>(), dst_.ptr<dt_float32>(), workspace0,
workspace1, current, next, state, nr_reduces);
}
}
};
MEGDNN_DISPATCH_CPU_KERN_OPR(kern());
}
} }