ct2rs 0.9.19

Rust bindings for OpenNMT/CTranslate2
Documentation
#include "ctranslate2/ops/topp_mask.h"

#include <algorithm>
#include <numeric>

#include "cpu/parallel.h"
#include "dispatch.h"

namespace ctranslate2 {
  namespace ops {

    template <typename T>
    static void topp_mask_kernel(const T* x,
                                 const T* probs,
                                 const float p,
                                 const float mask,
                                 const dim_t batch_size,
                                 const dim_t depth,
                                 T* y) {
      cpu::parallel_for(0, batch_size, 1, [&](dim_t begin, dim_t end) {
        for (dim_t i = begin; i < end; ++i) {
          const auto* x_i = x + (i * depth);
          const auto* probs_i = probs + (i * depth);
          auto* y_i = y + (i * depth);

          std::vector<dim_t> ids(depth);
          std::iota(ids.begin(), ids.end(), 0);
          std::sort(ids.begin(), ids.end(), [&probs_i](const dim_t i1, const dim_t i2) {
            return probs_i[i1] > probs_i[i2];
          });

          float total_p = 0;

          for (const auto id : ids) {
            y_i[id] = total_p < p ? x_i[id] : mask;
            total_p += probs_i[id];
          }
        }
      });
    }

    template <Device D, typename T>
    void TopPMask::compute(const StorageView& input,
                           const StorageView& probs,
                           StorageView& output) const {
      const dim_t depth = input.dim(-1);
      const dim_t batch_size = input.size() / depth;

      topp_mask_kernel(input.data<T>(),
                       probs.data<T>(),
                       _p,
                       _mask_value,
                       batch_size,
                       depth,
                       output.data<T>());
    }

    template<>
    dim_t TopPMask::max_num_classes<Device::CPU>() {
      return std::numeric_limits<dim_t>::max();
    }

#define DECLARE_IMPL(T)                                                 \
    template void TopPMask::compute<Device::CPU, T>(const StorageView&, \
                                                    const StorageView&, \
                                                    StorageView&) const;

    DECLARE_IMPL(float)

  }
}