1use crate::{
12 NovelTTSError,
13 queue::{self, TTSQueueInput, TTSQueueOutput},
14 utils::{TextSegment, preprocess_text},
15};
16use kokoro_tts::{KokoroTts, Voice};
17use rodio::buffer::SamplesBuffer;
18use std::sync::Arc;
19use tokio::sync::{Mutex, mpsc::Receiver};
20use tokio_util::sync::CancellationToken;
21
22#[derive(Clone)]
24pub struct ChapterTTS {
25 pub texts: Vec<TextSegment>,
26 pub cancel_token: CancellationToken,
28 pub active_index: Arc<Mutex<usize>>,
29 pub tts: Arc<KokoroTts>,
30 pub queue: Option<Arc<TTSQueueInput<SamplesBuffer>>>,
31 pub generate_index: usize,
32}
33
34impl ChapterTTS {
35 pub fn new(tts: Arc<KokoroTts>, text: &str) -> Self {
43 Self {
44 cancel_token: CancellationToken::new(),
45 active_index: Arc::new(Mutex::new(0)),
46 tts,
47 texts: preprocess_text(text, 200),
48 queue: None,
49 generate_index: 0,
50 }
51 }
52
53 pub fn stream(
70 &mut self,
71 voice: Voice,
72 on_error: impl Fn(NovelTTSError) + Send + Sync + 'static,
73 ) -> (TTSQueueOutput<SamplesBuffer>, Receiver<Option<usize>>) {
74 let (audio_queue_tx, audio_queue_rx) = queue::queue();
75 self.queue.replace(audio_queue_tx.clone());
76
77 let (position_tx, position_rx) = tokio::sync::mpsc::channel::<Option<usize>>(1);
78
79 self.cancel_token = CancellationToken::new();
80
81 let cancel_token = self.cancel_token.clone();
82 let tts = self.tts.clone();
83 let active_index = self.active_index.clone();
84 let texts = self.texts.clone();
85 self.generate_index = *self.active_index.try_lock().unwrap();
86
87 tokio::spawn(async move {
88 let n = *active_index.lock().await;
89 let len: usize = texts.len() - n;
90 for (index, TextSegment { text, .. }) in texts.iter().skip(n).enumerate() {
91 tokio::select! {
92 _ = cancel_token.cancelled() => {
93 break;
94 }
95 res = tts.synth(text, voice) =>{
96 let Ok((data, _)) = res else{
97 on_error(NovelTTSError::from(res.err().unwrap()));
98 continue;
99 };
100 let buffer = SamplesBuffer::new(1, 24000, data);
101
102 let mut signal = audio_queue_tx.append_with_signal(buffer.clone());
103
104
105 if index == len-1{
107 audio_queue_tx.set_is_finished(true);
108 }
109
110 tokio::spawn({
111 let position_tx = position_tx.clone();
112 let active_index = active_index.clone();
113 async move {
114 while let Some(end) = signal.recv().await {
115 if !end {
116 let _ = position_tx.send(Some(n+index)).await;
117 *active_index.lock().await = n+index;
118 } else if index == len-1 {
119 let _ = position_tx.send(None).await;
120 }
121 }
122
123 }
124 });
125 }
126 }
127 }
128 });
129 (audio_queue_rx, position_rx)
130 }
131
132 pub fn cancel(&self) {
136 self.cancel_token.cancel();
137 }
138
139 pub fn set_index(&self, index: usize) {
141 if index <= self.texts.len() {
142 let mut active_index = self.active_index.try_lock().unwrap();
143 *active_index = index;
144 }
145 }
146
147 pub fn retrieve_output(&self, index: usize) -> Option<TTSQueueOutput<SamplesBuffer>> {
149 if index >= self.texts.len() || index < self.generate_index {
150 return None;
151 }
152 self.queue
153 .as_ref()
154 .map(|q| TTSQueueOutput::new(q.clone(), index))
155 }
156}