ct2rs 0.9.18

Rust bindings for OpenNMT/CTranslate2
Documentation
#include "module.h"

#include <ctranslate2/models/wav2vec2bert.h>

#include "replica_pool.h"

namespace ctranslate2 {
  namespace python {

    class Wav2Vec2BertWrapper : public ReplicaPoolHelper<models::Wav2Vec2Bert> {
    public:
      using ReplicaPoolHelper::ReplicaPoolHelper;

      StorageView encode(const StorageView& features, const bool to_cpu) {
        std::shared_lock lock(_mutex);
        assert_model_is_ready();
        return _pool->encode(features, to_cpu).get();
      }
    };


    void register_wav2vec2bert(py::module& m) {
      py::class_<Wav2Vec2BertWrapper>(
        m, "Wav2Vec2Bert",
        R"pbdoc(
            Implements the Wav2Vec2Bert speech recognition model published by Facebook.

            See Also:
               https://github.com/facebookresearch/fairseq/tree/main/examples/wav2vec
        )pbdoc")

        .def(py::init<const std::string&, const std::string&, const std::variant<int, std::vector<int>>&, const StringOrMap&, size_t, size_t, long, bool, bool, py::object>(),
             py::arg("model_path"),
             py::arg("device")="cpu",
             py::kw_only(),
             py::arg("device_index")=0,
             py::arg("compute_type")="default",
             py::arg("inter_threads")=1,
             py::arg("intra_threads")=0,
             py::arg("max_queued_batches")=0,
             py::arg("flash_attention")=false,
             py::arg("tensor_parallel")=false,
             py::arg("files")=py::none(),
             R"pbdoc(
                 Initializes a Wav2Vec2Bert model from a converted model.

                 Arguments:
                   model_path: Path to the CTranslate2 model directory.
                   device: Device to use (possible values are: cpu, cuda, auto).
                   device_index: Device IDs where to place this model on.
                   compute_type: Model computation type or a dictionary mapping a device name
                     to the computation type (possible values are: default, auto, int8, int8_float32,
                     int8_float16, int8_bfloat16, int16, float16, bfloat16, float32).
                   inter_threads: Number of workers to allow executing multiple batches in parallel.
                   intra_threads: Number of OpenMP threads per worker (0 to use a default value).
                   max_queued_batches: Maximum numbers of batches in the worker queue (-1 for unlimited,
                     0 for an automatic value). When the queue is full, future requests will block
                     until a free slot is available.
                   flash_attention: run model with flash attention 2 for self-attention layer
                   tensor_parallel: run model with tensor parallel mode
                   files: Load model files from the memory. This argument is a dictionary mapping
                     file names to file contents as file-like or bytes objects. If this is set,
                     :obj:`model_path` acts as an identifier for this model.
             )pbdoc")

        .def_property_readonly("device", &Wav2Vec2BertWrapper::device,
                               "Device this model is running on.")
        .def_property_readonly("device_index", &Wav2Vec2BertWrapper::device_index,
                               "List of device IDs where this model is running on.")
        .def_property_readonly("compute_type", &Wav2Vec2BertWrapper::compute_type,
                               "Computation type used by the model.")
        .def_property_readonly("num_workers", &Wav2Vec2BertWrapper::num_replicas,
                               "Number of model workers backing this instance.")
        .def_property_readonly("num_queued_batches", &Wav2Vec2BertWrapper::num_queued_batches,
                               "Number of batches waiting to be processed.")
        .def_property_readonly("tensor_parallel", &Wav2Vec2BertWrapper::tensor_parallel,
                               "Run model with tensor parallel mode.")
        .def_property_readonly("num_active_batches", &Wav2Vec2BertWrapper::num_active_batches,
                               "Number of batches waiting to be processed or currently processed.")

        .def("encode", &Wav2Vec2BertWrapper::encode,
             py::arg("features"),
             py::arg("to_cpu")=false,
             py::call_guard<py::gil_scoped_release>(),
             R"pbdoc(
                 Encodes the input features.

                 Arguments:
                   features: Mel spectogram of the audio, as a float array with shape
                     ``[batch_size, 80, 3000]``.
                   to_cpu: Copy the encoder output to the CPU before returning the value.

                 Returns:
                   The encoder output.
             )pbdoc")

        .def("unload_model", &Wav2Vec2BertWrapper::unload_model,
             py::arg("to_cpu")=false,
             py::call_guard<py::gil_scoped_release>(),
             R"pbdoc(
                 Unloads the model attached to this wav2vec2bert but keep enough runtime context
                 to quickly resume wav2vec2bert on the initial device.

                 Arguments:
                   to_cpu: If ``True``, the model is moved to the CPU memory and not fully unloaded.
             )pbdoc")

        .def("load_model", &Wav2Vec2BertWrapper::load_model,
             py::arg("keep_cache")=false,
             py::call_guard<py::gil_scoped_release>(),
             R"pbdoc(
                 Loads the model back to the initial device.

                 Arguments:
                   keep_cache: If ``True``, the model cache in the CPU memory is not deleted if it exists.
             )pbdoc")

        .def_property_readonly("model_is_loaded", &Wav2Vec2BertWrapper::model_is_loaded,
                               "Whether the model is loaded on the initial device and ready to be used.")
        ;
    }

  }
}