natural-tts 0.3.1

High-level bindings to a variety of text-to-speech libraries.
Documentation
use super::*;
use pyo3::prelude::*;

#[derive(Debug)]
pub struct CoquiModel {
    model: Py<PyAny>,
    device: String,
}

impl Clone for CoquiModel {
    fn clone(&self) -> Self {
        Python::attach(|py| -> Self {
            return Self {
                model: self.model.clone_ref(py),
                device: self.device.clone(),
            };
        })
    }
}

impl CoquiModel {
    pub fn new(model_name: String, use_gpu: bool) -> Result<Self, TtsError> {
        let m = Python::attach(|py| -> Result<Self, TtsError> {
            let torch = py.import("torch")?;
            let tts = py.import("TTS.api")?;

            let cuda: bool = torch
                .getattr("cuda")?
                .getattr("is_available")?
                .call0()?
                .extract()?;

            let device: String = if cuda && use_gpu {
                "cuda:0".to_string()
            } else {
                "cpu".to_string()
            };

            let model = tts
                .getattr("TTS")?
                .call1((("model_name", model_name), ("progress_bar", false)))?
                .getattr("to")?
                .call1((device.clone(), ("return_tensors", "pt")))?
                .unbind();

            Ok(Self { model, device })
        });

        m
    }

    pub fn generate(&self, message: String, path: &PathBuf) -> Result<(), TtsError> {
        Python::attach(|py| -> Result<(), TtsError> {
            self.model.getattr(py, "tts_to_file")?.call1(
                py,
                (("text", message), ("file_path", path.to_str().unwrap())),
            )?;
            Ok(())
        })
    }
}

impl Default for CoquiModel {
    fn default() -> Self {
        Self::new("tts_models/en/ljspeech/vits".to_string(), true).unwrap()
    }
}

impl NaturalModelTrait for CoquiModel {
    type SynthesizeType = f32;

    fn save(&mut self, message: String, path: &PathBuf) -> Result<(), TtsError> {
        let _ = self.generate(message, path)?;
        did_save(path)
    }
}