ct2rs 0.9.19

Rust bindings for OpenNMT/CTranslate2
Documentation
#include "ctranslate2/ops/transpose.h"

#include "dispatch.h"

namespace ctranslate2 {
  namespace ops {

    Transpose::Transpose(const std::vector<dim_t>& perm)
      : _perm(perm) {
    }

    void Transpose::operator()(const StorageView& x, StorageView& y) const {
      PROFILE("Transpose");
      if (x.rank() <= 1) {
        y = x;
        return;
      }

      std::vector<dim_t> perm;
      bool identity = true;
      if (_perm.empty()) {
        perm.resize(x.rank());
        for (dim_t i = 0; i < x.rank(); ++i)
          perm[i] = x.rank() - i - 1;
        identity = false;
      } else {
        assert(_perm.size() == x.rank());
        perm = _perm;
        for (dim_t i = 0; i < x.rank(); ++i) {
          if (perm[i] != i) {
            identity = false;
            break;
          }
        }
      }

      if (identity) {
        y = x;
        return;
      }

      DEVICE_AND_TYPE_DISPATCH(x.device(), x.dtype(), (compute<D, T>(x, perm, y)));
    }

  }
}