#include "src/common/lstm_cell.h"
#include "megdnn/oprs.h"
#include "src/common/utils.h"
namespace megdnn {
void LSTMCell::deduce_layout(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& cx, TensorLayout& h_new, TensorLayout& c_new,
TensorLayout& gates) {
h_new = TensorLayout(hx, hx.dtype);
c_new = TensorLayout(cx, cx.dtype);
auto opr = handle()->create_operator<RNNCellForward>();
opr->param().nonlineMode = param::RNNCell::NonlineMode::IDENTITY;
opr->deduce_layout(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, gates);
}
void LSTMCell::check_exec(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& cx, const TensorLayout& h_new, const TensorLayout& c_new,
const TensorLayout& gates, size_t workspace_in_bytes) {
TensorLayout h_new_expected, c_new_expected, gates_expected;
auto errmsg = [&]() {
std::string msg;
msg.append("input=");
msg.append(input.to_string());
msg.append(", weight_ih=");
msg.append(weight_ih.to_string());
msg.append(", bias_ih=");
msg.append(bias_ih.to_string());
msg.append(", hx=");
msg.append(hx.to_string());
msg.append(", weight_hh=");
msg.append(weight_hh.to_string());
msg.append(", bias_hh=");
msg.append(bias_hh.to_string());
msg.append(", cx=");
msg.append(cx.to_string());
return msg;
};
#define ASSERT_BRIEF(_content) megdnn_assert(_content, "%s", errmsg().c_str());
ASSERT_BRIEF(input.ndim == 2)
ASSERT_BRIEF(input.shape[1] == weight_ih.shape[1])
ASSERT_BRIEF(weight_ih.shape[0] == weight_hh.shape[0])
ASSERT_BRIEF(weight_hh.shape[0] == 4 * weight_hh.shape[1])
ASSERT_BRIEF(bias_ih.shape[0] == bias_hh.shape[0])
ASSERT_BRIEF(hx.ndim == 2)
ASSERT_BRIEF(hx.shape[0] == input.shape[0])
ASSERT_BRIEF(hx.shape[1] == cx.shape[1]) ASSERT_BRIEF(cx.ndim == 2)
ASSERT_BRIEF(cx.shape[0] == input.shape[0])
ASSERT_BRIEF(cx.shape[1] == weight_hh.shape[1])
#undef ASSERT_BRIEF
deduce_layout(
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new_expected,
c_new_expected, gates_expected);
megdnn_assert_eq_layout(h_new_expected, h_new);
megdnn_assert_eq_layout(c_new_expected, c_new);
megdnn_assert_eq_layout(gates_expected, gates);
auto required_workspace_in_bytes = get_workspace_in_bytes(
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, cx, h_new, c_new, gates);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
}
namespace megdnn {
namespace lstm_cell {
size_t get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& , const TensorLayout& ,
const TensorLayout& , const TensorLayout& gates, Handle* handle) {
TensorLayout tmp_layout;
auto opr = handle->create_operator<RNNCellForward>();
opr->param().nonlineMode = param::RNNCell::NonlineMode::IDENTITY;
opr->deduce_layout(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, tmp_layout);
size_t rnn_cell_need = opr->get_workspace_in_bytes(
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, gates);
size_t lstm_cell_need = 2 * tmp_layout.span().dist_byte();
return rnn_cell_need > lstm_cell_need ? rnn_cell_need : lstm_cell_need;
}
void exec(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih,
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
_megdnn_tensor_in cx, _megdnn_tensor_out h_new, _megdnn_tensor_out c_new,
_megdnn_tensor_out gates, _megdnn_workspace workspace, Handle* handle) {
auto opr = handle->create_operator<RNNCellForward>();
opr->param().nonlineMode = param::RNNCell::NonlineMode::IDENTITY;
opr->exec(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, gates, workspace);
size_t batch_size = hx.layout.shape[0];
size_t hidden_size = hx.layout.shape[1];
auto copy_opr = handle->create_operator<TypeCvtForward>();
TensorND copy_gates{static_cast<void*>(workspace.raw_ptr), gates.layout};
TensorLayout hidden_layout{TensorShape{hidden_size}, hx.layout.dtype};
TensorLayout gateinfo_layout{TensorShape{batch_size, hidden_size}, hx.layout.dtype};
for (size_t i = 0; i < batch_size; i++) {
for (size_t j = 0; j < 4; j++) {
TensorND half_step_states{
static_cast<uint8_t*>(gates.raw_ptr()) +
(4 * i + j) * hidden_layout.span().dist_byte(),
hidden_layout};
TensorND half_step_output{
static_cast<uint8_t*>(copy_gates.raw_ptr()) +
j * gateinfo_layout.span().dist_byte() +
i * hidden_layout.span().dist_byte(),
hidden_layout};
copy_opr->exec(half_step_states, half_step_output);
}
}
void* workspace_ptr = workspace.raw_ptr + copy_gates.layout.span().dist_byte();
copy_opr->exec(copy_gates, gates);
TensorND tmp{static_cast<void*>(workspace_ptr), copy_gates.layout};
TensorLayout gates_ifo_layout{
TensorShape({batch_size, hidden_size * 2}), copy_gates.layout.dtype};
TensorND gates_ifo_origin{copy_gates.raw_ptr(), gates_ifo_layout};
TensorND gates_ifo{tmp.raw_ptr(), gates_ifo_layout};
auto sigmoid = handle->create_operator<ElemwiseForward>();
sigmoid->param().mode = Elemwise::Param::Mode::SIGMOID;
sigmoid->exec({gates_ifo_origin}, gates_ifo);
TensorLayout g_layout{
TensorShape({batch_size, hidden_size}), copy_gates.layout.dtype};
TensorND g_origin{
static_cast<char*>(copy_gates.raw_ptr()) +
gates_ifo_layout.span().dist_byte(),
g_layout};
TensorND g{
static_cast<char*>(tmp.raw_ptr()) + gates_ifo_layout.span().dist_byte(),
g_layout};
auto tanh = handle->create_operator<ElemwiseForward>();
tanh->param().mode = Elemwise::Param::Mode::TANH;
tanh->exec({g_origin}, g);
TensorLayout three_gates_ifo_layout{
TensorShape({batch_size, hidden_size * 3}), copy_gates.layout.dtype};
TensorLayout o_layout{
TensorShape({batch_size, hidden_size}), copy_gates.layout.dtype};
TensorND o_origin{
static_cast<char*>(copy_gates.raw_ptr()) +
three_gates_ifo_layout.span().dist_byte(),
o_layout};
TensorND o{
static_cast<char*>(tmp.raw_ptr()) +
three_gates_ifo_layout.span().dist_byte(),
o_layout};
sigmoid->exec({o_origin}, o);
TensorND i{static_cast<char*>(tmp.raw_ptr()), g_layout};
TensorND f{
static_cast<char*>(tmp.raw_ptr()) + g_layout.span().dist_byte(), g_layout};
auto elewise_mul_add = handle->create_operator<ElemwiseForward>();
elewise_mul_add->param().mode = Elemwise::Param::Mode::FUSE_MUL_ADD4;
elewise_mul_add->exec({f, cx, i, g}, c_new);
tanh->exec({c_new}, h_new);
auto elewise_mul = handle->create_operator<ElemwiseForward>();
elewise_mul->param().mode = Elemwise::Param::Mode::MUL;
elewise_mul->exec({o, h_new}, h_new);
}
} }