kokoro_tts/
lib.rs

1mod error;
2mod g2p;
3mod stream;
4mod synthesizer;
5mod tokenizer;
6mod transcription;
7mod voice;
8
9use {
10    bincode::{config::standard, decode_from_slice},
11    ort::{execution_providers::CUDAExecutionProvider, session::Session},
12    std::{collections::HashMap, path::Path, sync::Arc, time::Duration},
13    tokio::{fs::read, sync::Mutex},
14};
15pub use {error::*, g2p::*, stream::*, tokenizer::*, transcription::*, voice::*};
16
17pub struct KokoroTts {
18    model: Arc<Mutex<Session>>,
19    voices: Arc<HashMap<String, Vec<Vec<Vec<f32>>>>>,
20}
21
22impl KokoroTts {
23    pub async fn new<P: AsRef<Path>>(model_path: P, voices_path: P) -> Result<Self, KokoroError> {
24        let voices = read(voices_path).await?;
25        let (voices, _) = decode_from_slice(&voices, standard())?;
26
27        let model = Session::builder()?
28            .with_execution_providers([CUDAExecutionProvider::default().build()])?
29            .commit_from_file(model_path)?;
30        Ok(Self {
31            model: Arc::new(model.into()),
32            voices,
33        })
34    }
35
36    pub async fn new_from_bytes<B>(model: B, voices: B) -> Result<Self, KokoroError>
37    where
38        B: AsRef<[u8]>,
39    {
40        let (voices, _) = decode_from_slice(voices.as_ref(), standard())?;
41
42        let model = Session::builder()?
43            .with_execution_providers([CUDAExecutionProvider::default().build()])?
44            .commit_from_memory(model.as_ref())?;
45        Ok(Self {
46            model: Arc::new(model.into()),
47            voices,
48        })
49    }
50
51    pub async fn synth<S>(&self, text: S, voice: Voice) -> Result<(Vec<f32>, Duration), KokoroError>
52    where
53        S: AsRef<str>,
54    {
55        let name = voice.get_name();
56        let pack = self
57            .voices
58            .get(name)
59            .ok_or(KokoroError::VoiceNotFound(name.to_owned()))?;
60        synthesizer::synth(Arc::downgrade(&self.model), text, pack, voice).await
61    }
62
63    pub fn stream<S>(&self, voice: Voice) -> (SynthSink<S>, SynthStream)
64    where
65        S: AsRef<str> + Send + 'static,
66    {
67        let voices = Arc::downgrade(&self.voices);
68        let model = Arc::downgrade(&self.model);
69
70        start_synth_session(voice, move |text, voice| {
71            let voices = voices.clone();
72            let model = model.clone();
73            async move {
74                let name = voice.get_name();
75                let voices = voices.upgrade().ok_or(KokoroError::ModelReleased)?;
76                let pack = voices
77                    .get(name)
78                    .ok_or(KokoroError::VoiceNotFound(name.to_owned()))?;
79                synthesizer::synth(model, text, pack, voice).await
80            }
81        })
82    }
83}