active_call/synthesis/
aliyun.rs

1use super::{SynthesisClient, SynthesisOption, SynthesisType};
2use crate::synthesis::SynthesisEvent;
3use anyhow::{Result, anyhow};
4use async_trait::async_trait;
5use futures::{
6    FutureExt, SinkExt, Stream, StreamExt, future,
7    stream::{self, BoxStream, SplitSink},
8};
9use serde::{Deserialize, Serialize};
10use serde_with::skip_serializing_none;
11use std::sync::Arc;
12use tokio::{
13    net::TcpStream,
14    sync::{Notify, mpsc},
15};
16use tokio_stream::wrappers::UnboundedReceiverStream;
17use tokio_tungstenite::{
18    MaybeTlsStream, WebSocketStream, connect_async,
19    tungstenite::{self, Message, client::IntoClientRequest},
20};
21use tracing::warn;
22use uuid::Uuid;
23type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
24type WsSink = SplitSink<WsStream, Message>;
25
26/// Aliyun CosyVoice WebSocket API Client
27/// https://help.aliyun.com/zh/model-studio/cosyvoice-websocket-api
28
29#[derive(Debug, Serialize)]
30struct Command {
31    header: CommandHeader,
32    payload: CommandPayload,
33}
34
35#[derive(Debug, Serialize)]
36#[serde(untagged)]
37enum CommandPayload {
38    Run(RunTaskPayload),
39    Continue(ContinueTaskPayload),
40    Finish(FinishTaskPayload),
41}
42
43impl Command {
44    fn run_task(option: &SynthesisOption, task_id: &str) -> Self {
45        let voice = option
46            .speaker
47            .clone()
48            .unwrap_or_else(|| "longyumi_v2".to_string());
49
50        let format = option.codec.as_deref().unwrap_or("pcm");
51
52        let sample_rate = option.samplerate.unwrap_or(16000) as u32;
53        let volume = option.volume.unwrap_or(50) as u32;
54        let rate = option.speed.unwrap_or(1.0);
55        let model = option
56            .model
57            .clone()
58            .unwrap_or_else(|| "cosyvoice-v2".to_string());
59
60        Command {
61            header: CommandHeader {
62                action: "run-task".to_string(),
63                task_id: task_id.to_string(),
64                streaming: "duplex".to_string(),
65            },
66            payload: CommandPayload::Run(RunTaskPayload {
67                task_group: "audio".to_string(),
68                task: "tts".to_string(),
69                function: "SpeechSynthesizer".to_string(),
70                model,
71                parameters: RunTaskParameters {
72                    text_type: "PlainText".to_string(),
73                    voice,
74                    format: Some(format.to_string()),
75                    sample_rate: Some(sample_rate),
76                    volume: Some(volume),
77                    rate: Some(rate),
78                },
79                input: EmptyInput {},
80            }),
81        }
82    }
83
84    fn continue_task(task_id: &str, text: &str) -> Self {
85        Command {
86            header: CommandHeader {
87                action: "continue-task".to_string(),
88                task_id: task_id.to_string(),
89                streaming: "duplex".to_string(),
90            },
91            payload: CommandPayload::Continue(ContinueTaskPayload {
92                input: PayloadInput {
93                    text: text.to_string(),
94                },
95            }),
96        }
97    }
98
99    fn finish_task(task_id: &str) -> Self {
100        Command {
101            header: CommandHeader {
102                action: "finish-task".to_string(),
103                task_id: task_id.to_string(),
104                streaming: "duplex".to_string(),
105            },
106            payload: CommandPayload::Finish(FinishTaskPayload {
107                input: EmptyInput {},
108            }),
109        }
110    }
111}
112
113#[derive(Debug, Serialize)]
114struct CommandHeader {
115    action: String,
116    task_id: String,
117    streaming: String,
118}
119
120#[derive(Debug, Serialize)]
121struct RunTaskPayload {
122    task_group: String,
123    task: String,
124    function: String,
125    model: String,
126    parameters: RunTaskParameters,
127    input: EmptyInput,
128}
129
130#[skip_serializing_none]
131#[derive(Debug, Serialize)]
132struct RunTaskParameters {
133    text_type: String,
134    voice: String,
135    format: Option<String>,
136    sample_rate: Option<u32>,
137    volume: Option<u32>,
138    rate: Option<f32>,
139}
140
141#[derive(Debug, Serialize)]
142struct ContinueTaskPayload {
143    input: PayloadInput,
144}
145
146#[derive(Debug, Serialize, Deserialize)]
147struct PayloadInput {
148    text: String,
149}
150
151#[derive(Debug, Serialize)]
152struct FinishTaskPayload {
153    input: EmptyInput,
154}
155
156#[derive(Debug, Serialize)]
157struct EmptyInput {}
158
159/// WebSocket event response structure
160#[derive(Debug, Deserialize)]
161struct Event {
162    header: EventHeader,
163}
164
165#[allow(dead_code)]
166#[derive(Debug, Deserialize)]
167struct EventHeader {
168    task_id: String,
169    event: String,
170    error_code: Option<String>,
171    error_message: Option<String>,
172}
173
174async fn connect(task_id: String, option: SynthesisOption) -> Result<WsStream> {
175    let api_key = option
176        .secret_key
177        .clone()
178        .or_else(|| std::env::var("DASHSCOPE_API_KEY").ok())
179        .ok_or_else(|| anyhow!("Aliyun TTS: missing api key"))?;
180    let ws_url = option
181        .endpoint
182        .as_deref()
183        .unwrap_or("wss://dashscope.aliyuncs.com/api-ws/v1/inference");
184
185    let mut request = ws_url.into_client_request()?;
186    let headers = request.headers_mut();
187    headers.insert("Authorization", format!("Bearer {}", api_key).parse()?);
188    headers.insert("X-DashScope-DataInspection", "enable".parse()?);
189
190    let (mut ws_stream, _) = connect_async(request).await?;
191    let run_task_cmd = Command::run_task(&option, task_id.as_str());
192    let run_task_json = serde_json::to_string(&run_task_cmd)?;
193    ws_stream.send(Message::text(run_task_json)).await?;
194    while let Some(message) = ws_stream.next().await {
195        match message {
196            Ok(Message::Text(text)) => {
197                let event = serde_json::from_str::<Event>(&text)?;
198                match event.header.event.as_str() {
199                    "task-started" => {
200                        break;
201                    }
202                    "task-failed" => {
203                        let error_code = event
204                            .header
205                            .error_code
206                            .unwrap_or_else(|| "Unknown error code".to_string());
207                        let error_msg = event
208                            .header
209                            .error_message
210                            .unwrap_or_else(|| "Unknown error message".to_string());
211                        return Err(anyhow!(
212                            "Aliyun TTS Task: {} failed: {}, {}",
213                            task_id,
214                            error_code,
215                            error_msg
216                        ))?;
217                    }
218                    _ => {
219                        warn!("Aliyun TTS Task: {} unexpected event: {:?}", task_id, event);
220                    }
221                }
222            }
223            Ok(Message::Close(_)) => {
224                return Err(anyhow!("Aliyun TTS start failed: closed by server"));
225            }
226            Err(e) => {
227                return Err(anyhow!("Aliyun TTS start failed:: {}", e));
228            }
229            _ => {}
230        }
231    }
232    Ok(ws_stream)
233}
234
235fn event_stream<T>(ws_stream: T) -> impl Stream<Item = Result<SynthesisEvent>> + Send + 'static
236where
237    T: Stream<Item = Result<Message, tungstenite::Error>> + Send + 'static,
238{
239    let notify = Arc::new(Notify::new());
240    let notify_clone = notify.clone();
241    ws_stream
242        .take_until(notify.notified_owned())
243        .filter_map(move |message| {
244            let notify = notify_clone.clone();
245            async move {
246                match message {
247                    Ok(Message::Binary(data)) => Some(Ok(SynthesisEvent::AudioChunk(data))),
248                    Ok(Message::Text(text)) => {
249                        let event: Event =
250                            serde_json::from_str(&text).expect("Aliyun TTS API changed!");
251
252                        match event.header.event.as_str() {
253                            "task-finished" => {
254                                notify.notify_one();
255                                Some(Ok(SynthesisEvent::Finished))
256                            }
257                            "task-failed" => {
258                                let error_code = event
259                                    .header
260                                    .error_code
261                                    .unwrap_or_else(|| "Unknown error code".to_string());
262                                let error_msg = event
263                                    .header
264                                    .error_message
265                                    .unwrap_or_else(|| "Unknown error message".to_string());
266                                notify.notify_one();
267                                Some(Err(anyhow!(
268                                    "Aliyun TTS Task: {} failed: {}, {}",
269                                    event.header.task_id,
270                                    error_code,
271                                    error_msg
272                                )))
273                            }
274                            _ => None,
275                        }
276                    }
277                    Ok(Message::Close(_)) => {
278                        notify.notify_one();
279                        warn!("Aliyun TTS: closed by remote");
280                        None
281                    }
282                    Err(e) => {
283                        notify.notify_one();
284                        Some(Err(anyhow!("Aliyun TTS: websocket error: {:?}", e)))
285                    }
286                    _ => None,
287                }
288            }
289        })
290}
291#[derive(Debug)]
292pub struct StreamingClient {
293    task_id: String,
294    option: SynthesisOption,
295    ws_sink: Option<WsSink>,
296}
297
298impl StreamingClient {
299    pub fn new(option: SynthesisOption) -> Self {
300        Self {
301            task_id: Uuid::new_v4().to_string(),
302            option,
303            ws_sink: None,
304        }
305    }
306}
307
308#[async_trait]
309impl SynthesisClient for StreamingClient {
310    fn provider(&self) -> SynthesisType {
311        SynthesisType::Aliyun
312    }
313
314    async fn start(
315        &mut self,
316    ) -> Result<BoxStream<'static, (Option<usize>, Result<SynthesisEvent>)>> {
317        let ws_stream = connect(self.task_id.clone(), self.option.clone()).await?;
318        let (ws_sink, ws_source) = ws_stream.split();
319        self.ws_sink.replace(ws_sink);
320        Ok(event_stream(ws_source).map(move |x| (None, x)).boxed())
321    }
322
323    async fn synthesize(
324        &mut self,
325        text: &str,
326        _cmd_seq: Option<usize>,
327        _option: Option<SynthesisOption>,
328    ) -> Result<()> {
329        if let Some(ws_sink) = self.ws_sink.as_mut() {
330            if !text.is_empty() {
331                let continue_task_cmd = Command::continue_task(self.task_id.as_str(), text);
332                let continue_task_json = serde_json::to_string(&continue_task_cmd)?;
333                ws_sink.send(Message::text(continue_task_json)).await?;
334            }
335        } else {
336            return Err(anyhow!("Aliyun TTS Task: not connected"));
337        }
338
339        Ok(())
340    }
341
342    async fn stop(&mut self) -> Result<()> {
343        if let Some(ws_sink) = self.ws_sink.as_mut() {
344            let finish_task_cmd = Command::finish_task(self.task_id.as_str());
345            let finish_task_json = serde_json::to_string(&finish_task_cmd)?;
346            ws_sink.send(Message::text(finish_task_json)).await?;
347        }
348
349        Ok(())
350    }
351}
352
353pub struct NonStreamingClient {
354    option: SynthesisOption,
355    tx: Option<mpsc::UnboundedSender<(String, Option<usize>, Option<SynthesisOption>)>>,
356}
357
358impl NonStreamingClient {
359    pub fn new(option: SynthesisOption) -> Self {
360        Self { option, tx: None }
361    }
362}
363
364#[async_trait]
365impl SynthesisClient for NonStreamingClient {
366    fn provider(&self) -> SynthesisType {
367        SynthesisType::Aliyun
368    }
369
370    async fn start(
371        &mut self,
372    ) -> Result<BoxStream<'static, (Option<usize>, Result<SynthesisEvent>)>> {
373        let (tx, rx) = mpsc::unbounded_channel();
374        self.tx.replace(tx);
375        let client_option = self.option.clone();
376        let max_concurrent_tasks = client_option.max_concurrent_tasks.unwrap_or(1);
377
378        let stream = UnboundedReceiverStream::new(rx)
379            .flat_map_unordered(max_concurrent_tasks, move |(text, cmd_seq, option)| {
380                let option = client_option.merge_with(option);
381                let task_id = Uuid::new_v4().to_string();
382                let text_clone = text.clone();
383                let task_id_clone = task_id.clone();
384                connect(task_id, option)
385                    .then(async move |res| match res {
386                        Ok(mut ws_stream) => {
387                            let continue_task_cmd =
388                                Command::continue_task(task_id_clone.as_str(), text_clone.as_str());
389                            let continue_task_json = serde_json::to_string(&continue_task_cmd)
390                                .expect("Aliyun TTS API changed!");
391                            ws_stream.send(Message::text(continue_task_json)).await.ok();
392                            let finish_task_cmd = Command::finish_task(task_id_clone.as_str());
393                            let finish_task_json = serde_json::to_string(&finish_task_cmd)
394                                .expect("Aliyun TTS API changed!");
395                            ws_stream.send(Message::text(finish_task_json)).await.ok();
396                            event_stream(ws_stream).boxed()
397                        }
398                        Err(e) => {
399                            warn!("Aliyun TTS: websocket error: {:?}, {:?}", cmd_seq, e);
400                            stream::once(future::ready(Err(e.into()))).boxed()
401                        }
402                    })
403                    .flatten_stream()
404                    .map(move |x| (cmd_seq, x))
405                    .boxed()
406            })
407            .boxed();
408        Ok(stream)
409    }
410
411    async fn synthesize(
412        &mut self,
413        text: &str,
414        cmd_seq: Option<usize>,
415        option: Option<SynthesisOption>,
416    ) -> Result<()> {
417        if let Some(tx) = &self.tx {
418            tx.send((text.to_string(), cmd_seq, option))?;
419        } else {
420            return Err(anyhow!("Aliyun TTS Task: not connected"));
421        }
422        Ok(())
423    }
424
425    async fn stop(&mut self) -> Result<()> {
426        self.tx.take();
427        Ok(())
428    }
429}
430
431pub struct AliyunTtsClient;
432impl AliyunTtsClient {
433    pub fn create(streaming: bool, option: &SynthesisOption) -> Result<Box<dyn SynthesisClient>> {
434        if streaming {
435            Ok(Box::new(StreamingClient::new(option.clone())))
436        } else {
437            Ok(Box::new(NonStreamingClient::new(option.clone())))
438        }
439    }
440}