megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file dnn/src/common/rnn_cell.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#include "src/common/rnn_cell.h"
#include "megdnn/oprs.h"
#include "src/common/utils.h"

namespace megdnn {

void RNNCell::deduce_layout(
        const TensorLayout& input, const TensorLayout& weight_ih,
        const TensorLayout& /*bias_ih*/, const TensorLayout& hx,
        const TensorLayout& /*weight_hh*/, const TensorLayout& /*bias_hh*/,
        TensorLayout& dst) {
    size_t batch_size = hx.shape[0];
    size_t gate_hidden_size = weight_ih.shape[0];

    dst = TensorLayout(TensorShape({batch_size, gate_hidden_size}), input.dtype);
}

void RNNCell::check_exec(
        const TensorLayout& input, const TensorLayout& weight_ih,
        const TensorLayout& bias_ih, const TensorLayout& hx,
        const TensorLayout& weight_hh, const TensorLayout& bias_hh,
        const TensorLayout& dst, size_t workspace_in_bytes) {
    TensorLayout dst_expected;
    auto errmsg = [&]() {
        std::string msg;
        msg.append("input=");
        msg.append(input.to_string());
        msg.append(", weight_ih=");
        msg.append(weight_ih.to_string());
        msg.append(", bias_ih=");
        msg.append(bias_ih.to_string());
        msg.append(", hx=");
        msg.append(hx.to_string());
        msg.append(", weight_hh=");
        msg.append(weight_hh.to_string());
        msg.append(", bias_hh=");
        msg.append(bias_hh.to_string());
        msg.append(", dst=");
        msg.append(dst.to_string());
        return msg;
    };
#define ASSERT_BRIEF(_content) megdnn_assert(_content, "%s", errmsg().c_str());

    ASSERT_BRIEF(input.ndim == 2)
    ASSERT_BRIEF(hx.ndim == 2)
    ASSERT_BRIEF(hx.shape[0] == input.shape[0])  // batch
    ASSERT_BRIEF(input.shape[1] == weight_ih.shape[1])
    ASSERT_BRIEF(hx.shape[0] == dst.shape[0])  // batch
    ASSERT_BRIEF(hx.shape[1] == dst.shape[1])
    ASSERT_BRIEF(hx.shape[1] == weight_ih.shape[0])  // hidden_size
    ASSERT_BRIEF(weight_ih.shape[0] == weight_hh.shape[0])
    ASSERT_BRIEF(weight_hh.shape[0] == weight_hh.shape[1])
    ASSERT_BRIEF(bias_ih.shape[0] == bias_hh.shape[0])
#undef ASSERT_BRIEF
    megdnn_assert_eq_dtype(input, dst);
    megdnn_assert_eq_dtype(hx, dst);
    deduce_layout(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst_expected);
    megdnn_assert_eq_layout(dst_expected, dst);

    auto required_workspace_in_bytes = get_workspace_in_bytes(
            input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst);
    megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}

}  // namespace megdnn

namespace megdnn {
namespace rnn_cell {

size_t get_workspace_in_bytes(
        const TensorLayout& input, const TensorLayout& weight_ih,
        const TensorLayout& /*bias_ih*/, const TensorLayout& hx,
        const TensorLayout& weight_hh, const TensorLayout& /*bias_hh*/,
        const TensorLayout& dst, Handle* handle) {
    auto opr = handle->create_operator<MatrixMulForward>();
    opr->param().transposeB = true;
    return dst.span().dist_byte() +
           std::max(
                   opr->get_workspace_in_bytes(hx, weight_hh, dst),
                   opr->get_workspace_in_bytes(input, weight_ih, dst));
}

void exec(
        _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih,
        _megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
        _megdnn_tensor_out dst, _megdnn_workspace workspace,
        param::RNNCell::NonlineMode nonline_mode, Handle* handle) {
    TensorND tmp{static_cast<void*>(workspace.raw_ptr), dst.layout};
    _megdnn_workspace new_workspace = {
            workspace.raw_ptr + dst.layout.span().dist_byte(),
            workspace.size - dst.layout.span().dist_byte()};
    auto opr = handle->create_operator<MatrixMulForward>();
    opr->param().transposeB = true;
    opr->exec(input, weight_ih, tmp, new_workspace);
    opr->exec(hx, weight_hh, dst, new_workspace);
    auto add_opr = handle->create_operator<ElemwiseForward>();
    add_opr->param().mode = Elemwise::Param::Mode::ADD;
    add_opr->exec({dst, tmp}, dst);
    add_opr->exec({dst, bias_ih}, dst);
    add_opr->exec({dst, bias_hh}, dst);

    // activation
    using NonlineMode = param::RNNCell::NonlineMode;

    switch (nonline_mode) {
#define cb(_mode)                                                    \
    case NonlineMode::_mode: {                                       \
        auto nonlinear = handle->create_operator<ElemwiseForward>(); \
        nonlinear->param().mode = Elemwise::Param::Mode::_mode;      \
        nonlinear->exec({dst}, dst);                                 \
        break;                                                       \
    }
        cb(RELU);
        cb(TANH);
#undef cb
        case NonlineMode::IDENTITY:
            break;
        default:
            megdnn_assert(false);
    }
}

}  // namespace rnn_cell
}  // namespace megdnn