#include "./opr_impl.h"
#include "src/naive/add_update/opr_impl.h"
#include "src/common/utils.h"
#include "src/fallback/handle.h"
namespace {
using namespace megdnn;
template <typename T>
void forward(
_megdnn_tensor_inout dest, _megdnn_tensor_in delta,
const AddUpdate::Param& param) {
T alpha(param.alpha), beta(param.beta), bias(param.bias);
T* iter0 = dest.ptr<T>();
T* iter1 = delta.ptr<T>();
for (size_t i = 0, it = dest.layout.total_nr_elems(); i < it; ++i) {
*iter0 = alpha * *iter0 + beta * *iter1 + bias;
++iter0;
++iter1;
}
}
}
namespace megdnn {
namespace fallback {
void AddUpdateImpl::exec(_megdnn_tensor_inout dest, _megdnn_tensor_in delta) {
check_exec(dest.layout, delta.layout);
if (!dest.layout.is_contiguous() || !delta.layout.is_contiguous() ||
!dest.layout.eq_shape(delta.layout)) {
return naive::AddUpdateForwardImpl::exec(dest, delta);
}
auto param = m_param;
#define cb(DType) \
if (dest.layout.dtype == DType()) { \
using ctype = typename DTypeTrait<DType>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR(forward<ctype>(dest, delta, param)); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
#undef cb
megdnn_assert_internal(0);
}
} }