ct2rs 0.9.19

Rust bindings for OpenNMT/CTranslate2
Documentation
#include "ctranslate2/ops/concat.h"

#include "dispatch.h"

namespace ctranslate2 {
  namespace ops {

    Concat::Concat(int axis)
      : _axis(axis) {
    }

    void Concat::operator()(const std::vector<const StorageView*>& inputs,
                            StorageView& output) const {
      PROFILE("Concat");
      const dim_t rank = inputs.front()->rank();
      const dim_t axis = _axis < 0 ? rank + _axis : _axis;
      dim_t concat_dims = 0;
      for (const StorageView* x : inputs) {
        assert(x->rank() == rank);
        concat_dims += x->dim(axis);
      }

      Shape output_shape(inputs.front()->shape());
      output_shape[axis] = concat_dims;
      output.resize(std::move(output_shape));

      DEVICE_AND_TYPE_DISPATCH(output.device(), output.dtype(), (compute<D, T>(inputs, output)));
    }

  }
}