#include "src/naive/split/opr_impl.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"
#include <numeric>
namespace megdnn {
namespace naive {
template <typename T>
void SplitForwardImpl::exec_internal(
_megdnn_tensor_in src, const TensorNDArray& dsts, _megdnn_workspace workspace) {
size_t A, B, C;
size_t* Bv = reinterpret_cast<size_t*>(workspace.raw_ptr);
auto dsts_layout = apply_vector<TensorLayout>(m_get_layout, dsts);
check_exec(src.layout, dsts_layout, workspace.size);
auto dsts_shape = apply_vector<TensorShape>(m_get_shape, dsts_layout);
get_ABC(dsts_shape, A, Bv, C);
B = std::accumulate(Bv, Bv + dsts.size(), 0u);
auto sptr = src.ptr<T>();
rep(a, A) {
size_t dbi = 0u;
size_t dbo = 0u;
rep(sb, B) {
auto dptr = dsts[dbi].ptr<T>();
rep(c, C) {
auto sidx = a * B * C + sb * C + c;
auto didx = a * Bv[dbi] * C + dbo * C + c;
dptr[didx] = sptr[sidx];
}
++dbo;
if (dbo >= Bv[dbi]) {
dbo = 0u;
++dbi;
}
}
}
}
void SplitForwardImpl::exec(
_megdnn_tensor_in src, const TensorNDArray& dsts, _megdnn_workspace workspace) {
#define cb(DType) \
if (src.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \
using ctype = typename DTypeTrait<DType>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal<ctype>(src, dsts, workspace)); \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
#undef cb
}
} }