ct2rs 0.9.19

Rust bindings for OpenNMT/CTranslate2
Documentation
#include "ctranslate2/ops/gumbel_max.h"

#include <cmath>
#include <limits>

#include "ctranslate2/random.h"
#include "type_dispatch.h"

namespace ctranslate2 {
  namespace ops {

    template <Device D, typename T>
    void GumbelMax::add_gumbel_noise(const StorageView& x, StorageView& y) const {
      auto& generator = get_random_generator();

      const T* src = x.data<T>();
      T* dst = y.data<T>();

      std::uniform_real_distribution<float> distribution(std::numeric_limits<float>::min(), 1.f);
      for (dim_t i = 0; i < x.size(); ++i) {
        const float z = -std::log(distribution(generator));
        dst[i] = src[i] + z;
      }
    }

#define DECLARE_IMPL(T)                                                 \
    template void                                                       \
    GumbelMax::add_gumbel_noise<Device::CPU, T>(const StorageView& x,   \
                                                StorageView& y) const;

    DECLARE_IMPL(float)

  }
}