ct2rs 0.9.18

Rust bindings for OpenNMT/CTranslate2
Documentation
#pragma once

#include "op.h"

namespace ctranslate2 {
  namespace ops {

    class SoftMax : public UnaryOp {
    public:
      SoftMax(bool log = false);

      using UnaryOp::operator();
      void operator()(StorageView& x) const;
      void operator()(const StorageView& x, StorageView& y) const override;
      void operator()(const StorageView& x, const StorageView& lengths, StorageView& y) const;
      void operator()(const StorageView& x, const StorageView* lengths, StorageView& y) const;

    private:
      template <Device D, typename T>
      void compute(const StorageView& input, const StorageView* lengths, StorageView& output) const;

      bool _log;
    };


    class LogSoftMax : public SoftMax {
    public:
      LogSoftMax();
    };

  }
}