ct2rs 0.9.19

Rust bindings for OpenNMT/CTranslate2
Documentation
#include "ctranslate2/ops/topk.h"

#include <algorithm>
#include <numeric>

#include "cpu/parallel.h"
#include "type_dispatch.h"

namespace ctranslate2 {
  namespace ops {

    template <Device D, typename DataType, typename IndexType>
    void TopK::compute(const StorageView& x,
                       StorageView& values,
                       StorageView& indices) const {
      const dim_t depth = x.dim(-1);
      const dim_t batch_size = x.size() / depth;

      const DataType* x_data = x.data<DataType>();
      DataType* v_data = values.data<DataType>();
      IndexType* i_data = indices.data<IndexType>();

      if (_k == 1) {
        cpu::parallel_for(0, batch_size, 1, [&](dim_t begin, dim_t end) {
          for (dim_t i = begin; i < end; ++i) {
            const DataType* row = x_data + i * depth;
            const DataType* max = std::max_element(row, row + depth);
            v_data[i] = *max;
            i_data[i] = std::distance(row, max);
          }
        });

      } else {
        cpu::parallel_for(0, batch_size, 1, [&](dim_t begin, dim_t end) {
          for (dim_t i = begin; i < end; ++i) {
            const auto* input = x_data + (i * depth);
            auto* val = v_data + (i * _k);
            auto* ind = i_data + (i * _k);

            StorageView range({depth}, indices.dtype());
            auto* ids = range.data<IndexType>();
            std::iota(ids, ids + depth, 0);
            std::partial_sort(ids, ids + _k, ids + depth,
                              [&input](const IndexType& i1, const IndexType& i2) {
                                return input[i1] > input[i2];
                              });
            for (dim_t j = 0; j < _k; ++j) {
              ind[j] = ids[j];
              val[j] = input[ind[j]];
            }
          }
        });

      }
    }

#define DECLARE_IMPL(T)                                                 \
    template void                                                       \
    TopK::compute<Device::CPU, T, int32_t>(const StorageView& x,        \
                                           StorageView& values,         \
                                           StorageView& indices) const;

    DECLARE_ALL_TYPES(DECLARE_IMPL)

  }
}