ct2rs 0.9.19

Rust bindings for OpenNMT/CTranslate2
Documentation
#include "ctranslate2/ops/gumbel_max.h"

#include "dispatch.h"

namespace ctranslate2 {
  namespace ops {

    GumbelMax::GumbelMax(dim_t num_samples)
      : _num_samples(num_samples)
      , _topk_op(num_samples)
    {
    }

    void GumbelMax::operator()(const StorageView& x,
                               StorageView& values,
                               StorageView& indices) const {
      PROFILE("GumbelMax");

      StorageView y(x.shape(), x.dtype(), x.device());

      DEVICE_AND_FLOAT_DISPATCH("GumbelMax", x.device(), x.dtype(), (add_gumbel_noise<D, T>(x, y)));

      _topk_op(y, values, indices);

      Shape output_shape = x.shape();
      output_shape.back() = _num_samples;
      values.reshape(output_shape);
      indices.reshape(output_shape);
    }

    void GumbelMax::operator()(const StorageView& x, StorageView& indices) const {
      StorageView values(x.dtype(), x.device());
      operator()(x, values, indices);
    }

  }
}