hematite-kokoros 0.1.3

Hematite-maintained fork of the Kokoros Rust text-to-speech engine
Documentation
use crate::tts::koko::ModelStrategy;
use ort::session::Session;

pub trait OrtBase {
    fn set_sess(&mut self, sess: Session);
    fn sess(&self) -> Option<&Session>;

    fn load_model(&mut self, model_path: String) -> Result<(), Box<dyn std::error::Error>> {
        let sess = Session::builder()?
            .with_execution_providers([
                ort::execution_providers::CPUExecutionProvider::default().build()
            ])?
            .with_optimization_level(ort::session::builder::GraphOptimizationLevel::Disable)?
            .with_memory_pattern(false)?
            .with_parallel_execution(false)?
            .with_intra_threads(1)?
            .with_inter_threads(1)?
            .commit_from_file(model_path)?;

        self.set_sess(sess);
        Ok(())
    }

    fn load_model_from_memory(
        &mut self,
        model_bytes: &[u8],
    ) -> Result<(), Box<dyn std::error::Error>> {
        let sess = Session::builder()?
            .with_execution_providers([
                ort::execution_providers::CPUExecutionProvider::default().build()
            ])?
            .with_optimization_level(ort::session::builder::GraphOptimizationLevel::Disable)?
            .with_memory_pattern(false)?
            .with_parallel_execution(false)?
            .with_intra_threads(1)?
            .with_inter_threads(1)?
            .commit_from_memory(model_bytes)?;

        self.set_sess(sess);
        Ok(())
    }

    fn inputs(&self) -> Vec<String> {
        self.sess()
            .map(|s| s.inputs().iter().map(|i| i.name().to_string()).collect())
            .unwrap_or_default()
    }

    fn outputs(&self) -> Vec<String> {
        self.sess()
            .map(|s| s.outputs().iter().map(|o| o.name().to_string()).collect())
            .unwrap_or_default()
    }

    fn infer(
        &mut self,
        tokens_batch: Vec<Vec<i64>>,
        style: &[f32],
        speed: f32,
        strategy: &ModelStrategy,
    ) -> Result<Vec<f32>, Box<dyn std::error::Error>>;
}