ct2rs 0.9.19

Rust bindings for OpenNMT/CTranslate2
Documentation
#include "ctranslate2/ops/mean.h"

#include "cpu/parallel.h"
#include "type_dispatch.h"

namespace ctranslate2 {
  namespace ops {

    template <Device D, typename T>
    void Mean::compute(const StorageView& input,
                       const dim_t outer_size,
                       const dim_t axis_size,
                       const dim_t inner_size,
                       const bool get_sum,
                       StorageView& output) const {
      const auto* src = input.data<T>();
      auto* dst = output.data<T>();

      cpu::parallel_for(0, outer_size, 1, [&](dim_t begin, dim_t end) {
        for (dim_t i = begin; i < end; ++i) {
          for (dim_t j = 0; j < inner_size; ++j) {
            float sum = 0.f;
            for (dim_t k = 0; k < axis_size; ++k) {
              sum += src[i * axis_size * inner_size + k * inner_size + j];
            }
            dst[i * inner_size + j] = sum;
            if (!get_sum)
              dst[i * inner_size + j] /= float(axis_size);
          }
        }
      });
    }

#define DECLARE_IMPL(T)                                         \
    template void                                               \
    Mean::compute<Device::CPU, T>(const StorageView& input,     \
                                  const dim_t outer_size,       \
                                  const dim_t axis_size,        \
                                  const dim_t inner_size,       \
                                  const bool get_sum,           \
                                  StorageView& output) const;

    DECLARE_IMPL(float)

  }
}