megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file dnn/src/common/lstm_cell.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#include "src/common/lstm_cell.h"
#include "megdnn/oprs.h"
#include "src/common/utils.h"

namespace megdnn {

void LSTMCell::deduce_layout(
        const TensorLayout& input, const TensorLayout& weight_ih,
        const TensorLayout& bias_ih, const TensorLayout& hx,
        const TensorLayout& weight_hh, const TensorLayout& bias_hh,
        const TensorLayout& cx, TensorLayout& h_new, TensorLayout& c_new,
        TensorLayout& gates) {
    h_new = TensorLayout(hx, hx.dtype);
    c_new = TensorLayout(cx, cx.dtype);
    auto opr = handle()->create_operator<RNNCellForward>();
    opr->param().nonlineMode = param::RNNCell::NonlineMode::IDENTITY;
    opr->deduce_layout(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, gates);
}

void LSTMCell::check_exec(
        const TensorLayout& input, const TensorLayout& weight_ih,
        const TensorLayout& bias_ih, const TensorLayout& hx,
        const TensorLayout& weight_hh, const TensorLayout& bias_hh,
        const TensorLayout& cx, const TensorLayout& h_new, const TensorLayout& c_new,
        const TensorLayout& gates, size_t workspace_in_bytes) {
    TensorLayout h_new_expected, c_new_expected, gates_expected;
    auto errmsg = [&]() {
        std::string msg;
        msg.append("input=");
        msg.append(input.to_string());
        msg.append(", weight_ih=");
        msg.append(weight_ih.to_string());
        msg.append(", bias_ih=");
        msg.append(bias_ih.to_string());
        msg.append(", hx=");
        msg.append(hx.to_string());
        msg.append(", weight_hh=");
        msg.append(weight_hh.to_string());
        msg.append(", bias_hh=");
        msg.append(bias_hh.to_string());
        msg.append(", cx=");
        msg.append(cx.to_string());
        return msg;
    };
#define ASSERT_BRIEF(_content) megdnn_assert(_content, "%s", errmsg().c_str());

    ASSERT_BRIEF(input.ndim == 2)
    ASSERT_BRIEF(input.shape[1] == weight_ih.shape[1])
    ASSERT_BRIEF(weight_ih.shape[0] == weight_hh.shape[0])
    ASSERT_BRIEF(weight_hh.shape[0] == 4 * weight_hh.shape[1])
    ASSERT_BRIEF(bias_ih.shape[0] == bias_hh.shape[0])
    ASSERT_BRIEF(hx.ndim == 2)
    ASSERT_BRIEF(hx.shape[0] == input.shape[0])
    ASSERT_BRIEF(hx.shape[1] == cx.shape[1])  // hidden_size
    ASSERT_BRIEF(cx.ndim == 2)
    ASSERT_BRIEF(cx.shape[0] == input.shape[0])
    ASSERT_BRIEF(cx.shape[1] == weight_hh.shape[1])
#undef ASSERT_BRIEF

    deduce_layout(
            input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new_expected,
            c_new_expected, gates_expected);
    megdnn_assert_eq_layout(h_new_expected, h_new);
    megdnn_assert_eq_layout(c_new_expected, c_new);
    megdnn_assert_eq_layout(gates_expected, gates);

    auto required_workspace_in_bytes = get_workspace_in_bytes(
            input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, gates);
    megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}

}  // namespace megdnn

