use crate::markers::{Missing, Provided};
use crate::{
http_client,
wasm_compat::{WasmCompatSend, WasmCompatSync},
};
use serde_json::Value;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum AudioGenerationError {
#[error("HttpError: {0}")]
HttpError(#[from] http_client::Error),
#[error("JsonError: {0}")]
JsonError(#[from] serde_json::Error),
#[error("RequestError: {0}")]
RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("ResponseError: {0}")]
ResponseError(String),
#[error("ProviderError: {0}")]
ProviderError(String),
}
pub trait AudioGeneration<M>
where
M: AudioGenerationModel,
{
fn audio_generation(
&self,
text: &str,
voice: &str,
) -> impl std::future::Future<
Output = Result<
AudioGenerationRequestBuilder<M, Provided<String>, Provided<String>>,
AudioGenerationError,
>,
> + Send;
}
pub struct AudioGenerationResponse<T> {
pub audio: Vec<u8>,
pub response: T,
}
pub trait AudioGenerationModel: Sized + Clone + WasmCompatSend + WasmCompatSync {
type Response: Send + Sync;
type Client;
fn make(client: &Self::Client, model: impl Into<String>) -> Self;
fn audio_generation(
&self,
request: AudioGenerationRequest,
) -> impl std::future::Future<
Output = Result<AudioGenerationResponse<Self::Response>, AudioGenerationError>,
> + Send;
fn audio_generation_request(&self) -> AudioGenerationRequestBuilder<Self, Missing, Missing> {
AudioGenerationRequestBuilder::new(self.clone())
}
}
#[non_exhaustive]
pub struct AudioGenerationRequest {
pub text: String,
pub voice: String,
pub speed: f32,
pub additional_params: Option<Value>,
}
#[non_exhaustive]
pub struct AudioGenerationRequestBuilder<M, T = Missing, V = Missing>
where
M: AudioGenerationModel,
{
model: M,
text: T,
voice: V,
speed: f32,
additional_params: Option<Value>,
}
impl<M> AudioGenerationRequestBuilder<M, Missing, Missing>
where
M: AudioGenerationModel,
{
pub fn new(model: M) -> Self {
Self {
model,
text: Missing,
voice: Missing,
speed: 1.0,
additional_params: None,
}
}
}
impl<M, T, V> AudioGenerationRequestBuilder<M, T, V>
where
M: AudioGenerationModel,
{
pub fn text(self, text: &str) -> AudioGenerationRequestBuilder<M, Provided<String>, V> {
AudioGenerationRequestBuilder {
model: self.model,
text: Provided(text.to_string()),
voice: self.voice,
speed: self.speed,
additional_params: self.additional_params,
}
}
pub fn voice(self, voice: &str) -> AudioGenerationRequestBuilder<M, T, Provided<String>> {
AudioGenerationRequestBuilder {
model: self.model,
text: self.text,
voice: Provided(voice.to_string()),
speed: self.speed,
additional_params: self.additional_params,
}
}
pub fn speed(mut self, speed: f32) -> Self {
self.speed = speed;
self
}
pub fn additional_params(mut self, params: Value) -> Self {
self.additional_params = Some(params);
self
}
}
impl<M> AudioGenerationRequestBuilder<M, Provided<String>, Provided<String>>
where
M: AudioGenerationModel,
{
pub fn build(self) -> AudioGenerationRequest {
AudioGenerationRequest {
text: self.text.0,
voice: self.voice.0,
speed: self.speed,
additional_params: self.additional_params,
}
}
pub async fn send(self) -> Result<AudioGenerationResponse<M::Response>, AudioGenerationError> {
let model = self.model.clone();
model.audio_generation(self.build()).await
}
}