use crate::{
NovelTTSError,
queue::{self, TTSQueueInput, TTSQueueOutput},
utils::{TextSegment, preprocess_text},
};
use kokoro_tts::{KokoroTts, Voice};
use rodio::buffer::SamplesBuffer;
use std::sync::Arc;
use tokio::sync::{Mutex, mpsc::Receiver};
use tokio_util::sync::CancellationToken;
#[derive(Clone)]
pub struct ChapterTTS {
pub texts: Vec<TextSegment>,
pub cancel_token: CancellationToken,
pub active_index: Arc<Mutex<usize>>,
pub tts: Arc<KokoroTts>,
pub queue: Option<Arc<TTSQueueInput<SamplesBuffer>>>,
pub generate_index: usize,
}
impl ChapterTTS {
pub fn new(tts: Arc<KokoroTts>, text: &str) -> Self {
Self {
cancel_token: CancellationToken::new(),
active_index: Arc::new(Mutex::new(0)),
tts,
texts: preprocess_text(text, 200),
queue: None,
generate_index: 0,
}
}
pub fn stream(
&mut self,
voice: Voice,
on_error: impl Fn(NovelTTSError) + Send + Sync + 'static,
) -> (TTSQueueOutput<SamplesBuffer>, Receiver<Option<usize>>) {
let (audio_queue_tx, audio_queue_rx) = queue::queue();
self.queue.replace(audio_queue_tx.clone());
let (position_tx, position_rx) = tokio::sync::mpsc::channel::<Option<usize>>(1);
self.cancel_token = CancellationToken::new();
let cancel_token = self.cancel_token.clone();
let tts = self.tts.clone();
let active_index = self.active_index.clone();
let texts = self.texts.clone();
self.generate_index = *self.active_index.try_lock().unwrap();
tokio::spawn(async move {
let n = *active_index.lock().await;
let len: usize = texts.len() - n;
for (index, TextSegment { text, .. }) in texts.iter().skip(n).enumerate() {
tokio::select! {
_ = cancel_token.cancelled() => {
break;
}
res = tts.synth(text, voice) =>{
let Ok((data, _)) = res else{
on_error(NovelTTSError::from(res.err().unwrap()));
continue;
};
let buffer = SamplesBuffer::new(1, 24000, data);
let mut signal = audio_queue_tx.append_with_signal(buffer.clone());
if index == len-1{
audio_queue_tx.set_is_finished(true);
}
tokio::spawn({
let position_tx = position_tx.clone();
let active_index = active_index.clone();
async move {
while let Some(end) = signal.recv().await {
if !end {
let _ = position_tx.send(Some(n+index)).await;
*active_index.lock().await = n+index;
} else if index == len-1 {
let _ = position_tx.send(None).await;
}
}
}
});
}
}
}
});
(audio_queue_rx, position_rx)
}
pub fn cancel(&self) {
self.cancel_token.cancel();
}
pub fn set_index(&self, index: usize) {
if index <= self.texts.len() {
let mut active_index = self.active_index.try_lock().unwrap();
*active_index = index;
}
}
pub fn retrieve_output(&self, index: usize) -> Option<TTSQueueOutput<SamplesBuffer>> {
if index >= self.texts.len() || index < self.generate_index {
return None;
}
self.queue
.as_ref()
.map(|q| TTSQueueOutput::new(q.clone(), index))
}
}