use crate::audio_generation::{
self, AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
};
use crate::http_client::{self, HttpClientExt};
use crate::json_utils::merge_inplace;
use crate::providers::xai::Client;
use bytes::Bytes;
use serde_json::json;
pub const TTS_1: &str = "tts-1";
#[derive(Clone)]
pub struct AudioGenerationModel<T = reqwest::Client> {
client: Client<T>,
pub model: String,
}
impl<T> AudioGenerationModel<T> {
pub(crate) fn new(client: Client<T>, model: impl Into<String>) -> Self {
Self {
client,
model: model.into(),
}
}
}
impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
where
T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
{
type Response = Bytes;
type Client = Client<T>;
fn make(client: &Self::Client, model: impl Into<String>) -> Self {
Self::new(client.clone(), model)
}
async fn audio_generation(
&self,
request: AudioGenerationRequest,
) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
let voice = if request.voice.is_empty() {
"eve".to_string()
} else {
request.voice
};
let mut body = json!({
"text": request.text,
"voice_id": voice,
"language": "en",
});
if let Some(additional_params) = request.additional_params {
merge_inplace(&mut body, additional_params);
}
let body = serde_json::to_vec(&body)?;
let req = self
.client
.post("/v1/tts")?
.body(body)
.map_err(http_client::Error::from)?;
let response = self.client.send(req).await?;
if !response.status().is_success() {
let status = response.status();
let text = http_client::text(response).await?;
return Err(AudioGenerationError::ProviderError(format!(
"{}: {}",
status, text,
)));
}
let bytes: Bytes = response.into_body().await?.into();
Ok(AudioGenerationResponse {
audio: bytes.to_vec(),
response: bytes,
})
}
}