megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file dnn/src/fallback/tile/opr_impl.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#include "src/fallback/tile/opr_impl.h"

#include <cstring>
#include <numeric>
#include "src/common/tile_repeat_helper.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"

namespace megdnn {
namespace fallback {

size_t TileImpl::get_workspace_in_bytes(
        const TensorLayout& src, const TensorLayout& dst) {
    auto workspace_size = get_workspace_in_bytes_fwd(src, dst);
    return workspace_size;
}

void TileImpl::exec(
        _megdnn_tensor_in src_, _megdnn_tensor_out dst_, _megdnn_workspace workspace) {
    check_exec(src_.layout, dst_.layout, workspace.size);
    TensorShape src, dst, times;
    simplify_shape(src_.layout, dst_.layout, param().times, src, dst, times);
    auto nr_reduces = count_not_ones_in_shape(times);
    if (nr_reduces == 0) {
        MEGDNN_DISPATCH_CPU_KERN_OPR(std::memcpy(
                dst_.raw_ptr(), src_.raw_ptr(), sizeof(float) * dst.total_nr_elems()));
        return;
    }

    auto kern = [=]() {
        auto ndim = times.ndim;
        WorkspaceBundle workspaces(
                workspace.raw_ptr, {dst.total_nr_elems() * sizeof(float),
                                    dst.total_nr_elems() * sizeof(float)});
        auto workspace0 = static_cast<float*>(workspaces.get(0));
        auto workspace1 = static_cast<float*>(workspaces.get(1));

        float *current, *next;
        size_t state;

        init_tile_repeat_state(
                src_.ptr<dt_float32>(), dst_.ptr<dt_float32>(), workspace0, workspace1,
                current, next, state, nr_reduces);

        for (size_t i = ndim; i > 0; --i) {
            size_t j = i - 1;
            if (times.shape[j] != 1) {
                // m = sshape[0]*...*sshape[i-2]
                auto m = std::accumulate(
                        src.shape, src.shape + j, 1_z, SafeMultiplies<size_t>());
                // n = sshape[i-1]*dshape[i]*...
                auto n = std::accumulate(
                                 dst.shape + i, dst.shape + ndim, 1_z,
                                 SafeMultiplies<size_t>()) *
                         src.shape[j];
                // forward is repeat (m, n) to (m*times, n)
                tile_or_repeat_single_axis(current, next, m, n, times[j]);
                update_tile_repeat_state(
                        src_.ptr<dt_float32>(), dst_.ptr<dt_float32>(), workspace0,
                        workspace1, current, next, state, nr_reduces);
            }
        }
    };
    MEGDNN_DISPATCH_CPU_KERN_OPR(kern());
}

}  // namespace fallback
}  // namespace megdnn

// vim: syntax=cpp.doxygen