ct2rs 0.9.18

Rust bindings for OpenNMT/CTranslate2
Documentation
#include <pybind11/pybind11.h>

#include <ctranslate2/devices.h>
#include <ctranslate2/models/model.h>
#include <ctranslate2/random.h>
#include <ctranslate2/types.h>
#include <ctranslate2/utils.h>

#include "module.h"
#include "utils.h"


static std::unordered_set<std::string>
get_supported_compute_types(const std::string& device_str, const int device_index) {
  const auto device = ctranslate2::str_to_device(device_str);

  const bool support_bfloat16 = ctranslate2::mayiuse_bfloat16(device, device_index);
  const bool support_float16 = ctranslate2::mayiuse_float16(device, device_index);
  const bool support_int16 = ctranslate2::mayiuse_int16(device, device_index);
  const bool support_int8 = ctranslate2::mayiuse_int8(device, device_index);

  std::unordered_set<std::string> compute_types;
  compute_types.emplace("float32");

  if (support_bfloat16)
    compute_types.emplace("bfloat16");

  if (support_float16)
    compute_types.emplace("float16");

  if (support_int16)
    compute_types.emplace("int16");

  if (support_int8) {
    compute_types.emplace("int8");
    compute_types.emplace("int8_float32");

    if (support_float16)
      compute_types.emplace("int8_float16");
    if (support_bfloat16)
      compute_types.emplace("int8_bfloat16");
  }

  return compute_types;
}


PYBIND11_MODULE(_ext, m)
{
  py::options options;
  options.disable_enum_members_docstring();

  m.def("contains_model", &ctranslate2::models::contains_model, py::arg("path"),
        "Helper function to check if a directory seems to contain a CTranslate2 model.");

  m.def("get_cuda_device_count", &ctranslate2::get_gpu_count,
        "Returns the number of visible GPU devices.");

  m.def("get_supported_compute_types", &get_supported_compute_types,
        py::arg("device"),
        py::arg("device_index")=0,
         R"pbdoc(
             Returns the set of supported compute types on a device.

             Arguments:
               device: Device name (cpu or cuda).
               device_index: Device index.

             Example:
                 >>> ctranslate2.get_supported_compute_types("cpu")
                 {'int16', 'float32', 'int8', 'int8_float32'}
                 >>> ctranslate2.get_supported_compute_types("cuda")
                 {'float32', 'int8_float16', 'float16', 'int8', 'int8_float32'}
         )pbdoc");

  m.def("set_random_seed", &ctranslate2::set_random_seed, py::arg("seed"),
        "Sets the seed of random generators.");

  ctranslate2::python::register_logging(m);
  ctranslate2::python::register_storage_view(m);
  ctranslate2::python::register_translation_stats(m);
  ctranslate2::python::register_translation_result(m);
  ctranslate2::python::register_scoring_result(m);
  ctranslate2::python::register_generation_result(m);
  ctranslate2::python::register_translator(m);
  ctranslate2::python::register_generator(m);
  ctranslate2::python::register_encoder(m);
  ctranslate2::python::register_whisper(m);
  ctranslate2::python::register_wav2vec2(m);
  ctranslate2::python::register_wav2vec2bert(m);
  ctranslate2::python::register_mpi(m);
}