ct2rs 0.9.19

Rust bindings for OpenNMT/CTranslate2
Documentation
#include "ctranslate2/ops/topp_mask.h"

#include "ctranslate2/ops/softmax.h"

#include "dispatch.h"

namespace ctranslate2 {
  namespace ops {

    TopPMask::TopPMask(const float p, const float mask_value)
      : _p(p)
      , _mask_value(mask_value)
    {
    }

    void TopPMask::operator()(const StorageView& input, StorageView& output) const {
      PROFILE("TopPMask");

      const DataType dtype = input.dtype();
      const Device device = input.device();

      StorageView probs(dtype, device);
      ops::SoftMax()(input, probs);

      output.resize_as(input);

      DEVICE_AND_FLOAT_DISPATCH("TopPMask", device, dtype, (compute<D, T>(input, probs, output)));
    }

    dim_t TopPMask::max_num_classes(const Device device) {
      dim_t num_classes = 0;
      DEVICE_DISPATCH(device, num_classes = max_num_classes<D>());
      return num_classes;
    }

  }
}