ct2rs 0.9.19

Rust bindings for OpenNMT/CTranslate2
Documentation
#include "ctranslate2/ops/sum.h"

#include "dispatch.h"
namespace ctranslate2 {
  namespace ops {

    Sum::Sum(const dim_t axis)
      : Mean(axis)
    {
    }

    void Sum::operator()(const StorageView& input, StorageView& output) const {
      PROFILE("Sum");

      const dim_t axis = _axis < 0 ? input.rank() + _axis : _axis;
      if (axis >= input.rank())
        throw std::out_of_range("Cannot compute sum of axis " + std::to_string(axis)
                                + " for input with rank " + std::to_string(input.rank()));

      const dim_t axis_size = input.dim(axis);
      if (axis_size == 1) {
        output = input;
        return;
      }

      {
        Shape output_shape(input.shape());
        output_shape[axis] = 1;
        output.resize(std::move(output_shape));
      }

      dim_t inner_size = 1;
      dim_t outer_size = 1;
      for (dim_t i = 0; i < axis; ++i)
        outer_size *= input.dim(i);
      for (dim_t i = axis + 1; i < input.rank(); ++i)
        inner_size *= input.dim(i);

      DEVICE_AND_FLOAT_DISPATCH("Sum", input.device(), input.dtype(),
                                (compute<D, T>(input, outer_size, axis_size, inner_size, true, output)));
    }

  }
}