async_dashscope/operation/audio/tts/
output.rs

1use crate::{error::DashScopeError, operation::common::Usage};
2use base64::prelude::*;
3use bytes::Bytes;
4use serde::{Deserialize, Serialize};
5use std::pin::Pin;
6use thiserror::Error;
7use tokio_stream::Stream;
8
9#[derive(Serialize, Deserialize, Debug, Clone)]
10pub struct TextToSpeechOutput {
11    pub request_id: String,
12    /// 调用结果信息。
13    #[serde(rename = "output")]
14    pub output: Output,
15    /// 本次chat请求使用的token信息。
16    #[serde(rename = "usage")]
17    pub usage: Option<Usage>,
18}
19
20#[derive(Serialize, Deserialize, Debug, Clone)]
21pub struct Output {
22    /// 有两种情况:
23    /// - 正在生成时为"null";
24    /// - 因模型输出自然结束,或触发输入参数中的stop条件而结束时为"stop"。
25    pub finish_reason: Option<String>,
26    /// 模型输出的音频信息。
27    pub audio: Audio,
28}
29
30#[derive(Serialize, Deserialize, Debug, Clone)]
31pub struct Audio {
32    pub id: String,
33    /// 模型输出的完整音频文件的URL,有效期24小时。
34    pub url: Option<String>,
35    /// url 将要过期的时间戳。
36    pub expires_at: i64,
37    /// 流式输出时的Base64 音频数据。
38    pub data: String,
39}
40
41pub type TextToSpeechOutputStream =
42    Pin<Box<dyn Stream<Item = Result<TextToSpeechOutput, DashScopeError>> + Send>>;
43
44#[derive(Error, Debug)]
45pub enum AudioOutputError {
46    #[error("Failed to download audio file:{}", 0)]
47    DownloadError(#[from] reqwest::Error),
48    #[error("Failed to save audio file:{}", 0)]
49    SaveError(#[from] std::io::Error),
50    #[error("Audio url is null")]
51    NullUrl,
52    #[error("Failed to decode audio data")]
53    DataDecodeError,
54}
55
56impl Audio {
57    pub fn get_audio_data(&self) -> String {
58        self.data.clone()
59    }
60
61    pub fn is_finished(&self) -> bool {
62        self.url.is_some()
63    }
64
65    /// 注意这是一个 pcm 数据,需要解码后才能播放
66    pub fn to_vec(&self) -> Result<Vec<u8>, AudioOutputError> {
67        BASE64_STANDARD
68            .decode(&self.data)
69            .map_err(|_| AudioOutputError::DataDecodeError)
70    }
71
72    #[cfg(feature = "wav-decoder")]
73    pub fn to_wav(&self,sample_rate: u32, num_channels: u16, bits_per_sample: u16) -> Result<Vec<u8>, AudioOutputError> {
74        use std::io::Cursor;
75        use hound::{WavSpec, WavWriter};
76
77        let pcm_data = self.to_vec()?;
78        let mut buffer = Cursor::new(Vec::new());
79        let spec = WavSpec {
80            channels: num_channels,
81            sample_rate,
82            bits_per_sample,
83            sample_format: hound::SampleFormat::Int,
84        };
85
86        let mut writer = WavWriter::new(&mut buffer, spec).map_err(|e| {
87            eprintln!("WAV writer error: {e}");
88            AudioOutputError::DataDecodeError
89        })?;
90
91        // 根据位深度写入PCM数据
92        match bits_per_sample {
93            16 => {
94                // 将字节转换为i16样本
95                for chunk in pcm_data.chunks_exact(2) {
96                    let sample = i16::from_le_bytes([chunk[0], chunk[1]]);
97                    writer.write_sample(sample).map_err(|_| AudioOutputError::DataDecodeError)?;
98                }
99            }
100            8 => {
101                // 直接写入u8样本
102                for &sample in &pcm_data {
103                    writer.write_sample(sample as i8).map_err(|_| AudioOutputError::DataDecodeError)?;
104                }
105            }
106            _ => return Err(AudioOutputError::DataDecodeError),
107        }
108
109        // 完成写入并返回WAV数据
110        writer.finalize().map_err(|_| AudioOutputError::DataDecodeError)?;
111        Ok(buffer.into_inner())
112    }
113
114    pub fn bytes(&self) -> Result<Bytes, AudioOutputError> {
115        Ok(Bytes::copy_from_slice(&self.to_vec()?))
116    }
117
118    pub async fn download(&self, save_path: &str) -> Result<(), AudioOutputError> {
119        let Some(url) = &self.url else {
120            return Err(AudioOutputError::NullUrl);
121        };
122        let r = reqwest::get(url).await?.bytes().await?;
123
124        // save file
125        tokio::fs::write(save_path, r).await?;
126
127        Ok(())
128    }
129}
130
131impl TextToSpeechOutput {
132    pub async fn download(&self, save_path: &str) -> Result<(), AudioOutputError> {
133        self.output.audio.download(save_path).await
134    }
135
136    pub fn is_finished(&self) -> bool {
137        self.output.audio.is_finished()
138    }
139}