ct2rs 0.9.18

Rust bindings for OpenNMT/CTranslate2
Documentation
#pragma once

#include "op.h"

namespace ctranslate2 {
  namespace ops {

    // Implements multinomial sampling with replacement.
    class Multinomial : public UnaryOp {
    public:
      Multinomial(dim_t sample_size = 1);
      void operator()(const StorageView& input, StorageView& output) const override;

    private:
      dim_t _sample_size;

      void dispatch(const StorageView& input, StorageView& output) const;

      template <Device D, typename T>
      void compute(const StorageView& input, StorageView& output) const;
    };

  }
}