ct2rs 0.9.19

Rust bindings for OpenNMT/CTranslate2
Documentation
#include "ctranslate2/ops/slide.h"

#include <numeric>

#include "dispatch.h"

namespace ctranslate2 {
  namespace ops {

    Slide::Slide(dim_t axis, const dim_t& index, const dim_t& size, bool no_copy)
      : _axis(axis)
      , _index(index)
      , _size(size)
      , _no_copy(no_copy) {
      check_arguments();
    }

    void Slide::operator()(const StorageView& input, StorageView& output) const {
      PROFILE("Slide");
      const dim_t axis = _axis < 0 ? input.rank() + _axis : _axis;

      if (_index < 0 || _index >= input.dim(axis))
        throw std::invalid_argument("Index or Size given is not valid");

      dim_t offset = input.stride(0) * _index;
      auto shape = input.shape();
      shape[axis] = _size;
      if (_no_copy) {
        TYPE_DISPATCH(input.dtype(),
                      output.view(const_cast<T*>(input.data<T>() + offset), std::move(shape)));
      }
      else {
        output.resize(std::move(shape));
      }

      if (!_no_copy) {
        DEVICE_AND_TYPE_DISPATCH(input.device(), input.dtype(), (compute<D, T>(input, output, _index)));
      }
    }

    void Slide::check_arguments() const {
      if (_no_copy && _axis != 0)
        throw std::invalid_argument("no_copy is only defined when splitting across the first dimension");
    }

  }
}