ct2rs 0.9.19

Rust bindings for OpenNMT/CTranslate2
Documentation
#include "ctranslate2/ops/rms_norm.h"

#include "cpu/kernels.h"

namespace ctranslate2 {
  namespace ops {

    template <Device D, typename T>
    void RMSNorm::compute(const StorageView& gamma,
                          const StorageView& input,
                          StorageView& output) const {
      const dim_t depth = input.dim(-1);
      const dim_t batch_size = input.size() / depth;
      CPU_ISA_DISPATCH((cpu::rms_norm<ISA>(input.data<T>(),
                                           gamma.data<T>(),
                                           output.data<T>(),
                                           batch_size,
                                           depth,
                                           _epsilon,
                                           _use_residual)));
    }

#define DECLARE_IMPL(T)                                                 \
    template void RMSNorm::compute<Device::CPU, T>(const StorageView&,  \
                                                   const StorageView&,  \
                                                   StorageView&) const;

    DECLARE_IMPL(float)

  }
}