#include "megdnn/oprs.h"
#include "src/common/utils.h"
namespace megdnn {
void LSTM::deduce_layout(
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
const TensorLayout& , TensorLayout& output, TensorLayout& hy,
TensorLayout& cy, TensorLayout& reserve_space) {
size_t seq_len = input.shape[0];
size_t batch_size = input.shape[1];
size_t D = param().bidirectional ? 2 : 1;
size_t hidden_size = hx.shape[2];
output = TensorLayout(
TensorShape{seq_len, batch_size, D * hidden_size}, input.dtype);
hy = TensorLayout(hx);
cy = TensorLayout(cx);
reserve_space = {{get_reserve_size_in_bytes(input)}, input.dtype};
}
void LSTM::check_exec(
const TensorLayout& input, const TensorLayout& hx, const TensorLayout& cx,
const TensorLayout& flatten_weights, const TensorLayout& output,
const TensorLayout& hy, const TensorLayout& cy,
const TensorLayout& , size_t ) {
auto errmsg = [&]() {
std::string msg;
msg.append("input=");
msg.append(input.to_string());
msg.append(", output=");
msg.append(output.to_string());
msg.append(", hx=");
msg.append(hx.to_string());
msg.append(", cx=");
msg.append(cx.to_string());
msg.append(", hy=");
msg.append(hy.to_string());
msg.append(", cy=");
msg.append(cy.to_string());
msg.append(", flatten_weights=");
msg.append(flatten_weights.to_string());
msg.append(", hidden_size=");
msg.append(std::to_string(param().hidden_size));
msg.append(", num_layers=");
msg.append(std::to_string(param().num_layers));
msg.append(", bidirectional=");
msg.append(std::to_string(param().bidirectional));
return msg;
};
size_t D = param().bidirectional ? 2 : 1;
size_t b = param().bias ? 1 : 0;
size_t num_layers = param().num_layers;
size_t input_size = input.shape[2];
size_t gate_hidden_size = 4 * param().hidden_size;
size_t size_dim1 = D * (input_size + param().hidden_size) +
(num_layers - 1) * D * ((D + 1) * param().hidden_size) +
b * 2 * D * num_layers;
#define ASSERT_BRIEF(_content) megdnn_assert(_content, "%s", errmsg().c_str());
ASSERT_BRIEF(input.ndim == 3)
ASSERT_BRIEF(output.ndim == 3)
ASSERT_BRIEF(flatten_weights.shape[0] == gate_hidden_size)
ASSERT_BRIEF(flatten_weights.shape[0] == size_dim1)
ASSERT_BRIEF(output.shape[0] == input.shape[0])
ASSERT_BRIEF(output.shape[1] == input.shape[1])
ASSERT_BRIEF(output.shape[2] == D * param().hidden_size)
ASSERT_BRIEF(hx.ndim == 3)
ASSERT_BRIEF(hx.shape[0] == D * num_layers)
ASSERT_BRIEF(hx.shape[1] == input.shape[1]) ASSERT_BRIEF(hx.shape[2] == param().hidden_size)
ASSERT_BRIEF(cx.ndim == 3)
ASSERT_BRIEF(cx.shape[0] == D * num_layers)
ASSERT_BRIEF(cx.shape[1] == input.shape[1]) ASSERT_BRIEF(cx.shape[2] == param().hidden_size)
ASSERT_BRIEF(hy.ndim == 3)
ASSERT_BRIEF(hy.shape[0] == D * num_layers)
ASSERT_BRIEF(hy.shape[1] == input.shape[1]) ASSERT_BRIEF(hy.shape[2] == param().hidden_size)
ASSERT_BRIEF(cy.ndim == 3)
ASSERT_BRIEF(cy.shape[0] == D * num_layers)
ASSERT_BRIEF(cy.shape[1] == input.shape[1]) ASSERT_BRIEF(cy.shape[2] == param().hidden_size)
#undef ASSERT_BRIEF
}
void LSTMBackward::deduce_layout(
const TensorLayout& x, const TensorLayout& , const TensorLayout& hx,
const TensorLayout& cx, const TensorLayout& , const TensorLayout& ,
const TensorLayout& , const TensorLayout& flatten_weights,
const TensorLayout& , TensorLayout& dx, TensorLayout& dhx,
TensorLayout& dcx, TensorLayout& dw) {
dx = x;
dhx = hx;
dcx = cx;
dw = flatten_weights;
}
void LSTMBackward::check_exec(
const TensorLayout& , const TensorLayout& ,
const TensorLayout& , const TensorLayout& ,
const TensorLayout& , const TensorLayout& ,
const TensorLayout& , const TensorLayout& ,
const TensorLayout& , const TensorLayout& ,
const TensorLayout& , const TensorLayout& ,
const TensorLayout& , size_t ) {}
}