use {
crate::error::{Error, Result},
base64::prelude::*,
serde::{Deserialize, Serialize},
tokio::fs::read,
};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "lowercase")]
pub enum AudioFormat {
#[default]
Wav,
Mp3,
Pcm,
#[serde(rename = "pcm16")]
Pcm16,
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub enum Voice {
#[default]
MimoDefault,
DefaultEn,
DefaultZh,
Bingtang,
Moli,
Suda,
Baihua,
Mia,
Chloe,
Milo,
Dean,
Custom(String),
}
impl Voice {
pub fn custom<S: Into<String>>(voice: S) -> Self {
Voice::Custom(voice.into())
}
pub async fn from_audio_file<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
let path = path.as_ref();
let data = read(path).await?;
let mime_type = match path.extension().and_then(|ext| ext.to_str()) {
Some("mp3") => "audio/mpeg",
Some("wav") => "audio/wav",
_ => return Err(Error::InvalidParameter("Unsupported audio format".into())),
};
let base64_audio = BASE64_STANDARD.encode(&data);
let voice_str = format!("data:{};base64,{}", mime_type, base64_audio);
Ok(Voice::Custom(voice_str))
}
}
impl Serialize for Voice {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let s = match self {
Voice::MimoDefault => "mimo_default",
Voice::DefaultEn => "default_en",
Voice::DefaultZh => "default_zh",
Voice::Bingtang => "冰糖",
Voice::Moli => "茉莉",
Voice::Suda => "苏打",
Voice::Baihua => "白桦",
Voice::Mia => "Mia",
Voice::Chloe => "Chloe",
Voice::Milo => "Milo",
Voice::Dean => "Dean",
Voice::Custom(s) => s.as_str(),
};
serializer.serialize_str(s)
}
}
impl<'de> Deserialize<'de> for Voice {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
Ok(match s.as_str() {
"mimo_default" => Voice::MimoDefault,
"default_en" => Voice::DefaultEn,
"default_zh" => Voice::DefaultZh,
"冰糖" => Voice::Bingtang,
"茉莉" => Voice::Moli,
"苏打" => Voice::Suda,
"白桦" => Voice::Baihua,
"Mia" => Voice::Mia,
"Chloe" => Voice::Chloe,
"Milo" => Voice::Milo,
"Dean" => Voice::Dean,
_ => Voice::Custom(s),
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Audio {
#[serde(skip_serializing_if = "Option::is_none")]
pub format: Option<AudioFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub voice: Option<Voice>,
}
impl Audio {
pub fn new() -> Self {
Self {
format: None,
voice: None,
}
}
pub fn format(mut self, format: AudioFormat) -> Self {
self.format = Some(format);
self
}
pub fn voice(mut self, voice: Voice) -> Self {
self.voice = Some(voice);
self
}
pub fn wav() -> Self {
Self::new().format(AudioFormat::Wav)
}
pub fn mp3() -> Self {
Self::new().format(AudioFormat::Mp3)
}
pub fn pcm() -> Self {
Self::new().format(AudioFormat::Pcm)
}
pub fn pcm16() -> Self {
Self::new().format(AudioFormat::Pcm16)
}
}
impl Default for Audio {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResponseAudio {
pub id: String,
pub data: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires_at: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub transcript: Option<String>,
}
impl ResponseAudio {
pub fn decode_data(&self) -> Result<Vec<u8>> {
use base64::Engine;
base64::engine::general_purpose::STANDARD
.decode(&self.data)
.map_err(Into::into)
}
pub fn transcript(&self) -> Option<&str> {
self.transcript.as_deref()
}
pub fn is_expired(&self) -> bool {
if let Some(expires_at) = self.expires_at {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
now > expires_at
} else {
false
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeltaAudio {
pub id: String,
pub data: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires_at: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub transcript: Option<String>,
}
impl DeltaAudio {
pub fn decode_data(&self) -> Result<Vec<u8>> {
use base64::Engine;
base64::engine::general_purpose::STANDARD
.decode(&self.data)
.map_err(Into::into)
}
}
#[derive(Debug, Clone, Default)]
pub struct TtsStyle {
styles: Vec<String>,
}
impl TtsStyle {
pub fn new() -> Self {
Self { styles: Vec::new() }
}
pub fn with_style(mut self, style: impl Into<String>) -> Self {
self.styles.push(style.into());
self
}
pub fn apply(&self, text: &str) -> String {
if self.styles.is_empty() {
text.to_string()
} else {
format!("<style>{}</style>{}", self.styles.join(" "), text)
}
}
}
pub fn styled_text(style: &str, text: &str) -> String {
TtsStyle::new().with_style(style).apply(text)
}
#[cfg(test)]
mod tests {
use super::*;
use base64::Engine;
#[test]
fn test_audio_format_default() {
let format = AudioFormat::default();
assert_eq!(format, AudioFormat::Wav);
}
#[test]
fn test_voice_default() {
let voice = Voice::default();
assert_eq!(voice, Voice::MimoDefault);
}
#[test]
fn test_audio_config() {
let audio = Audio::wav().voice(Voice::DefaultZh);
assert_eq!(audio.format, Some(AudioFormat::Wav));
assert_eq!(audio.voice, Some(Voice::DefaultZh));
}
#[test]
fn test_audio_serialization() {
let audio = Audio::mp3().voice(Voice::DefaultEn);
let json = serde_json::to_string(&audio).unwrap();
assert!(json.contains("\"format\":\"mp3\""));
assert!(json.contains("\"voice\":\"default_en\""));
}
#[test]
fn test_audio_formats() {
assert_eq!(Audio::wav().format, Some(AudioFormat::Wav));
assert_eq!(Audio::mp3().format, Some(AudioFormat::Mp3));
assert_eq!(Audio::pcm().format, Some(AudioFormat::Pcm));
}
#[test]
fn test_tts_style_single() {
let text = TtsStyle::new().with_style("开心").apply("Hello");
assert_eq!(text, "<style>开心</style>Hello");
}
#[test]
fn test_tts_style_multiple() {
let text = TtsStyle::new()
.with_style("开心")
.with_style("变快")
.apply("Hello");
assert!(text.starts_with("<style>"));
assert!(text.contains("开心"));
assert!(text.contains("变快"));
assert!(text.ends_with("Hello"));
}
#[test]
fn test_tts_style_empty() {
let text = TtsStyle::new().apply("Hello");
assert_eq!(text, "Hello");
}
#[test]
fn test_styled_text_helper() {
let text = styled_text("东北话", "哎呀妈呀");
assert_eq!(text, "<style>东北话</style>哎呀妈呀");
}
#[test]
fn test_response_audio_decode() {
let audio = ResponseAudio {
id: "test-id".to_string(),
data: base64::engine::general_purpose::STANDARD.encode(b"test audio data"),
expires_at: None,
transcript: Some("test".to_string()),
};
let decoded = audio.decode_data().unwrap();
assert_eq!(decoded, b"test audio data");
}
#[test]
fn test_response_audio_transcript() {
let audio = ResponseAudio {
id: "test-id".to_string(),
data: String::new(),
expires_at: None,
transcript: Some("Hello world".to_string()),
};
assert_eq!(audio.transcript(), Some("Hello world"));
}
}