#include "./opr_impl.h"
#include "./kern.cuh"
#include "src/common/utils.h"
using namespace megdnn;
using namespace cuda;
void AddUpdateForwardImpl::exec(_megdnn_tensor_inout dest, _megdnn_tensor_in delta) {
check_exec(dest.layout, delta.layout);
if (!dest.layout.is_contiguous()) {
return exec_noncontig(dest, delta);
}
ElemwiseOpParamN<1> param;
param[0] = delta;
param[0].layout = param[0].layout.broadcast(dest.layout);
param.init_from_given_tensor();
auto stream = cuda_stream(handle());
switch (dest.layout.dtype.enumv()) {
#define cb(_dt) \
case DTypeTrait<_dt>::enumv: { \
using ctype = DTypeTrait<_dt>::ctype; \
return run_elemwise<AddUpdateKernOp<ctype>, ctype, 1>( \
param, stream, {dest, m_param}); \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
#undef cb
default:
megdnn_throw("unsupported dtype for AddUpdate");
}
}
void AddUpdateForwardImpl::exec_noncontig(
_megdnn_tensor_inout dest, _megdnn_tensor_in delta) {
ElemwiseOpParamN<2> param = make_param(dest, delta);
auto stream = cuda_stream(handle());
switch (dest.layout.dtype.enumv()) {
#define cb(_dt) \
case DTypeTrait<_dt>::enumv: { \
using ctype = DTypeTrait<_dt>::ctype; \
return run_elemwise<AddUpdateKernOpNonContig<ctype>, ctype, 2>( \
param, stream, {m_param}); \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
#undef cb
default:
megdnn_throw("unsupported dtype for AddUpdate");
}
}