ct2rs 0.9.19

Rust bindings for OpenNMT/CTranslate2
Documentation
#include "ctranslate2/ops/gather.h"

#include "cpu/parallel.h"
#include "type_dispatch.h"

namespace ctranslate2 {
  namespace ops {

    template <Device D, typename T>
    void Gather::compute(const StorageView& data,
                         const StorageView& input,
                         const dim_t axis,
                         const dim_t batch_dims,
                         StorageView& output) const {
      const auto* indices = input.data<int32_t>();
      const T* src = data.data<T>();
      T* dst = output.data<T>();

      if (axis == batch_dims) {
        const dim_t copy_size = data.stride(axis);
        const dim_t batch_stride = axis > 0 ? data.stride(axis - 1) : data.size();
        const dim_t batch_size = data.size() / batch_stride;
        const dim_t num_indices = input.size();
        const dim_t num_indices_per_batch = num_indices / batch_size;

        const dim_t grain_size = cpu::get_minimum_batch_copies_per_thread<T>(copy_size);
        cpu::parallel_for(0, num_indices, grain_size, [&](dim_t begin, dim_t end) {
          for (dim_t i = begin; i < end; ++i) {
            const dim_t batch_index = i / num_indices_per_batch;
            const dim_t read_index = indices[i];
            primitives<Device::CPU>::copy(src + batch_index * batch_stride + read_index * copy_size,
                                          dst + i * copy_size,
                                          copy_size);
          }
        });

      } else {
        throw std::invalid_argument("Gather only supports indexing the first non batch dimension");
      }
    }

#define DECLARE_IMPL(T)                                         \
    template void                                               \
    Gather::compute<Device::CPU, T>(const StorageView& data,    \
                                    const StorageView& input,   \
                                    const dim_t axis,           \
                                    const dim_t batch_dims,     \
                                    StorageView& output) const;

    DECLARE_ALL_TYPES(DECLARE_IMPL)

  }
}