ct2rs 0.9.18

Rust bindings for OpenNMT/CTranslate2
Documentation
#pragma once

#include "ctranslate2/layers/decoder.h"
#include "ctranslate2/layers/encoder.h"
#include "ctranslate2/models/model.h"
#include "ctranslate2/scoring.h"
#include "ctranslate2/translation.h"
#include "ctranslate2/vocabulary_map.h"

namespace ctranslate2 {
  namespace models {

    class SequenceToSequenceModel : public Model {
    public:
      size_t num_source_vocabularies() const;
      const Vocabulary& get_source_vocabulary(size_t index = 0) const;
      const Vocabulary& get_target_vocabulary() const;
      const VocabularyMap* get_vocabulary_map() const;

      bool with_source_bos() const {
        return config["add_source_bos"];
      }

      bool with_source_eos() const {
        return config["add_source_eos"];
      }

      const std::string* decoder_start_token() const {
        auto& start_token = config["decoder_start_token"];
        return start_token.is_null() ? nullptr : start_token.get_ptr<const std::string*>();
      }

    protected:
      virtual void initialize(ModelReader& model_reader) override;

    private:
      std::vector<std::shared_ptr<const Vocabulary>> _source_vocabularies;
      std::shared_ptr<const Vocabulary> _target_vocabulary;
      std::shared_ptr<const VocabularyMap> _vocabulary_map;

      void load_vocabularies(ModelReader& model_reader);
    };


    class SequenceToSequenceReplica : public ModelReplica {
    public:
      SequenceToSequenceReplica(const std::shared_ptr<const Model>& model)
        : ModelReplica(model)
      {
      }

      static std::unique_ptr<SequenceToSequenceReplica> create_from_model(const Model& model) {
        return model.as_sequence_to_sequence();
      }

      std::vector<ScoringResult>
      score(const std::vector<std::vector<std::string>>& source,
            const std::vector<std::vector<std::string>>& target,
            const ScoringOptions& options = ScoringOptions());

      std::vector<TranslationResult>
      translate(const std::vector<std::vector<std::string>>& source,
                const std::vector<std::vector<std::string>>& target_prefix = {},
                const TranslationOptions& options = TranslationOptions());

    protected:
      virtual bool skip_scoring(const std::vector<std::string>& source,
                                const std::vector<std::string>& target,
                                const ScoringOptions& options,
                                ScoringResult& result) {
        (void)source;
        (void)target;
        (void)options;
        (void)result;
        return false;
      }

      virtual bool skip_translation(const std::vector<std::string>& source,
                                    const std::vector<std::string>& target_prefix,
                                    const TranslationOptions& options,
                                    TranslationResult& result) {
        (void)source;
        (void)target_prefix;
        (void)options;
        (void)result;
        return false;
      }

      virtual std::vector<ScoringResult>
      run_scoring(const std::vector<std::vector<std::string>>& source,
                  const std::vector<std::vector<std::string>>& target,
                  const ScoringOptions& options) = 0;

      virtual std::vector<TranslationResult>
      run_translation(const std::vector<std::vector<std::string>>& source,
                      const std::vector<std::vector<std::string>>& target_prefix,
                      const TranslationOptions& options) = 0;
    };


    class EncoderDecoderReplica : public SequenceToSequenceReplica {
    public:
      EncoderDecoderReplica(const std::shared_ptr<const SequenceToSequenceModel>& model,
                            std::unique_ptr<layers::Encoder> encoder,
                            std::unique_ptr<layers::Decoder> decoder);

      layers::Encoder& encoder() {
        return *_encoder;
      }

      layers::Decoder& decoder() {
        return *_decoder;
      }

    protected:
      bool skip_scoring(const std::vector<std::string>& source,
                        const std::vector<std::string>& target,
                        const ScoringOptions& options,
                        ScoringResult& result) override;

      bool skip_translation(const std::vector<std::string>& source,
                            const std::vector<std::string>& target_prefix,
                            const TranslationOptions& options,
                            TranslationResult& result) override;

      std::vector<ScoringResult>
      run_scoring(const std::vector<std::vector<std::string>>& source,
                  const std::vector<std::vector<std::string>>& target,
                  const ScoringOptions& options) override;

      std::vector<TranslationResult>
      run_translation(const std::vector<std::vector<std::string>>& source,
                      const std::vector<std::vector<std::string>>& target_prefix,
                      const TranslationOptions& options) override;

    private:
      std::vector<std::vector<std::vector<size_t>>>
      make_source_ids(const std::vector<std::vector<std::vector<std::string>>>& source_features,
                      size_t max_length = 0) const;

      std::vector<std::vector<size_t>>
      make_target_ids(const std::vector<std::vector<std::string>>& target,
                      size_t max_length = 0,
                      bool is_prefix = false) const;

      size_t get_source_length(const std::vector<std::string>& source,
                               bool include_special_tokens) const;

      void encode(const std::vector<std::vector<std::vector<size_t>>>& ids,
                  StorageView& memory,
                  StorageView& memory_lengths);

      const std::shared_ptr<const SequenceToSequenceModel> _model;
      const std::unique_ptr<layers::Encoder> _encoder;
      const std::unique_ptr<layers::Decoder> _decoder;
    };

  }
}