ct2rs 0.9.19

Rust bindings for OpenNMT/CTranslate2
Documentation
#include "ctranslate2/ops/multinomial.h"

#include "ctranslate2/random.h"

namespace ctranslate2 {
  namespace ops {

    template <typename In, typename Out>
    static void multinomial_kernel(const In* input,
                                   dim_t batch_size,
                                   dim_t class_size,
                                   dim_t sample_size,
                                   Out* output) {
      auto& generator = get_random_generator();

      for (dim_t i = 0; i < batch_size; ++i) {
        const In* input_data = input + i * class_size;
        Out* output_data = output + i * sample_size;

        std::discrete_distribution<Out> distribution(input_data, input_data + class_size);
        for (dim_t j = 0; j < sample_size; ++j)
          output_data[j] = distribution(generator);
      }
    }

    template <Device D, typename T>
    void Multinomial::compute(const StorageView& input, StorageView& output) const {
      const dim_t class_size = input.dim(-1);
      const dim_t batch_size = input.size() / class_size;
      multinomial_kernel(input.data<T>(),
                         batch_size,
                         class_size,
                         _sample_size,
                         output.data<int32_t>());
    }

#define DECLARE_IMPL(T)                                                 \
    template void                                                       \
    Multinomial::compute<Device::CPU, T>(const StorageView& input,      \
                                         StorageView& output) const;

    DECLARE_IMPL(float)

  }
}