use crate::{error::DashScopeError, operation::common::Usage};
use base64::prelude::*;
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use std::pin::Pin;
use thiserror::Error;
use tokio_stream::Stream;
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct TextToSpeechOutput {
pub request_id: String,
#[serde(rename = "output")]
pub output: Output,
#[serde(rename = "usage")]
pub usage: Option<Usage>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Output {
pub finish_reason: Option<String>,
pub audio: Audio,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Audio {
pub id: String,
pub url: Option<String>,
pub expires_at: i64,
pub data: String,
}
pub type TextToSpeechOutputStream =
Pin<Box<dyn Stream<Item = Result<TextToSpeechOutput, DashScopeError>> + Send>>;
#[derive(Error, Debug)]
pub enum AudioOutputError {
#[error("Failed to download audio file:{}", 0)]
DownloadError(#[from] reqwest::Error),
#[error("Failed to save audio file:{}", 0)]
SaveError(#[from] std::io::Error),
#[error("Audio url is null")]
NullUrl,
#[error("Failed to decode audio data")]
DataDecodeError,
}
impl Audio {
pub fn get_audio_data(&self) -> String {
self.data.clone()
}
pub fn is_finished(&self) -> bool {
self.url.is_some()
}
pub fn to_vec(&self) -> Result<Vec<u8>, AudioOutputError> {
BASE64_STANDARD
.decode(&self.data)
.map_err(|_| AudioOutputError::DataDecodeError)
}
#[cfg(feature = "wav-decoder")]
pub fn to_wav(&self,sample_rate: u32, num_channels: u16, bits_per_sample: u16) -> Result<Vec<u8>, AudioOutputError> {
use std::io::Cursor;
use hound::{WavSpec, WavWriter};
let pcm_data = self.to_vec()?;
let mut buffer = Cursor::new(Vec::new());
let spec = WavSpec {
channels: num_channels,
sample_rate,
bits_per_sample,
sample_format: hound::SampleFormat::Int,
};
let mut writer = WavWriter::new(&mut buffer, spec).map_err(|e| {
eprintln!("WAV writer error: {e}");
AudioOutputError::DataDecodeError
})?;
match bits_per_sample {
16 => {
for chunk in pcm_data.chunks_exact(2) {
let sample = i16::from_le_bytes([chunk[0], chunk[1]]);
writer.write_sample(sample).map_err(|_| AudioOutputError::DataDecodeError)?;
}
}
8 => {
for &sample in &pcm_data {
writer.write_sample(sample as i8).map_err(|_| AudioOutputError::DataDecodeError)?;
}
}
_ => return Err(AudioOutputError::DataDecodeError),
}
writer.finalize().map_err(|_| AudioOutputError::DataDecodeError)?;
Ok(buffer.into_inner())
}
pub fn bytes(&self) -> Result<Bytes, AudioOutputError> {
Ok(Bytes::copy_from_slice(&self.to_vec()?))
}
pub async fn download(&self, save_path: &str) -> Result<(), AudioOutputError> {
let Some(url) = &self.url else {
return Err(AudioOutputError::NullUrl);
};
let r = reqwest::get(url).await?.bytes().await?;
tokio::fs::write(save_path, r).await?;
Ok(())
}
}
impl TextToSpeechOutput {
pub async fn download(&self, save_path: &str) -> Result<(), AudioOutputError> {
self.output.audio.download(save_path).await
}
pub fn is_finished(&self) -> bool {
self.output.audio.is_finished()
}
}