#include "src/naive/reduce/opr_impl.h"
#include <climits>
#include <cstring>
#include <functional>
#include "src/common/reduce_helper.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"
using namespace megdnn;
namespace {
using Mode = Reduce::Mode;
template <Mode mode, typename ctype>
struct Trait;
template <typename ctype>
struct Trait<Mode::SUM, ctype> {
static const ctype INIT;
static ctype apply(ctype x, ctype y) { return x + y; }
static ctype visit(ctype x) { return x; }
static ctype write(ctype x, size_t) { return x; }
};
template <typename ctype>
const ctype Trait<Mode::SUM, ctype>::INIT = ctype(0);
template <typename ctype>
struct Trait<Mode::MEAN, ctype> {
static const ctype INIT;
static ctype apply(ctype x, ctype y) { return x + y; }
static ctype visit(ctype x) { return x; }
static ctype write(ctype x, size_t B) { return x / (ctype)B; }
};
template <typename ctype>
const ctype Trait<Mode::MEAN, ctype>::INIT = ctype(0);
template <typename ctype>
struct Trait<Mode::SUM_SQR, ctype> {
static const ctype INIT;
static ctype apply(ctype x, ctype y) { return x + y; }
static ctype visit(ctype x) { return x * x; }
static ctype write(ctype x, size_t) { return x; }
};
template <typename ctype>
const ctype Trait<Mode::SUM_SQR, ctype>::INIT = ctype(0);
template <typename ctype>
struct Trait<Mode::PRODUCT, ctype> {
static const ctype INIT;
static ctype apply(ctype x, ctype y) { return x * y; }
static ctype visit(ctype x) { return x; }
static ctype write(ctype x, size_t) { return x; }
};
template <typename ctype>
const ctype Trait<Mode::PRODUCT, ctype>::INIT = ctype(1);
template <typename ctype>
struct Trait<Mode::MIN, ctype> {
static ctype apply(ctype x, ctype y) { return x < y ? x : y; }
static ctype visit(ctype x) { return x; }
static ctype write(ctype x, size_t) { return x; }
};
template <>
struct Trait<Mode::MIN, dt_float32> {
using ctype = dt_float32;
static ctype apply(ctype x, ctype y) { return (std::isnan(x) || x < y) ? x : y; }
static ctype visit(ctype x) { return x; }
static ctype write(ctype x, size_t) { return x; }
};
template <typename ctype>
struct Trait<Mode::MAX, ctype> {
static ctype apply(ctype x, ctype y) { return x > y ? x : y; }
static ctype visit(ctype x) { return x; }
static ctype write(ctype x, size_t) { return x; }
};
template <>
struct Trait<Mode::MAX, dt_float32> {
using ctype = dt_float32;
static ctype apply(ctype x, ctype y) { return (std::isnan(x) || x > y) ? x : y; }
static ctype visit(ctype x) { return x; }
static ctype write(ctype x, size_t) { return x; }
};
template <Mode mode, typename ctype>
void reduce_fwd(
const ctype* __restrict sptr, ctype* __restrict dptr, size_t A, size_t B,
size_t C) {
std::function<ctype(size_t, size_t, size_t, size_t)> func;
func = [&](size_t a, size_t c, size_t bl, size_t br) -> ctype {
if (bl + 1 < br) {
size_t mid = bl + (br - bl) / 2;
return Trait<mode, ctype>::apply(func(a, c, bl, mid), func(a, c, mid, br));
} else {
return Trait<mode, ctype>::visit(sptr[a * B * C + bl * C + c]);
}
};
for (size_t a = 0; a < A; ++a)
for (size_t c = 0; c < C; ++c) {
dptr[a * C + c] = Trait<mode, ctype>::write(func(a, c, 0, B), B);
}
}
template <>
void reduce_fwd<Mode::SUM>(
const dt_quint8* __restrict, dt_quint8* __restrict, size_t, size_t, size_t) {
megdnn_throw(
"Reduce (SUM) with DEFAULT DataType is not supported "
"on Quantized8Asymm");
}
template <>
void reduce_fwd<Mode::MEAN>(
const dt_quint8* __restrict, dt_quint8* __restrict, size_t, size_t, size_t) {
megdnn_throw(
"Reduce (MEAN) with DEFAULT DataType is not supported "
"on Quantized8Asymm");
}
template <>
void reduce_fwd<Mode::SUM_SQR>(
const dt_quint8* __restrict, dt_quint8* __restrict, size_t, size_t, size_t) {
megdnn_throw(
"Reduce (SUM_SQR) with DEFAULT DataType is not supported "
"on Quantized8Asymm");
}
template <>
void reduce_fwd<Mode::PRODUCT>(
const dt_quint8* __restrict, dt_quint8* __restrict, size_t, size_t, size_t) {
megdnn_throw(
"Reduce (PRODUCT) with DEFAULT DataType is not supported "
"on Quantized8Asymm");
}
template <>
void reduce_fwd<Mode::SUM>(
const dt_qint8* __restrict, dt_qint8* __restrict, size_t, size_t, size_t) {
megdnn_throw(
"Reduce (SUM) with DEFAULT DataType is not supported "
"on QuantizedS8");
}
template <>
void reduce_fwd<Mode::MEAN>(
const dt_qint8* __restrict, dt_qint8* __restrict, size_t, size_t, size_t) {
megdnn_throw(
"Reduce (MEAN) with DEFAULT DataType is not supported "
"on QuantizedS8");
}
template <>
void reduce_fwd<Mode::SUM_SQR>(
const dt_qint8* __restrict, dt_qint8* __restrict, size_t, size_t, size_t) {
megdnn_throw(
"Reduce (SUM_SQR) with DEFAULT DataType is not supported "
"on QuantizedS8");
}
template <>
void reduce_fwd<Mode::PRODUCT>(
const dt_qint8* __restrict, dt_qint8* __restrict, size_t, size_t, size_t) {
megdnn_throw(
"Reduce (PRODUCT) with DEFAULT DataType is not supported "
"on QuantizedS8");
}
template <Mode mode>
void dispatch_dtype(
megdnn::naive::HandleImpl* handle, const TensorND& src, const TensorND& dst,
size_t A, size_t B, size_t C) {
switch (src.layout.dtype.enumv()) {
#define cb(_dt) \
case DTypeTrait<_dt>::enumv: { \
using ctype = DTypeTrait<_dt>::ctype; \
MEGDNN_DISPATCH_CPU_KERN( \
handle, reduce_fwd<mode MEGDNN_COMMA ctype>( \
src.ptr<ctype>(), dst.ptr<ctype>(), A, B, C)); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
#undef cb
default:
megdnn_assert_internal(false);
}
}
}
namespace megdnn {
namespace naive {
size_t ReduceForwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) {
MEGDNN_MARK_USED_VAR(src);
MEGDNN_MARK_USED_VAR(dst);
megdnn_assert(
param().data_type != Reduce::DataType::FLOAT_IO16xC32,
"FLOAT_IO16xC32 is deprecated");
DType comp_dtype = src.dtype;
if (param().mode == Mode::SUM || param().mode == Mode::MEAN) {
if (src.dtype.category() == DTypeCategory::QUANTIZED) {
float src_scale;
if (src.dtype.enumv() == DTypeEnum::QuantizedS8) {
src_scale = src.dtype.param<dtype::QuantizedS8>().scale;
} else if (src.dtype.enumv() == DTypeEnum::Quantized8Asymm) {
src_scale = src.dtype.param<dtype::Quantized8Asymm>().scale;
} else {
megdnn_assert_internal(0);
}
comp_dtype = dtype::QuantizedS32(src_scale);
} else if (param().data_type != Param::DataType::DEFAULT) {
comp_dtype = dtype::Float32();
}
} else if (param().data_type != Param::DataType::DEFAULT) {
comp_dtype = dtype::Float32();
}
size_t size = 0;
if (src.dtype != comp_dtype)
size += comp_dtype.size(src.total_nr_elems());
if (dst.dtype != comp_dtype)
size += comp_dtype.size(dst.total_nr_elems());
return size;
}
void ReduceForwardImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) {
using namespace reduce;
check_exec(src.layout, dst.layout, workspace.size);
size_t A, B, C;
get_ABC(src.layout, A, B, C, param().axis);
DType comp_dtype = src.layout.dtype;
if (param().mode == Mode::SUM || param().mode == Mode::MEAN) {
if (src.layout.dtype.category() == DTypeCategory::QUANTIZED) {
float src_scale;
if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8) {
src_scale = src.layout.dtype.param<dtype::QuantizedS8>().scale;
} else if (src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) {
src_scale = src.layout.dtype.param<dtype::Quantized8Asymm>().scale;
} else {
megdnn_assert_internal(0);
}
comp_dtype = dtype::QuantizedS32(src_scale);
} else if (param().data_type != Param::DataType::DEFAULT) {
comp_dtype = dtype::Float32();
}
} else if (param().data_type != Param::DataType::DEFAULT) {
comp_dtype = dtype::Float32();
}
auto make_tensor = [&](DType comp_dtype, _megdnn_tensor_inout tensor,
dt_byte*& workspace_ptr) {
if (comp_dtype == tensor.layout.dtype)
return tensor;
auto layout = TensorLayout(tensor.layout, comp_dtype);
TensorND new_tensor{workspace_ptr, layout};
workspace_ptr += layout.span().dist_byte();
return new_tensor;
};
auto typecvt = handle()->create_operator<TypeCvt>();
auto copy_to = [&typecvt](const TensorND& from, const TensorND& to) {
if (from.raw_ptr() != to.raw_ptr())
typecvt->exec(from, to);
};
auto workspace_ptr = workspace.ptr<dt_byte>();
auto new_src = make_tensor(comp_dtype, src, workspace_ptr);
auto new_dst = make_tensor(comp_dtype, dst, workspace_ptr);
#define CASE(mode) \
case mode: { \
copy_to(src, new_src); \
dispatch_dtype<mode>( \
static_cast<HandleImpl*>(handle()), new_src, new_dst, A, B, C); \
copy_to(new_dst, dst); \
return; \
}
switch (param().mode) {
CASE(Mode::SUM);
CASE(Mode::SUM_SQR);
CASE(Mode::PRODUCT);
CASE(Mode::MIN);
CASE(Mode::MAX);
CASE(Mode::MEAN);
default:
megdnn_assert_internal(false);
}
#undef CASE
}
} }