use {
crate::{KokoroError, Voice},
futures::{Sink, SinkExt, Stream},
pin_project::pin_project,
std::{
pin::Pin,
task::{Context, Poll},
time::Duration,
},
tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel},
};
struct Request<S> {
voice: Voice,
text: S,
}
struct Response {
data: Vec<f32>,
took: Duration,
}
#[pin_project]
pub struct SynthStream {
#[pin]
rx: UnboundedReceiver<Response>,
}
impl Stream for SynthStream {
type Item = (Vec<f32>, Duration);
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.project().rx)
.poll_recv(cx)
.map(|i| i.map(|Response { data, took }| (data, took)))
}
}
#[pin_project]
pub struct SynthSink<S> {
tx: UnboundedSender<Request<S>>,
voice: Voice,
}
impl<S> SynthSink<S> {
pub fn set_voice(&mut self, voice: Voice) {
self.voice = voice
}
pub async fn synth(&mut self, text: S) -> Result<(), KokoroError> {
self.send((self.voice, text)).await
}
}
impl<S> Sink<(Voice, S)> for SynthSink<S> {
type Error = KokoroError;
fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn start_send(self: Pin<&mut Self>, (voice, text): (Voice, S)) -> Result<(), Self::Error> {
self.tx
.send(Request { voice, text })
.map_err(|e| KokoroError::Send(e.to_string()))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
}
pub(super) fn start_synth_session<F, R, S>(
voice: Voice,
synth_request_callback: F,
) -> (SynthSink<S>, SynthStream)
where
F: Fn(S, Voice) -> R + Send + 'static,
R: Future<Output = Result<(Vec<f32>, Duration), KokoroError>> + Send,
S: AsRef<str> + Send + 'static,
{
let (tx, mut rx) = unbounded_channel::<Request<S>>();
let (tx2, rx2) = unbounded_channel();
tokio::spawn(async move {
while let Some(req) = rx.recv().await {
let (data, took) = synth_request_callback(req.text, req.voice).await?;
tx2.send(Response { data, took })
.map_err(|e| KokoroError::Send(e.to_string()))?;
}
Ok::<_, KokoroError>(())
});
(SynthSink { tx, voice }, SynthStream { rx: rx2 })
}