coqui-stt 1.0.2

Safe wrapper around the Coqui STT C library
Documentation
#![allow(clippy::missing_safety_doc)]
use crate::{Metadata, Stream};
use std::ffi::CStr;
use std::os::raw::c_uint;

/// A trained Coqui STT model.
pub struct Model(pub(crate) *mut coqui_stt_sys::ModelState);

// these implementations are safe, as ModelState can be passed between threads safely
unsafe impl Send for Model {}
unsafe impl Sync for Model {}

impl Drop for Model {
    #[inline]
    fn drop(&mut self) {
        // SAFETY: this is only called after the model has been disposed of
        unsafe { coqui_stt_sys::STT_FreeModel(self.0) }
    }
}

impl Model {
    /// Create a new model.
    ///
    /// # Errors
    /// Returns an error if the model path is invalid, or for other reasons.
    #[inline]
    pub fn new(model_path: impl Into<String>) -> crate::Result<Self> {
        Self::_new(model_path.into())
    }

    fn _new(model_path: String) -> crate::Result<Self> {
        let mut model_path = model_path.into_bytes();
        model_path.reserve_exact(1);
        model_path.push(b'\0');
        let model_path = CStr::from_bytes_with_nul(model_path.as_ref())?;

        let mut state = std::ptr::null_mut::<coqui_stt_sys::ModelState>();

        // SAFETY: creating a model is only done with a null pointer and a model path,
        // both of which have been checked
        let retval = unsafe {
            coqui_stt_sys::STT_CreateModel(model_path.as_ptr(), std::ptr::addr_of_mut!(state))
        };

        if let Some(e) = crate::Error::from_c_int(retval) {
            return Err(e);
        }

        if state.is_null() {
            return Err(crate::Error::Unknown);
        }

        Ok(Self(state))
    }

