ct2rs 0.9.18

Rust bindings for OpenNMT/CTranslate2
Documentation
#pragma once

#include "model.h"

namespace ctranslate2 {
  namespace models {

    class ModelFactory {
    public:
      static ModelFactory& get_instance() {
        static ModelFactory factory;
        return factory;
      }

      template <typename Model, typename... Args>
      bool register_model(const std::string& name, Args&&... args) {
        Builder builder = [args...]() { return std::make_shared<Model>(args...); };
        return _registry.emplace(name, std::move(builder)).second;
      }

      std::shared_ptr<Model> create_model(const std::string& name) const {
        auto it = _registry.find(name);
        if (it == _registry.end())
          throw std::invalid_argument("Unknown model " + name);
        return it->second();
      }

    private:
      ModelFactory() = default;

      using Builder = std::function<std::shared_ptr<models::Model>(void)>;
      std::unordered_map<std::string, Builder> _registry;
    };

    template <typename Model, typename... Args>
    bool register_model(const std::string& name, Args&&... args) {
      return ModelFactory::get_instance().register_model<Model>(name, std::forward<Args>(args)...);
    }

    std::shared_ptr<Model> create_model(const std::string& name);

  }
}