#include "megdnn/oprs/general.h"
#include "src/common/utils.h"
#include <cmath>
using namespace megdnn;
void TopK::deduce_layout(
int k, const TensorLayout& data, TensorLayout& values, TensorLayout& indices) {
megdnn_assert(
k && data.ndim == 2 && data.stride[1] == 1, "invalid k=%d data=%s", k,
data.to_string().c_str());
values.dtype = data.dtype;
indices.dtype = dtype::Int32{};
switch (param().mode) {
case Param::Mode::KTH_ONLY:
values.init_contiguous_stride({data[0]});
indices.ndim = 0;
break;
case Param::Mode::VALUE_IDX_NOSORT:
case Param::Mode::VALUE_IDX_SORTED:
values.init_contiguous_stride(
{data[0], std::min<size_t>(std::abs(k), data.shape[1])});
indices.init_contiguous_stride(values);
break;
default:
megdnn_throw("invalid TopK mode");
}
}
void TopK::exec(
int k, _megdnn_tensor_in data, _megdnn_tensor_out values,
_megdnn_tensor_out indices, _megdnn_workspace workspace) {
TensorLayout oval, oidx;
deduce_layout(k, data.layout, oval, oidx);
megdnn_assert_eq_layout(oval, values.layout);
int32_t* iptr = nullptr;
if (param().mode == Param::Mode::KTH_ONLY) {
megdnn_assert_eq_shape(indices.layout, TensorShape{});
} else {
iptr = indices.ptr<int32_t>();
megdnn_assert_eq_layout(oidx, indices.layout);
}
megdnn_assert(
workspace.size >=
get_workspace_in_bytes(k, data.layout, values.layout, indices.layout));
if (static_cast<size_t>(std::abs(k)) > data.layout[1]) {
if (k > 0) {
k = data.layout[1];
} else {
k = -static_cast<int>(data.layout[1]);
}
}
do_exec(k, data, values, iptr, workspace);
}