ct2rs 0.9.19

Rust bindings for OpenNMT/CTranslate2
Documentation
#include "ctranslate2/ops/rms_norm.h"

#include "dispatch.h"

namespace ctranslate2 {
  namespace ops {

    RMSNorm::RMSNorm(const float epsilon, const bool use_residual)
      : _epsilon(epsilon)
      , _use_residual(use_residual)
    {
    }

    void RMSNorm::operator()(const StorageView& gamma,
                             const StorageView& input,
                             StorageView& output) const {
      PROFILE("RMSNorm");

      output.resize_as(input);

      DEVICE_AND_FLOAT_DISPATCH("RMSNorm", input.device(), input.dtype(),
                                (compute<D, T>(gamma, input, output)));
    }

  }
}