ct2rs 0.9.19

Rust bindings for OpenNMT/CTranslate2
Documentation
#include "ctranslate2/ops/rotary.h"

#include "cpu/parallel.h"

namespace ctranslate2 {
  namespace ops {

    template <typename T, bool interleave>
    void rotary_kernel(const T* input,
                       const T* sin,
                       const T* cos,
                       T* output,
                       const dim_t batch_size,
                       const dim_t max_time,
                       const dim_t ndims,
                       const dim_t depth) {
      const dim_t middle = ndims / 2;

      cpu::parallel_for(0, batch_size, 1, [&](dim_t begin, dim_t end) {
        for (dim_t b = begin; b < end; ++b) {
          for (dim_t t = 0; t < max_time; ++t) {
            const T* s = sin + t * ndims;
            const T* c = cos + t * ndims;

            const T* x = input + b * (max_time * depth) + t * depth;
            T* y = output + b * (max_time * depth) + t * depth;

            for (dim_t i = 0; i < ndims; ++i) {
              if (interleave)
                y[i] = x[i] * c[i] + (i % 2 == 0 ? -x[i + 1] : x[i - 1]) * s[i];
              else
                y[i] = x[i] * c[i] + (i < middle ? -x[i + middle] : x[i - middle]) * s[i];
            }

            if (ndims < depth)
              std::copy(x + ndims, x + depth, y + ndims);
          }
        }
      });
    }

    template <Device D, typename T>
    void Rotary::compute(const StorageView& input,
                         const StorageView& sin,
                         const StorageView& cos,
                         StorageView& output,
                         bool is_transposed) const {
      const dim_t max_time = is_transposed ? input.dim(-2) : input.dim(-3);
      const dim_t depth = input.dim(-1);
      const dim_t batch_size = input.size() / (max_time * depth);
      const dim_t ndims = _ndims == 0 ? depth : _ndims;

      const auto* x = input.data<T>();
      const auto* s = sin.data<T>();
      const auto* c = cos.data<T>();
      auto* y = output.data<T>();

      if (_interleave)
        rotary_kernel<T, true>(x, s, c, y, batch_size, max_time, ndims, depth);
      else
        rotary_kernel<T, false>(x, s, c, y, batch_size, max_time, ndims, depth);
    }

#define DECLARE_IMPL(T)                                                 \
    template void                                                       \
    Rotary::compute<Device::CPU, T>(const StorageView&,                 \
                                    const StorageView&,                 \
                                    const StorageView&,                 \
                                    StorageView&,                       \
                                    bool) const;

    DECLARE_IMPL(float)

  }
}