1use super::{SynthesisClient, SynthesisOption, SynthesisType};
2use crate::synthesis::SynthesisEvent;
3use anyhow::{Result, anyhow};
4use async_trait::async_trait;
5use futures::{
6 FutureExt, SinkExt, Stream, StreamExt, future,
7 stream::{self, BoxStream, SplitSink},
8};
9use serde::{Deserialize, Serialize};
10use serde_with::skip_serializing_none;
11use std::sync::Arc;
12use tokio::{
13 net::TcpStream,
14 sync::{Notify, mpsc},
15};
16use tokio_stream::wrappers::UnboundedReceiverStream;
17use tokio_tungstenite::{
18 MaybeTlsStream, WebSocketStream, connect_async,
19 tungstenite::{self, Message, client::IntoClientRequest},
20};
21use tracing::warn;
22use uuid::Uuid;
23type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
24type WsSink = SplitSink<WsStream, Message>;
25
26#[derive(Debug, Serialize)]
30struct Command {
31 header: CommandHeader,
32 payload: CommandPayload,
33}
34
35#[derive(Debug, Serialize)]
36#[serde(untagged)]
37enum CommandPayload {
38 Run(RunTaskPayload),
39 Continue(ContinueTaskPayload),
40 Finish(FinishTaskPayload),
41}
42
43impl Command {
44 fn run_task(option: &SynthesisOption, task_id: &str) -> Self {
45 let voice = option
46 .speaker
47 .clone()
48 .unwrap_or_else(|| "longyumi_v2".to_string());
49
50 let format = option.codec.as_deref().unwrap_or("pcm");
51
52 let sample_rate = option.samplerate.unwrap_or(16000) as u32;
53 let volume = option.volume.unwrap_or(50) as u32;
54 let rate = option.speed.unwrap_or(1.0);
55 let model = option
56 .model
57 .clone()
58 .unwrap_or_else(|| "cosyvoice-v2".to_string());
59
60 Command {
61 header: CommandHeader {
62 action: "run-task".to_string(),
63 task_id: task_id.to_string(),
64 streaming: "duplex".to_string(),
65 },
66 payload: CommandPayload::Run(RunTaskPayload {
67 task_group: "audio".to_string(),
68 task: "tts".to_string(),
69 function: "SpeechSynthesizer".to_string(),
70 model,
71 parameters: RunTaskParameters {
72 text_type: "PlainText".to_string(),
73 voice,
74 format: Some(format.to_string()),
75 sample_rate: Some(sample_rate),
76 volume: Some(volume),
77 rate: Some(rate),
78 },
79 input: EmptyInput {},
80 }),
81 }
82 }
83
84 fn continue_task(task_id: &str, text: &str) -> Self {
85 Command {
86 header: CommandHeader {
87 action: "continue-task".to_string(),
88 task_id: task_id.to_string(),
89 streaming: "duplex".to_string(),
90 },
91 payload: CommandPayload::Continue(ContinueTaskPayload {
92 input: PayloadInput {
93 text: text.to_string(),
94 },
95 }),
96 }
97 }
98
99 fn finish_task(task_id: &str) -> Self {
100 Command {
101 header: CommandHeader {
102 action: "finish-task".to_string(),
103 task_id: task_id.to_string(),
104 streaming: "duplex".to_string(),
105 },
106 payload: CommandPayload::Finish(FinishTaskPayload {
107 input: EmptyInput {},
108 }),
109 }
110 }
111}
112
113#[derive(Debug, Serialize)]
114struct CommandHeader {
115 action: String,
116 task_id: String,
117 streaming: String,
118}
119
120#[derive(Debug, Serialize)]
121struct RunTaskPayload {
122 task_group: String,
123 task: String,
124 function: String,
125 model: String,
126 parameters: RunTaskParameters,
127 input: EmptyInput,
128}
129
130#[skip_serializing_none]
131#[derive(Debug, Serialize)]
132struct RunTaskParameters {
133 text_type: String,
134 voice: String,
135 format: Option<String>,
136 sample_rate: Option<u32>,
137 volume: Option<u32>,
138 rate: Option<f32>,
139}
140
141#[derive(Debug, Serialize)]
142struct ContinueTaskPayload {
143 input: PayloadInput,
144}
145
146#[derive(Debug, Serialize, Deserialize)]
147struct PayloadInput {
148 text: String,
149}
150
151#[derive(Debug, Serialize)]
152struct FinishTaskPayload {
153 input: EmptyInput,
154}
155
156#[derive(Debug, Serialize)]
157struct EmptyInput {}
158
159#[derive(Debug, Deserialize)]
161struct Event {
162 header: EventHeader,
163}
164
165#[allow(dead_code)]
166#[derive(Debug, Deserialize)]
167struct EventHeader {
168 task_id: String,
169 event: String,
170 error_code: Option<String>,
171 error_message: Option<String>,
172}
173
174async fn connect(task_id: String, option: SynthesisOption) -> Result<WsStream> {
175 let api_key = option
176 .secret_key
177 .clone()
178 .or_else(|| std::env::var("DASHSCOPE_API_KEY").ok())
179 .ok_or_else(|| anyhow!("Aliyun TTS: missing api key"))?;
180 let ws_url = option
181 .endpoint
182 .as_deref()
183 .unwrap_or("wss://dashscope.aliyuncs.com/api-ws/v1/inference");
184
185 let mut request = ws_url.into_client_request()?;
186 let headers = request.headers_mut();
187 headers.insert("Authorization", format!("Bearer {}", api_key).parse()?);
188 headers.insert("X-DashScope-DataInspection", "enable".parse()?);
189
190 let (mut ws_stream, _) = connect_async(request).await?;
191 let run_task_cmd = Command::run_task(&option, task_id.as_str());
192 let run_task_json = serde_json::to_string(&run_task_cmd)?;
193 ws_stream.send(Message::text(run_task_json)).await?;
194 while let Some(message) = ws_stream.next().await {
195 match message {
196 Ok(Message::Text(text)) => {
197 let event = serde_json::from_str::<Event>(&text)?;
198 match event.header.event.as_str() {
199 "task-started" => {
200 break;
201 }
202 "task-failed" => {
203 let error_code = event
204 .header
205 .error_code
206 .unwrap_or_else(|| "Unknown error code".to_string());
207 let error_msg = event
208 .header
209 .error_message
210 .unwrap_or_else(|| "Unknown error message".to_string());
211 return Err(anyhow!(
212 "Aliyun TTS Task: {} failed: {}, {}",
213 task_id,
214 error_code,
215 error_msg
216 ))?;
217 }
218 _ => {
219 warn!("Aliyun TTS Task: {} unexpected event: {:?}", task_id, event);
220 }
221 }
222 }
223 Ok(Message::Close(_)) => {
224 return Err(anyhow!("Aliyun TTS start failed: closed by server"));
225 }
226 Err(e) => {
227 return Err(anyhow!("Aliyun TTS start failed:: {}", e));
228 }
229 _ => {}
230 }
231 }
232 Ok(ws_stream)
233}
234
235fn event_stream<T>(ws_stream: T) -> impl Stream<Item = Result<SynthesisEvent>> + Send + 'static
236where
237 T: Stream<Item = Result<Message, tungstenite::Error>> + Send + 'static,
238{
239 let notify = Arc::new(Notify::new());
240 let notify_clone = notify.clone();
241 ws_stream
242 .take_until(notify.notified_owned())
243 .filter_map(move |message| {
244 let notify = notify_clone.clone();
245 async move {
246 match message {
247 Ok(Message::Binary(data)) => Some(Ok(SynthesisEvent::AudioChunk(data))),
248 Ok(Message::Text(text)) => {
249 let event: Event =
250 serde_json::from_str(&text).expect("Aliyun TTS API changed!");
251
252 match event.header.event.as_str() {
253 "task-finished" => {
254 notify.notify_one();
255 Some(Ok(SynthesisEvent::Finished))
256 }
257 "task-failed" => {
258 let error_code = event
259 .header
260 .error_code
261 .unwrap_or_else(|| "Unknown error code".to_string());
262 let error_msg = event
263 .header
264 .error_message
265 .unwrap_or_else(|| "Unknown error message".to_string());
266 notify.notify_one();
267 Some(Err(anyhow!(
268 "Aliyun TTS Task: {} failed: {}, {}",
269 event.header.task_id,
270 error_code,
271 error_msg
272 )))
273 }
274 _ => None,
275 }
276 }
277 Ok(Message::Close(_)) => {
278 notify.notify_one();
279 warn!("Aliyun TTS: closed by remote");
280 None
281 }
282 Err(e) => {
283 notify.notify_one();
284 Some(Err(anyhow!("Aliyun TTS: websocket error: {:?}", e)))
285 }
286 _ => None,
287 }
288 }
289 })
290}
291#[derive(Debug)]
292pub struct StreamingClient {
293 task_id: String,
294 option: SynthesisOption,
295 ws_sink: Option<WsSink>,
296}
297
298impl StreamingClient {
299 pub fn new(option: SynthesisOption) -> Self {
300 Self {
301 task_id: Uuid::new_v4().to_string(),
302 option,
303 ws_sink: None,
304 }
305 }
306}
307
308#[async_trait]
309impl SynthesisClient for StreamingClient {
310 fn provider(&self) -> SynthesisType {
311 SynthesisType::Aliyun
312 }
313
314 async fn start(
315 &mut self,
316 ) -> Result<BoxStream<'static, (Option<usize>, Result<SynthesisEvent>)>> {
317 let ws_stream = connect(self.task_id.clone(), self.option.clone()).await?;
318 let (ws_sink, ws_source) = ws_stream.split();
319 self.ws_sink.replace(ws_sink);
320 Ok(event_stream(ws_source).map(move |x| (None, x)).boxed())
321 }
322
323 async fn synthesize(
324 &mut self,
325 text: &str,
326 _cmd_seq: Option<usize>,
327 _option: Option<SynthesisOption>,
328 ) -> Result<()> {
329 if let Some(ws_sink) = self.ws_sink.as_mut() {
330 if !text.is_empty() {
331 let continue_task_cmd = Command::continue_task(self.task_id.as_str(), text);
332 let continue_task_json = serde_json::to_string(&continue_task_cmd)?;
333 ws_sink.send(Message::text(continue_task_json)).await?;
334 }
335 } else {
336 return Err(anyhow!("Aliyun TTS Task: not connected"));
337 }
338
339 Ok(())
340 }
341
342 async fn stop(&mut self) -> Result<()> {
343 if let Some(ws_sink) = self.ws_sink.as_mut() {
344 let finish_task_cmd = Command::finish_task(self.task_id.as_str());
345 let finish_task_json = serde_json::to_string(&finish_task_cmd)?;
346 ws_sink.send(Message::text(finish_task_json)).await?;
347 }
348
349 Ok(())
350 }
351}
352
353pub struct NonStreamingClient {
354 option: SynthesisOption,
355 tx: Option<mpsc::UnboundedSender<(String, Option<usize>, Option<SynthesisOption>)>>,
356}
357
358impl NonStreamingClient {
359 pub fn new(option: SynthesisOption) -> Self {
360 Self { option, tx: None }
361 }
362}
363
364#[async_trait]
365impl SynthesisClient for NonStreamingClient {
366 fn provider(&self) -> SynthesisType {
367 SynthesisType::Aliyun
368 }
369
370 async fn start(
371 &mut self,
372 ) -> Result<BoxStream<'static, (Option<usize>, Result<SynthesisEvent>)>> {
373 let (tx, rx) = mpsc::unbounded_channel();
374 self.tx.replace(tx);
375 let client_option = self.option.clone();
376 let max_concurrent_tasks = client_option.max_concurrent_tasks.unwrap_or(1);
377
378 let stream = UnboundedReceiverStream::new(rx)
379 .flat_map_unordered(max_concurrent_tasks, move |(text, cmd_seq, option)| {
380 let option = client_option.merge_with(option);
381 let task_id = Uuid::new_v4().to_string();
382 let text_clone = text.clone();
383 let task_id_clone = task_id.clone();
384 connect(task_id, option)
385 .then(async move |res| match res {
386 Ok(mut ws_stream) => {
387 let continue_task_cmd =
388 Command::continue_task(task_id_clone.as_str(), text_clone.as_str());
389 let continue_task_json = serde_json::to_string(&continue_task_cmd)
390 .expect("Aliyun TTS API changed!");
391 ws_stream.send(Message::text(continue_task_json)).await.ok();
392 let finish_task_cmd = Command::finish_task(task_id_clone.as_str());
393 let finish_task_json = serde_json::to_string(&finish_task_cmd)
394 .expect("Aliyun TTS API changed!");
395 ws_stream.send(Message::text(finish_task_json)).await.ok();
396 event_stream(ws_stream).boxed()
397 }
398 Err(e) => {
399 warn!("Aliyun TTS: websocket error: {:?}, {:?}", cmd_seq, e);
400 stream::once(future::ready(Err(e.into()))).boxed()
401 }
402 })
403 .flatten_stream()
404 .map(move |x| (cmd_seq, x))
405 .boxed()
406 })
407 .boxed();
408 Ok(stream)
409 }
410
411 async fn synthesize(
412 &mut self,
413 text: &str,
414 cmd_seq: Option<usize>,
415 option: Option<SynthesisOption>,
416 ) -> Result<()> {
417 if let Some(tx) = &self.tx {
418 tx.send((text.to_string(), cmd_seq, option))?;
419 } else {
420 return Err(anyhow!("Aliyun TTS Task: not connected"));
421 }
422 Ok(())
423 }
424
425 async fn stop(&mut self) -> Result<()> {
426 self.tx.take();
427 Ok(())
428 }
429}
430
431pub struct AliyunTtsClient;
432impl AliyunTtsClient {
433 pub fn create(streaming: bool, option: &SynthesisOption) -> Result<Box<dyn SynthesisClient>> {
434 if streaming {
435 Ok(Box::new(StreamingClient::new(option.clone())))
436 } else {
437 Ok(Box::new(NonStreamingClient::new(option.clone())))
438 }
439 }
440}