ct2rs 0.8.2

Rust bindings for OpenNMT/CTranslate2
Documentation
#include "ctranslate2/generator.h"

#include <spdlog/spdlog.h>

namespace ctranslate2 {

  std::vector<std::future<GenerationResult>>
  Generator::generate_batch_async(const std::vector<std::vector<std::string>>& start_tokens,
                                  const GenerationOptions& options,
                                  const size_t max_batch_size,
                                  const BatchType batch_type) {
    return post_examples<GenerationResult>(
      load_examples({start_tokens}),
      max_batch_size,
      batch_type,
      [options](models::SequenceGeneratorReplica& generator, const Batch& batch) {
        spdlog::debug("Running batch generation on {} examples", batch.num_examples());
        auto results = generator.generate(
          batch.get_stream(0),
          restore_batch_ids_in_callback(options, batch.example_index));
        spdlog::debug("Finished batch generation");
        return results;
      });
  }

  std::vector<std::future<ScoringResult>>
  Generator::score_batch_async(const std::vector<std::vector<std::string>>& tokens,
                               const ScoringOptions& options,
                               const size_t max_batch_size,
                               const BatchType batch_type) {
    return post_examples<ScoringResult>(
      load_examples({tokens}),
      max_batch_size,
      batch_type,
      [options](models::SequenceGeneratorReplica& generator, const Batch& batch) {
        spdlog::debug("Running batch scoring on {} examples", batch.num_examples());
        auto results = generator.score(batch.get_stream(0), options);
        spdlog::debug("Finished batch scoring");
        return results;
      });
  }

  std::future<StorageView>
  Generator::forward_batch_async(std::vector<std::vector<std::string>> tokens,
                                 const bool return_log_probs) {
    return post<StorageView>(
      [tokens = std::move(tokens), return_log_probs]
      (models::SequenceGeneratorReplica& generator) {
        return generator.forward(tokens, return_log_probs);
      });
  }

  std::future<StorageView>
  Generator::forward_batch_async(std::vector<std::vector<size_t>> ids,
                                 const bool return_log_probs) {
    return post<StorageView>(
      [ids = std::move(ids), return_log_probs]
      (models::SequenceGeneratorReplica& generator) {
        return generator.forward(ids, return_log_probs);
      });
  }

  std::future<StorageView>
  Generator::forward_batch_async(StorageView ids,
                                 StorageView lengths,
                                 const bool return_log_probs) {
    return post<StorageView>(
      [ids = std::move(ids), lengths = std::move(lengths), return_log_probs]
      (models::SequenceGeneratorReplica& generator) {
        return generator.forward(ids, lengths, return_log_probs);
      });
  }

}