    /// Create a new model from a memory buffer.
    ///
    /// # Errors
    /// Returns an error if the model is invalid, or for other reasons.
    #[inline]
    #[cfg(not(target_os = "windows"))]
    pub fn new_from_buffer<'a>(buffer: impl AsRef<&'a [u8]>) -> crate::Result<Self> {
        Self::_new_from_buffer(buffer.as_ref())
    }

    #[inline]
    #[cfg(not(target_os = "windows"))]
    fn _new_from_buffer(buffer: &[u8]) -> crate::Result<Self> {
        let mut state = std::ptr::null_mut::<coqui_stt_sys::ModelState>();

        // SAFETY: creating a model is only done with a null pointer and a model buffer
        // both of which have been checked
        let retval = unsafe {
            coqui_stt_sys::STT_CreateModelFromBuffer(
                buffer.as_ptr().cast::<i8>(),
                buffer.len() as c_uint,
                std::ptr::addr_of_mut!(state),
            )
        };

        if let Some(e) = crate::Error::from_c_int(retval) {
            return Err(e);
        }

        if state.is_null() {
            return Err(crate::Error::Unknown);
        }

        Ok(Self(state))
    }

    /// Take this model, and return the inner model state.
    ///
    /// This is useful if the safe API does not provide something you need.
    ///
    /// # Safety
    /// Once this is called, the memory management of the model is no longer handled for you.
    ///
    /// You must not forget to call `STT_FreeModel` once you are done
    /// with the pointer to dispose of the model properly.
    #[inline]
    #[must_use]
    pub unsafe fn into_inner(self) -> *mut coqui_stt_sys::ModelState {
        let manual_drop = std::mem::ManuallyDrop::new(self);

        manual_drop.0
    }

    /// Create a new model from an existing model state.
    ///
    /// # Safety
    /// You must ensure `state` is a valid model state.
    #[inline]
    pub const unsafe fn from_model_state(state: *mut coqui_stt_sys::ModelState) -> Self {
        Self(state)
    }

    /// Enable an external scorer for this model.
    ///
    /// # Errors
    /// Returns an error if the `scorer_path`/file pointed to is invalid in some way.
    #[inline]
    pub fn enable_external_scorer(&mut self, scorer_path: impl Into<String>) -> crate::Result<()> {
        self._enable_external_scorer(scorer_path.into())
    }

    #[inline]
    fn _enable_external_scorer(&mut self, scorer_path: String) -> crate::Result<()> {
        let mut scorer_path = scorer_path.into_bytes();
        scorer_path.reserve_exact(1);
        scorer_path.push(b'\0');
        let scorer_path = CStr::from_bytes_with_nul(scorer_path.as_ref())?;
        handle_error!(coqui_stt_sys::STT_EnableExternalScorer(
            self.0,
            scorer_path.as_ptr()
        ))
    }

    /// Enable an external scorer for this model, loaded from a buffer in memory.
    ///
    /// # Errors
    /// Returns an error if the scorer in memory is invalid in some way.
    #[inline]
    #[cfg(not(target_os = "windows"))]
    pub fn enable_external_scorer_from_buffer(
        &mut self,
        buffer: impl AsRef<[u8]>,
    ) -> crate::Result<()> {
        self._enable_external_scorer_from_buffer(buffer.as_ref())
    }

    #[inline]
    #[cfg(not(target_os = "windows"))]
    fn _enable_external_scorer_from_buffer(&mut self, buffer: &[u8]) -> crate::Result<()> {
        handle_error!(coqui_stt_sys::STT_EnableExternalScorerFromBuffer(
            self.0,
            buffer.as_ptr().cast::<i8>(),
            buffer.len() as c_uint
        ))
    }

    /// Disable an external scorer that was previously set up with
    /// [`enable_external_scorer`](crate::Model::enable_external_scorer).
    ///
    /// # Errors
    /// Returns an error if an error happened while disabling the scorer.
    #[inline]
    pub fn disable_external_scorer(&mut self) -> crate::Result<()> {
        handle_error!(coqui_stt_sys::STT_DisableExternalScorer(self.0))
    }

    /// Add a hot-word and its boost.
    ///
    /// Words that don’t occur in the scorer (e.g. proper nouns),
    /// or strings that contain spaces won't be taken into account.
    ///
    /// # Errors
    /// Passes through any errors from the C library. See enum [`Error`](crate::Error).
    #[inline]
    pub fn add_hot_word(&mut self, word: impl Into<String>, boost: f32) -> crate::Result<()> {
        self._add_hot_word(word.into(), boost)
    }

    #[inline]
    fn _add_hot_word(&mut self, word: String, boost: f32) -> crate::Result<()> {
        let mut word = word.into_bytes();
        word.reserve_exact(1);
        word.push(b'\0');
        let word = CStr::from_bytes_with_nul(word.as_ref())?;
        handle_error!(coqui_stt_sys::STT_AddHotWord(self.0, word.as_ptr(), boost))
    }

    /// Remove entry for a hot-word from the hot-words map.
    ///
    /// # Errors
    /// Passes through any errors from the C library. See enum [`Error`](crate::Error).
    ///
    /// Additionally, if the input word contains a NUL character anywhere in it, returns an error.
    #[inline]
    pub fn erase_hot_word(&mut self, word: impl Into<String>) -> crate::Result<()> {
        self._erase_hot_word(word.into())
    }

    #[inline]
    fn _erase_hot_word(&mut self, word: String) -> crate::Result<()> {
        let mut word = word.into_bytes();
        word.reserve_exact(1);
        word.push(b'\0');
        let word = CStr::from_bytes_with_nul(word.as_ref())?;
        handle_error!(coqui_stt_sys::STT_EraseHotWord(self.0, word.as_ptr()))
    }

    /// Removes all elements from the hot-words map.
    ///
    /// # Errors
    /// Passes through any errors from the C library. See enum [`Error`](crate::Error).
    #[inline]
    pub fn clear_hot_words(&mut self) -> crate::Result<()> {
        handle_error!(coqui_stt_sys::STT_ClearHotWords(self.0))
    }

    /// Set hyperparameters alpha and beta of the external scorer.
    ///
    /// `alpha` is the alpha hyperparameter of the decoder. Language model weight.
    ///
    /// `beta` is the beta hyperparameter of the decoder. Word insertion weight.
    ///
    /// # Errors
    /// Passes through any errors from the C library. See enum [`Error`](crate::Error).
    #[inline]
    pub fn set_scorer_alpha_beta(&mut self, alpha: f32, beta: f32) -> crate::Result<()> {
        handle_error!(coqui_stt_sys::STT_SetScorerAlphaBeta(self.0, alpha, beta))
    }

    /// Return the sample rate expected by a model in Hz.
    #[inline]
    #[must_use]
    pub fn get_sample_rate(&self) -> i32 {
        unsafe { coqui_stt_sys::STT_GetModelSampleRate(self.0 as *const _) }
    }

    /// Use the Coqui STT model to convert speech to text.
    ///
    /// `buffer` should be a 16-bit, mono, raw audio signal
    /// at the appropriate sample rate, matching what the model was trained on.
    /// The required sample rate can be obtained from [`get_sample_rate`](crate::Model::get_sample_rate).
    ///
    /// # Errors
    /// Passes through any errors from the C library. See enum [`Error`](crate::Error).
    ///
    /// Additionally, if the returned string is not valid UTF-8, this function returns an error.
    #[allow(clippy::missing_inline_in_public_items)]
    pub fn speech_to_text(&mut self, buffer: &[i16]) -> crate::Result<String> {
        let ptr = unsafe {
            coqui_stt_sys::STT_SpeechToText(self.0, buffer.as_ptr(), buffer.len() as c_uint)
        };

        if ptr.is_null() {
            return Err(crate::Error::Unknown);
        }

        // SAFETY: STT_SpeechToText will always return a valid CStr
        let cstr = unsafe { CStr::from_ptr(ptr) };
        let mut unchecked_str = Vec::new();
        unchecked_str.extend_from_slice(cstr.to_bytes());

        // SAFETY: the pointer the string points to is not used anywhere after this call
        unsafe { coqui_stt_sys::STT_FreeString(ptr) }

        Ok(String::from_utf8(unchecked_str)?)
    }

    /// Use the Coqui STT model to convert speech to text and output results including metadata.
    ///
    /// `buffer` should be a 16-bit, mono, raw audio signal
    /// at the appropriate sample rate, matching what the model was trained on.
    /// The required sample rate can be obtained from [`get_sample_rate`](crate::Model::get_sample_rate).
    ///
    /// `num_results` is the maximum number of possible transcriptions to return.
    /// Note that it is not guaranteed this many will be returned at minimum,
    /// but there will never be more than this number at maximum.
    ///
    /// # Errors
    /// Passes through any errors from the C library. See enum [`Error`](crate::Error).
    #[inline]
    pub fn speech_to_text_with_metadata(
        &mut self,
        buffer: &[i16],
        num_results: u32,
    ) -> crate::Result<Metadata> {
        let ptr = unsafe {
            coqui_stt_sys::STT_SpeechToTextWithMetadata(
                self.0,
                buffer.as_ptr(),
                buffer.len() as c_uint,
                num_results,
            )
        };

        if ptr.is_null() {
            return Err(crate::Error::Unknown);
        }

        Ok(crate::Metadata::new(ptr))
    }

    /// Convert this model into one used for streaming inference states.
    ///
    /// Note that this requires exclusive access to the model,
    /// so it is not possible to use the same model for multiple streams concurrently.
    ///
    /// # Errors
    /// Passes through any errors from the C library. See enum [`Error`](crate::Error).
    #[allow(clippy::missing_inline_in_public_items)]
    pub fn as_streaming(&mut self) -> crate::Result<Stream> {
        let mut state = std::ptr::null_mut();

        let retval = unsafe { coqui_stt_sys::STT_CreateStream(self.0, &mut state) };

        if let Some(e) = crate::Error::from_c_int(retval) {
            return Err(e);
        }

        if state.is_null() {
            return Err(crate::Error::Unknown);
        }

        Ok(Stream {
            model: self,
            state,
            already_freed: false,
        })
    }
}