namespace megdnn {
namespace lstm_cell {

size_t get_workspace_in_bytes(
        const TensorLayout& input, const TensorLayout& weight_ih,
        const TensorLayout& bias_ih, const TensorLayout& hx,
        const TensorLayout& weight_hh, const TensorLayout& bias_hh,
        const TensorLayout& /*cx*/, const TensorLayout& /*h_new*/,
        const TensorLayout& /*c_new*/, const TensorLayout& gates, Handle* handle) {
    TensorLayout tmp_layout;
    auto opr = handle->create_operator<RNNCellForward>();
    opr->param().nonlineMode = param::RNNCell::NonlineMode::IDENTITY;
    opr->deduce_layout(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, tmp_layout);
    size_t rnn_cell_need = opr->get_workspace_in_bytes(
            input, weight_ih, bias_ih, hx, weight_hh, bias_hh, gates);
    size_t lstm_cell_need = 2 * tmp_layout.span().dist_byte();
    return rnn_cell_need > lstm_cell_need ? rnn_cell_need : lstm_cell_need;
}

void exec(
        _megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih,
        _megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
        _megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new,
        _megdnn_tensor_out gates, _megdnn_workspace workspace, Handle* handle) {
    auto opr = handle->create_operator<RNNCellForward>();
    opr->param().nonlineMode = param::RNNCell::NonlineMode::IDENTITY;
    opr->exec(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, gates, workspace);
    // activation
    size_t batch_size = hx.layout.shape[0];
    size_t hidden_size = hx.layout.shape[1];

    auto copy_opr = handle->create_operator<TypeCvtForward>();
    TensorND copy_gates{static_cast<void*>(workspace.raw_ptr), gates.layout};
    TensorLayout hidden_layout{TensorShape{hidden_size}, hx.layout.dtype};
    TensorLayout gateinfo_layout{TensorShape{batch_size, hidden_size}, hx.layout.dtype};
    for (size_t i = 0; i < batch_size; i++) {
        for (size_t j = 0; j < 4; j++) {
            TensorND half_step_states{
                    // output
                    static_cast<uint8_t*>(gates.raw_ptr()) +
                            (4 * i + j) * hidden_layout.span().dist_byte(),
                    hidden_layout};
            TensorND half_step_output{
                    static_cast<uint8_t*>(copy_gates.raw_ptr()) +
                            j * gateinfo_layout.span().dist_byte() +
                            i * hidden_layout.span().dist_byte(),
                    hidden_layout};
            copy_opr->exec(half_step_states, half_step_output);
        }
    }
    void* workspace_ptr = workspace.raw_ptr + copy_gates.layout.span().dist_byte();
    copy_opr->exec(copy_gates, gates);

    // sigmoid: i f
    TensorND tmp{static_cast<void*>(workspace_ptr), copy_gates.layout};
    TensorLayout gates_ifo_layout{
            TensorShape({batch_size, hidden_size * 2}), copy_gates.layout.dtype};
    TensorND gates_ifo_origin{copy_gates.raw_ptr(), gates_ifo_layout};
    TensorND gates_ifo{tmp.raw_ptr(), gates_ifo_layout};
    auto sigmoid = handle->create_operator<ElemwiseForward>();
    sigmoid->param().mode = Elemwise::Param::Mode::SIGMOID;
    sigmoid->exec({gates_ifo_origin}, gates_ifo);
    // tanh: g
    TensorLayout g_layout{
            TensorShape({batch_size, hidden_size}), copy_gates.layout.dtype};
    TensorND g_origin{
            static_cast<char*>(copy_gates.raw_ptr()) +
                    gates_ifo_layout.span().dist_byte(),
            g_layout};
    TensorND g{
            static_cast<char*>(tmp.raw_ptr()) + gates_ifo_layout.span().dist_byte(),
            g_layout};
    auto tanh = handle->create_operator<ElemwiseForward>();
    tanh->param().mode = Elemwise::Param::Mode::TANH;
    tanh->exec({g_origin}, g);
    // sigmoid: o
    TensorLayout three_gates_ifo_layout{
            TensorShape({batch_size, hidden_size * 3}), copy_gates.layout.dtype};
    TensorLayout o_layout{
            TensorShape({batch_size, hidden_size}), copy_gates.layout.dtype};
    TensorND o_origin{
            static_cast<char*>(copy_gates.raw_ptr()) +
                    three_gates_ifo_layout.span().dist_byte(),
            o_layout};
    TensorND o{
            static_cast<char*>(tmp.raw_ptr()) +
                    three_gates_ifo_layout.span().dist_byte(),
            o_layout};
    sigmoid->exec({o_origin}, o);
    // extract i f o
    TensorND i{static_cast<char*>(tmp.raw_ptr()), g_layout};
    TensorND f{
            static_cast<char*>(tmp.raw_ptr()) + g_layout.span().dist_byte(), g_layout};
    // calculate new cell state
    auto elewise_mul_add = handle->create_operator<ElemwiseForward>();
    elewise_mul_add->param().mode = Elemwise::Param::Mode::FUSE_MUL_ADD4;
    elewise_mul_add->exec({f, cx, i, g}, c_new);
    // calculate new hidden state
    tanh->exec({c_new}, h_new);
    auto elewise_mul = handle->create_operator<ElemwiseForward>();
    elewise_mul->param().mode = Elemwise::Param::Mode::MUL;
    elewise_mul->exec({o, h_new}, h_new);
}

}  // namespace lstm_cell
}  // namespace megdnn