#include "src/common/rnn_cell.h"
#include "megdnn/oprs.h"
#include "src/common/utils.h"
namespace megdnn {
void RNNCell::deduce_layout(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& , const TensorLayout& hx,
const TensorLayout& , const TensorLayout& ,
TensorLayout& dst) {
size_t batch_size = hx.shape[0];
size_t gate_hidden_size = weight_ih.shape[0];
dst = TensorLayout(TensorShape({batch_size, gate_hidden_size}), input.dtype);
}
void RNNCell::check_exec(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& bias_ih, const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& bias_hh,
const TensorLayout& dst, size_t workspace_in_bytes) {
TensorLayout dst_expected;
auto errmsg = [&]() {
std::string msg;
msg.append("input=");
msg.append(input.to_string());
msg.append(", weight_ih=");
msg.append(weight_ih.to_string());
msg.append(", bias_ih=");
msg.append(bias_ih.to_string());
msg.append(", hx=");
msg.append(hx.to_string());
msg.append(", weight_hh=");
msg.append(weight_hh.to_string());
msg.append(", bias_hh=");
msg.append(bias_hh.to_string());
msg.append(", dst=");
msg.append(dst.to_string());
return msg;
};
#define ASSERT_BRIEF(_content) megdnn_assert(_content, "%s", errmsg().c_str());
ASSERT_BRIEF(input.ndim == 2)
ASSERT_BRIEF(hx.ndim == 2)
ASSERT_BRIEF(hx.shape[0] == input.shape[0]) ASSERT_BRIEF(input.shape[1] == weight_ih.shape[1])
ASSERT_BRIEF(hx.shape[0] == dst.shape[0]) ASSERT_BRIEF(hx.shape[1] == dst.shape[1])
ASSERT_BRIEF(hx.shape[1] == weight_ih.shape[0]) ASSERT_BRIEF(weight_ih.shape[0] == weight_hh.shape[0])
ASSERT_BRIEF(weight_hh.shape[0] == weight_hh.shape[1])
ASSERT_BRIEF(bias_ih.shape[0] == bias_hh.shape[0])
#undef ASSERT_BRIEF
megdnn_assert_eq_dtype(input, dst);
megdnn_assert_eq_dtype(hx, dst);
deduce_layout(input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst_expected);
megdnn_assert_eq_layout(dst_expected, dst);
auto required_workspace_in_bytes = get_workspace_in_bytes(
input, weight_ih, bias_ih, hx, weight_hh, bias_hh, dst);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
}
namespace megdnn {
namespace rnn_cell {
size_t get_workspace_in_bytes(
const TensorLayout& input, const TensorLayout& weight_ih,
const TensorLayout& , const TensorLayout& hx,
const TensorLayout& weight_hh, const TensorLayout& ,
const TensorLayout& dst, Handle* handle) {
auto opr = handle->create_operator<MatrixMulForward>();
opr->param().transposeB = true;
return dst.span().dist_byte() +
std::max(
opr->get_workspace_in_bytes(hx, weight_hh, dst),
opr->get_workspace_in_bytes(input, weight_ih, dst));
}
void exec(
_megdnn_tensor_in input, _megdnn_tensor_in weight_ih, _megdnn_tensor_in bias_ih,
_megdnn_tensor_in hx, _megdnn_tensor_in weight_hh, _megdnn_tensor_in bias_hh,
_megdnn_tensor_out dst, _megdnn_workspace workspace,
param::RNNCell::NonlineMode nonline_mode, Handle* handle) {
TensorND tmp{static_cast<void*>(workspace.raw_ptr), dst.layout};
_megdnn_workspace new_workspace = {
workspace.raw_ptr + dst.layout.span().dist_byte(),
workspace.size - dst.layout.span().dist_byte()};
auto opr = handle->create_operator<MatrixMulForward>();
opr->param().transposeB = true;
opr->exec(input, weight_ih, tmp, new_workspace);
opr->exec(hx, weight_hh, dst, new_workspace);
auto add_opr = handle->create_operator<ElemwiseForward>();
add_opr->param().mode = Elemwise::Param::Mode::ADD;
add_opr->exec({dst, tmp}, dst);
add_opr->exec({dst, bias_ih}, dst);
add_opr->exec({dst, bias_hh}, dst);
using NonlineMode = param::RNNCell::NonlineMode;
switch (nonline_mode) {
#define cb(_mode) \
case NonlineMode::_mode: { \
auto nonlinear = handle->create_operator<ElemwiseForward>(); \
nonlinear->param().mode = Elemwise::Param::Mode::_mode; \
nonlinear->exec({dst}, dst); \
break; \
}
cb(RELU);
cb(TANH);
#undef cb
case NonlineMode::IDENTITY:
break;
default:
megdnn_assert(false);
}
}
} }