1use {
7 crate::error::{Error, Result},
8 base64::prelude::*,
9 serde::{Deserialize, Serialize},
10 tokio::fs::read,
11};
12
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
15#[serde(rename_all = "lowercase")]
16pub enum AudioFormat {
17 #[default]
19 Wav,
20 Mp3,
22 Pcm,
24 #[serde(rename = "pcm16")]
26 Pcm16,
27}
28
29#[derive(Debug, Clone, Default, PartialEq, Eq)]
34pub enum Voice {
35 #[default]
37 MimoDefault,
38 DefaultEn,
40 DefaultZh,
42 Bingtang,
44 Moli,
46 Suda,
48 Baihua,
50 Mia,
52 Chloe,
54 Milo,
56 Dean,
58 Custom(String),
60}
61
62impl Voice {
63 pub fn custom<S: Into<String>>(voice: S) -> Self {
67 Voice::Custom(voice.into())
68 }
69
70 pub async fn from_audio_file<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
75 let path = path.as_ref();
76 let data = read(path).await?;
77
78 let mime_type = match path.extension().and_then(|ext| ext.to_str()) {
79 Some("mp3") => "audio/mpeg",
80 Some("wav") => "audio/wav",
81 _ => return Err(Error::InvalidParameter("Unsupported audio format".into())),
82 };
83
84 let base64_audio = BASE64_STANDARD.encode(&data);
85 let voice_str = format!("data:{};base64,{}", mime_type, base64_audio);
86
87 Ok(Voice::Custom(voice_str))
88 }
89}
90
91impl Serialize for Voice {
93 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
94 where
95 S: serde::Serializer,
96 {
97 let s = match self {
98 Voice::MimoDefault => "mimo_default",
99 Voice::DefaultEn => "default_en",
100 Voice::DefaultZh => "default_zh",
101 Voice::Bingtang => "冰糖",
102 Voice::Moli => "茉莉",
103 Voice::Suda => "苏打",
104 Voice::Baihua => "白桦",
105 Voice::Mia => "Mia",
106 Voice::Chloe => "Chloe",
107 Voice::Milo => "Milo",
108 Voice::Dean => "Dean",
109 Voice::Custom(s) => s.as_str(),
110 };
111 serializer.serialize_str(s)
112 }
113}
114
115impl<'de> Deserialize<'de> for Voice {
117 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
118 where
119 D: serde::Deserializer<'de>,
120 {
121 let s = String::deserialize(deserializer)?;
122 Ok(match s.as_str() {
123 "mimo_default" => Voice::MimoDefault,
124 "default_en" => Voice::DefaultEn,
125 "default_zh" => Voice::DefaultZh,
126 "冰糖" => Voice::Bingtang,
127 "茉莉" => Voice::Moli,
128 "苏打" => Voice::Suda,
129 "白桦" => Voice::Baihua,
130 "Mia" => Voice::Mia,
131 "Chloe" => Voice::Chloe,
132 "Milo" => Voice::Milo,
133 "Dean" => Voice::Dean,
134 _ => Voice::Custom(s),
135 })
136 }
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct Audio {
142 #[serde(skip_serializing_if = "Option::is_none")]
144 pub format: Option<AudioFormat>,
145 #[serde(skip_serializing_if = "Option::is_none")]
147 pub voice: Option<Voice>,
148}
149
150impl Audio {
151 pub fn new() -> Self {
163 Self {
164 format: None,
165 voice: None,
166 }
167 }
168
169 pub fn format(mut self, format: AudioFormat) -> Self {
171 self.format = Some(format);
172 self
173 }
174
175 pub fn voice(mut self, voice: Voice) -> Self {
177 self.voice = Some(voice);
178 self
179 }
180
181 pub fn wav() -> Self {
183 Self::new().format(AudioFormat::Wav)
184 }
185
186 pub fn mp3() -> Self {
188 Self::new().format(AudioFormat::Mp3)
189 }
190
191 pub fn pcm() -> Self {
193 Self::new().format(AudioFormat::Pcm)
194 }
195
196 pub fn pcm16() -> Self {
198 Self::new().format(AudioFormat::Pcm16)
199 }
200}
201
202impl Default for Audio {
203 fn default() -> Self {
204 Self::new()
205 }
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct ResponseAudio {
211 pub id: String,
213 pub data: String,
215 #[serde(skip_serializing_if = "Option::is_none")]
217 pub expires_at: Option<i64>,
218 #[serde(skip_serializing_if = "Option::is_none")]
220 pub transcript: Option<String>,
221}
222
223impl ResponseAudio {
224 pub fn decode_data(&self) -> Result<Vec<u8>> {
247 use base64::Engine;
248 base64::engine::general_purpose::STANDARD
249 .decode(&self.data)
250 .map_err(Into::into)
251 }
252
253 pub fn transcript(&self) -> Option<&str> {
255 self.transcript.as_deref()
256 }
257
258 pub fn is_expired(&self) -> bool {
260 if let Some(expires_at) = self.expires_at {
261 let now = std::time::SystemTime::now()
262 .duration_since(std::time::UNIX_EPOCH)
263 .unwrap()
264 .as_secs() as i64;
265 now > expires_at
266 } else {
267 false
268 }
269 }
270}
271
272#[derive(Debug, Clone, Serialize, Deserialize)]
274pub struct DeltaAudio {
275 pub id: String,
277 pub data: String,
279 #[serde(skip_serializing_if = "Option::is_none")]
281 pub expires_at: Option<i64>,
282 #[serde(skip_serializing_if = "Option::is_none")]
284 pub transcript: Option<String>,
285}
286
287impl DeltaAudio {
288 pub fn decode_data(&self) -> Result<Vec<u8>> {
290 use base64::Engine;
291 base64::engine::general_purpose::STANDARD
292 .decode(&self.data)
293 .map_err(Into::into)
294 }
295}
296
297#[derive(Debug, Clone, Default)]
302pub struct TtsStyle {
303 styles: Vec<String>,
304}
305
306impl TtsStyle {
307 pub fn new() -> Self {
309 Self { styles: Vec::new() }
310 }
311
312 pub fn with_style(mut self, style: impl Into<String>) -> Self {
336 self.styles.push(style.into());
337 self
338 }
339
340 pub fn apply(&self, text: &str) -> String {
344 if self.styles.is_empty() {
345 text.to_string()
346 } else {
347 format!("<style>{}</style>{}", self.styles.join(" "), text)
348 }
349 }
350}
351
352pub fn styled_text(style: &str, text: &str) -> String {
363 TtsStyle::new().with_style(style).apply(text)
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369 use base64::Engine;
370
371 #[test]
372 fn test_audio_format_default() {
373 let format = AudioFormat::default();
374 assert_eq!(format, AudioFormat::Wav);
375 }
376
377 #[test]
378 fn test_voice_default() {
379 let voice = Voice::default();
380 assert_eq!(voice, Voice::MimoDefault);
381 }
382
383 #[test]
384 fn test_audio_config() {
385 let audio = Audio::wav().voice(Voice::DefaultZh);
386 assert_eq!(audio.format, Some(AudioFormat::Wav));
387 assert_eq!(audio.voice, Some(Voice::DefaultZh));
388 }
389
390 #[test]
391 fn test_audio_serialization() {
392 let audio = Audio::mp3().voice(Voice::DefaultEn);
393 let json = serde_json::to_string(&audio).unwrap();
394 assert!(json.contains("\"format\":\"mp3\""));
395 assert!(json.contains("\"voice\":\"default_en\""));
396 }
397
398 #[test]
399 fn test_audio_formats() {
400 assert_eq!(Audio::wav().format, Some(AudioFormat::Wav));
401 assert_eq!(Audio::mp3().format, Some(AudioFormat::Mp3));
402 assert_eq!(Audio::pcm().format, Some(AudioFormat::Pcm));
403 }
404
405 #[test]
406 fn test_tts_style_single() {
407 let text = TtsStyle::new().with_style("开心").apply("Hello");
408 assert_eq!(text, "<style>开心</style>Hello");
409 }
410
411 #[test]
412 fn test_tts_style_multiple() {
413 let text = TtsStyle::new()
414 .with_style("开心")
415 .with_style("变快")
416 .apply("Hello");
417 assert!(text.starts_with("<style>"));
418 assert!(text.contains("开心"));
419 assert!(text.contains("变快"));
420 assert!(text.ends_with("Hello"));
421 }
422
423 #[test]
424 fn test_tts_style_empty() {
425 let text = TtsStyle::new().apply("Hello");
426 assert_eq!(text, "Hello");
427 }
428
429 #[test]
430 fn test_styled_text_helper() {
431 let text = styled_text("东北话", "哎呀妈呀");
432 assert_eq!(text, "<style>东北话</style>哎呀妈呀");
433 }
434
435 #[test]
436 fn test_response_audio_decode() {
437 let audio = ResponseAudio {
438 id: "test-id".to_string(),
439 data: base64::engine::general_purpose::STANDARD.encode(b"test audio data"),
440 expires_at: None,
441 transcript: Some("test".to_string()),
442 };
443
444 let decoded = audio.decode_data().unwrap();
445 assert_eq!(decoded, b"test audio data");
446 }
447
448 #[test]
449 fn test_response_audio_transcript() {
450 let audio = ResponseAudio {
451 id: "test-id".to_string(),
452 data: String::new(),
453 expires_at: None,
454 transcript: Some("Hello world".to_string()),
455 };
456
457 assert_eq!(audio.transcript(), Some("Hello world"));
458 }
459}