#[allow(deprecated)]
use crate::client::audio_generation::AudioGenerationModelHandle;
use crate::{
http_client,
wasm_compat::{WasmCompatSend, WasmCompatSync},
};
use futures::future::BoxFuture;
use serde_json::Value;
use std::sync::Arc;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum AudioGenerationError {
#[error("HttpError: {0}")]
HttpError(#[from] http_client::Error),
#[error("JsonError: {0}")]
JsonError(#[from] serde_json::Error),
#[error("RequestError: {0}")]
RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("ResponseError: {0}")]
ResponseError(String),
#[error("ProviderError: {0}")]
ProviderError(String),
}
pub trait AudioGeneration<M>
where
M: AudioGenerationModel,
{
fn audio_generation(
&self,
text: &str,
voice: &str,
) -> impl std::future::Future<
Output = Result<AudioGenerationRequestBuilder<M>, AudioGenerationError>,
> + Send;
}
pub struct AudioGenerationResponse<T> {
pub audio: Vec<u8>,
pub response: T,
}
pub trait AudioGenerationModel: Sized + Clone + WasmCompatSend + WasmCompatSync {
type Response: Send + Sync;
type Client;
fn make(client: &Self::Client, model: impl Into<String>) -> Self;
fn audio_generation(
&self,
request: AudioGenerationRequest,
) -> impl std::future::Future<
Output = Result<AudioGenerationResponse<Self::Response>, AudioGenerationError>,
> + Send;
fn audio_generation_request(&self) -> AudioGenerationRequestBuilder<Self> {
AudioGenerationRequestBuilder::new(self.clone())
}
}
#[allow(deprecated)]
#[deprecated(
since = "0.25.0",
note = "`DynClientBuilder` and related features have been deprecated and will be removed in a future release. In this case, use `AudioGenerationModel` instead."
)]
pub trait AudioGenerationModelDyn: Send + Sync {
fn audio_generation(
&self,
request: AudioGenerationRequest,
) -> BoxFuture<'_, Result<AudioGenerationResponse<()>, AudioGenerationError>>;
fn audio_generation_request(
&self,
) -> AudioGenerationRequestBuilder<AudioGenerationModelHandle<'_>>;
}
#[allow(deprecated)]
impl<T> AudioGenerationModelDyn for T
where
T: AudioGenerationModel,
{
fn audio_generation(
&self,
request: AudioGenerationRequest,
) -> BoxFuture<'_, Result<AudioGenerationResponse<()>, AudioGenerationError>> {
Box::pin(async move {
let resp = self.audio_generation(request).await;
resp.map(|r| AudioGenerationResponse {
audio: r.audio,
response: (),
})
})
}
fn audio_generation_request(
&self,
) -> AudioGenerationRequestBuilder<AudioGenerationModelHandle<'_>> {
AudioGenerationRequestBuilder::new(AudioGenerationModelHandle {
inner: Arc::new(self.clone()),
})
}
}
#[non_exhaustive]
pub struct AudioGenerationRequest {
pub text: String,
pub voice: String,
pub speed: f32,
pub additional_params: Option<Value>,
}
#[non_exhaustive]
pub struct AudioGenerationRequestBuilder<M>
where
M: AudioGenerationModel,
{
model: M,
text: String,
voice: String,
speed: f32,
additional_params: Option<Value>,
}
impl<M> AudioGenerationRequestBuilder<M>
where
M: AudioGenerationModel,
{
pub fn new(model: M) -> Self {
Self {
model,
text: "".to_string(),
voice: "".to_string(),
speed: 1.0,
additional_params: None,
}
}
pub fn text(mut self, text: &str) -> Self {
self.text = text.to_string();
self
}
pub fn voice(mut self, voice: &str) -> Self {
self.voice = voice.to_string();
self
}
pub fn speed(mut self, speed: f32) -> Self {
self.speed = speed;
self
}
pub fn additional_params(mut self, params: Value) -> Self {
self.additional_params = Some(params);
self
}
pub fn build(self) -> AudioGenerationRequest {
AudioGenerationRequest {
text: self.text,
voice: self.voice,
speed: self.speed,
additional_params: self.additional_params,
}
}
pub async fn send(self) -> Result<AudioGenerationResponse<M::Response>, AudioGenerationError> {
let model = self.model.clone();
model.audio_generation(self.build()).await
}
}