ct2rs 0.9.18

Rust bindings for OpenNMT/CTranslate2
Documentation
#pragma once

#include "op.h"

namespace ctranslate2 {
  namespace ops {

    class LayerNorm : public TernaryOp {
    public:
      LayerNorm(const dim_t axis = -1, const float epsilon = 1e-5);

      using TernaryOp::operator();
      void operator()(const StorageView& beta,
                      const StorageView& gamma,
                      const StorageView& input,
                      StorageView& output) const;

      void operator()(StorageView& input) const;
      void operator()(const StorageView& input, StorageView& output) const;

    private:
      void operator()(const StorageView* beta,
                      const StorageView* gamma,
                      const StorageView& input,
                      StorageView& output) const;

      template <Device D, typename T>
      void compute(const StorageView* beta,
                   const StorageView* gamma,
                   const StorageView& input,
                   const dim_t axis,
                   const dim_t outer_size,
                   const dim_t axis_size,
                   const dim_t inner_size,
                   StorageView& output) const;

      const dim_t _axis;
      const float _epsilon;
    };

  }
}