1mod error;
2mod g2p;
3mod stream;
4mod synthesizer;
5mod tokenizer;
6mod transcription;
7mod voice;
8
9use {
10 bincode::{config::standard, decode_from_slice},
11 ort::{execution_providers::CUDAExecutionProvider, session::Session},
12 std::{collections::HashMap, path::Path, sync::Arc, time::Duration},
13 tokio::{fs::read, sync::Mutex},
14};
15pub use {error::*, g2p::*, stream::*, tokenizer::*, transcription::*, voice::*};
16
17pub struct KokoroTts {
18 model: Arc<Mutex<Session>>,
19 voices: Arc<HashMap<String, Vec<Vec<Vec<f32>>>>>,
20}
21
22impl KokoroTts {
23 pub async fn new<P: AsRef<Path>>(model_path: P, voices_path: P) -> Result<Self, KokoroError> {
24 let voices = read(voices_path).await?;
25 let (voices, _) = decode_from_slice(&voices, standard())?;
26
27 let model = Session::builder()?
28 .with_execution_providers([CUDAExecutionProvider::default().build()])?
29 .commit_from_file(model_path)?;
30 Ok(Self {
31 model: Arc::new(model.into()),
32 voices,
33 })
34 }
35
36 pub async fn new_from_bytes<B>(model: B, voices: B) -> Result<Self, KokoroError>
37 where
38 B: AsRef<[u8]>,
39 {
40 let (voices, _) = decode_from_slice(voices.as_ref(), standard())?;
41
42 let model = Session::builder()?
43 .with_execution_providers([CUDAExecutionProvider::default().build()])?
44 .commit_from_memory(model.as_ref())?;
45 Ok(Self {
46 model: Arc::new(model.into()),
47 voices,
48 })
49 }
50
51 pub async fn synth<S>(&self, text: S, voice: Voice) -> Result<(Vec<f32>, Duration), KokoroError>
52 where
53 S: AsRef<str>,
54 {
55 let name = voice.get_name();
56 let pack = self
57 .voices
58 .get(name)
59 .ok_or(KokoroError::VoiceNotFound(name.to_owned()))?;
60 synthesizer::synth(Arc::downgrade(&self.model), text, pack, voice).await
61 }
62
63 pub fn stream<S>(&self, voice: Voice) -> (SynthSink<S>, SynthStream)
64 where
65 S: AsRef<str> + Send + 'static,
66 {
67 let voices = Arc::downgrade(&self.voices);
68 let model = Arc::downgrade(&self.model);
69
70 start_synth_session(voice, move |text, voice| {
71 let voices = voices.clone();
72 let model = model.clone();
73 async move {
74 let name = voice.get_name();
75 let voices = voices.upgrade().ok_or(KokoroError::ModelReleased)?;
76 let pack = voices
77 .get(name)
78 .ok_or(KokoroError::VoiceNotFound(name.to_owned()))?;
79 synthesizer::synth(model, text, pack, voice).await
80 }
81 })
82 }
83}