ct2rs 0.9.18

Rust bindings for OpenNMT/CTranslate2
Documentation
#pragma once

#include "ops/gather.h"
#include "storage_view.h"

namespace ctranslate2 {

  // This class can be used to dynamically remove or add padding.
  // This is useful to save on computation when lengths are very different.
  class Padder {
  public:
    static inline bool allow_padding_removal(const Device device,
                                             const ComputeType compute_type) {
      return device == Device::CPU || compute_type != ComputeType::FLOAT16;
    }

    // If max_time is negative, it is set to the maximum length.
    Padder(const StorageView& lengths,
           const dim_t max_time = -1,
           const dim_t pad_batch_to_multiple = 1);

    // Merge batch and time dimensions and remove padding.
    void remove_padding(StorageView& x) const;

    // Split first dimension into batch and time dimensions and add padding.
    void add_padding(StorageView& x) const;

  private:
    dim_t _batch_size;
    dim_t _max_time;
    StorageView _padded_to_flat;
    StorageView _flat_to_padded;
    const ops::Gather _gather_op;
  };

}