1use crate::synthesis::{SynthesisClient, SynthesisEvent, SynthesisOption, SynthesisType};
2use anyhow::Result;
3use anyhow::anyhow;
4use async_trait::async_trait;
5use bytes::Bytes;
6use futures::SinkExt;
7use futures::StreamExt;
8use futures::TryStreamExt;
9use futures::future;
10use futures::future::FutureExt;
11use futures::stream;
12use futures::stream::SplitSink;
13use futures::{Stream, stream::BoxStream};
14use serde::Deserialize;
15use serde::Serialize;
16use tokio::net::TcpStream;
17use tokio::sync::mpsc;
18use tokio_stream::wrappers::UnboundedReceiverStream;
19use tokio_tungstenite::MaybeTlsStream;
20use tokio_tungstenite::WebSocketStream;
21use tokio_tungstenite::connect_async;
22use tokio_tungstenite::tungstenite::Message;
23use tokio_tungstenite::tungstenite::client::IntoClientRequest;
24use tracing::warn;
25use url::Url;
26
27type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
28type WsSink = SplitSink<WsStream, Message>;
29
30const DEEPGRAM_BASE_URL: &str = "https://api.deepgram.com/v1/speak";
31const TERMINATORS: [char; 3] = ['.', '?', '!'];
32
33pub struct RestClient {
35 option: SynthesisOption,
36 tx: Option<mpsc::UnboundedSender<(String, Option<usize>, Option<SynthesisOption>)>>,
37}
38
39#[derive(Serialize)]
40struct Payload {
41 text: String,
42}
43
44impl RestClient {
45 pub fn new(option: SynthesisOption) -> Self {
46 Self { option, tx: None }
47 }
48}
49
50fn request_url(option: &SynthesisOption, protocol: &str) -> Url {
55 let mut url = Url::parse(DEEPGRAM_BASE_URL).expect("Deepgram base url is invalid");
56 url.set_scheme(protocol).expect("illegal url scheme");
57
58 let mut query = url.query_pairs_mut();
59
60 if let Some(speaker) = option.speaker.as_ref() {
61 query.append_pair("model", speaker);
62 }
63
64 if let Some(codec) = option.codec.as_ref() {
65 match codec.as_str() {
66 "pcm" => query.append_pair("encoding", "linear16"),
67 "pcmu" => query.append_pair("encoding", "mulaw"),
68 "pcma" => query.append_pair("encoding", "alaw"),
69 _ => query.append_pair("encoding", "linear16"),
70 };
71 } else {
72 query.append_pair("encoding", "linear16");
73 }
74
75 let samplerate = option.samplerate.unwrap_or(16000);
76 query.append_pair("sample_rate", samplerate.to_string().as_str());
77
78 drop(query);
79 url
80}
81
82async fn chunked_stream(
83 option: SynthesisOption,
84 text: String,
85) -> Result<impl Stream<Item = Result<Bytes>>> {
86 let url = request_url(&option, "https");
87 let token = option
88 .secret_key
89 .as_ref()
90 .ok_or_else(|| anyhow!("Deepegram tts: missing api key"))?;
91 let payload = Payload { text };
92 let client = reqwest::Client::new();
93 let resp = client
94 .post(url)
95 .header("Content-Type", "application/json")
96 .header("Authorization", format!("Token {}", token))
97 .json(&payload)
98 .send()
99 .await?;
100 Ok(resp.bytes_stream().map_err(anyhow::Error::from))
101}
102
103#[async_trait]
104impl SynthesisClient for RestClient {
105 fn provider(&self) -> SynthesisType {
106 SynthesisType::Deepgram
107 }
108
109 async fn start(
110 &mut self,
111 ) -> Result<BoxStream<'static, (Option<usize>, Result<SynthesisEvent>)>> {
112 let (tx, rx) = mpsc::unbounded_channel();
113 self.tx = Some(tx);
114 let max_concurrent_tasks = self.option.max_concurrent_tasks.unwrap_or(1);
115 let client_option = self.option.clone();
116 let stream = UnboundedReceiverStream::new(rx).flat_map_unordered(
117 max_concurrent_tasks,
118 move |(text, cmd_seq, cmd_option)| {
119 let option = client_option.merge_with(cmd_option);
120 chunked_stream(option, text)
121 .map(move |res| match res {
122 Ok(stream) => stream
123 .map(move |res| res.map(|bytes| SynthesisEvent::AudioChunk(bytes)))
124 .chain(stream::once(future::ready(Ok(SynthesisEvent::Finished))))
125 .boxed(),
126 Err(e) => stream::once(future::ready(Err(e))).boxed(),
127 })
128 .flatten_stream()
129 .map(move |res| (cmd_seq, res))
130 .boxed()
131 },
132 );
133
134 Ok(stream.boxed())
135 }
136
137 async fn synthesize(
138 &mut self,
139 text: &str,
140 cmd_seq: Option<usize>,
141 option: Option<SynthesisOption>,
142 ) -> Result<()> {
143 if let Some(tx) = &self.tx {
144 tx.send((text.to_string(), cmd_seq, option))?;
145 } else {
146 return Err(anyhow::anyhow!("Deepgram TTS: missing client sender"));
147 };
148 Ok(())
149 }
150
151 async fn stop(&mut self) -> Result<()> {
152 self.tx.take();
153 Ok(())
154 }
155}
156
157struct StreamingClient {
158 option: SynthesisOption,
159 sink: Option<WsSink>,
160}
161
162impl StreamingClient {
163 pub fn new(option: SynthesisOption) -> Self {
164 Self { option, sink: None }
165 }
166}
167
168#[derive(Serialize)]
169#[serde(tag = "type")]
170enum Command {
171 Speak { text: String },
172 Flush,
173 Close,
174}
175
176#[allow(dead_code)]
177#[derive(Deserialize, Debug)]
178#[serde(tag = "type")]
179enum Event {
180 Metadata {
181 request_id: String,
182 model_name: String,
183 model_version: String,
184 model_uuid: String,
185 },
186 Flushed {
187 sequence_id: usize,
188 },
189 Cleared {
190 sequence_id: usize,
191 },
192 Warning {
193 description: String,
194 code: String,
195 },
196}
197
198async fn connect(option: SynthesisOption) -> Result<WsStream> {
199 let url = request_url(&option, "wss");
200 let mut request = url.as_str().into_client_request()?;
201 let token = option
202 .secret_key
203 .as_ref()
204 .ok_or_else(|| anyhow!("Deepegram tts: missing api key"))?;
205 request
206 .headers_mut()
207 .insert("Authorization", format!("Token {}", token).parse()?);
208 let (ws_stream, _) = connect_async(request).await?;
209 Ok(ws_stream)
210}
211
212#[async_trait]
213impl SynthesisClient for StreamingClient {
214 fn provider(&self) -> SynthesisType {
215 SynthesisType::Deepgram
216 }
217
218 async fn start(
219 &mut self,
220 ) -> Result<BoxStream<'static, (Option<usize>, Result<SynthesisEvent>)>> {
221 let (sink, source) = connect(self.option.clone()).await?.split();
222 self.sink = Some(sink);
223 let stream = source
224 .filter_map(async move |message| match message {
225 Ok(Message::Binary(bytes)) => Some(Ok(SynthesisEvent::AudioChunk(bytes))),
226 Ok(Message::Text(text)) => {
227 let event: Event =
228 serde_json::from_str(&text).expect("Deepgram TTS API changed!");
229
230 if let Event::Warning { description, code } = event {
231 warn!("Deepgram TTS: warning: {}, {}", description, code);
232 }
233
234 None
235 }
236 Ok(Message::Close(_)) => Some(Ok(SynthesisEvent::Finished)),
237 Err(e) => Some(Err(anyhow!("Deepgram TTS: websocket error: {:?}", e))),
238 _ => None,
239 })
240 .map(|res| (None, res))
241 .boxed();
242 Ok(stream)
243 }
244
245 async fn synthesize(
246 &mut self,
247 text: &str,
248 _cmd_seq: Option<usize>,
249 _option: Option<SynthesisOption>,
250 ) -> Result<()> {
251 if let Some(sink) = &mut self.sink {
252 for sentence in text.split_inclusive(&TERMINATORS[..]) {
255 if !sentence.is_empty() {
256 let speak_cmd = Command::Speak {
257 text: sentence.to_string(),
258 };
259 let speak_json = serde_json::to_string(&speak_cmd)?;
260 sink.send(Message::text(speak_json)).await?;
261 }
262
263 if sentence.ends_with(&TERMINATORS[..]) {
264 let flush_cmd = Command::Flush;
265 let flush_json = serde_json::to_string(&flush_cmd)?;
266 sink.send(Message::text(flush_json)).await?;
267 }
268 }
269 } else {
270 return Err(anyhow::anyhow!("Deepgram TTS: missing sink"));
271 };
272 Ok(())
273 }
274
275 async fn stop(&mut self) -> Result<()> {
276 if let Some(mut sink) = self.sink.take() {
277 let close_cmd = Command::Close;
278 let close_json = serde_json::to_string(&close_cmd)?;
279 sink.send(Message::text(close_json)).await?;
280 } else {
281 warn!("Deepgram TTS: missing sink");
282 }
283 Ok(())
284 }
285}
286
287pub struct DeepegramTtsClient;
288
289impl DeepegramTtsClient {
290 pub fn create(streaming: bool, option: &SynthesisOption) -> Result<Box<dyn SynthesisClient>> {
291 if streaming {
292 Ok(Box::new(StreamingClient::new(option.clone())))
293 } else {
294 Ok(Box::new(RestClient::new(option.clone())))
295 }
296 }
297}