1use crate::{client::audio_generation::AudioGenerationModelHandle, http_client};
4use futures::future::BoxFuture;
5use serde_json::Value;
6use std::sync::Arc;
7use thiserror::Error;
8
9#[derive(Debug, Error)]
10pub enum AudioGenerationError {
11 #[error("HttpError: {0}")]
13 HttpError(#[from] http_client::Error),
14
15 #[error("JsonError: {0}")]
17 JsonError(#[from] serde_json::Error),
18
19 #[error("RequestError: {0}")]
21 RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
22
23 #[error("ResponseError: {0}")]
25 ResponseError(String),
26
27 #[error("ProviderError: {0}")]
29 ProviderError(String),
30}
31pub trait AudioGeneration<M>
32where
33 M: AudioGenerationModel,
34{
35 fn audio_generation(
44 &self,
45 text: &str,
46 voice: &str,
47 ) -> impl std::future::Future<
48 Output = Result<AudioGenerationRequestBuilder<M>, AudioGenerationError>,
49 > + Send;
50}
51
52pub struct AudioGenerationResponse<T> {
53 pub audio: Vec<u8>,
54 pub response: T,
55}
56
57pub trait AudioGenerationModel: Clone + Send + Sync {
58 type Response: Send + Sync;
59
60 fn audio_generation(
61 &self,
62 request: AudioGenerationRequest,
63 ) -> impl std::future::Future<
64 Output = Result<AudioGenerationResponse<Self::Response>, AudioGenerationError>,
65 > + Send;
66
67 fn audio_generation_request(&self) -> AudioGenerationRequestBuilder<Self> {
68 AudioGenerationRequestBuilder::new(self.clone())
69 }
70}
71
72pub trait AudioGenerationModelDyn: Send + Sync {
73 fn audio_generation(
74 &self,
75 request: AudioGenerationRequest,
76 ) -> BoxFuture<'_, Result<AudioGenerationResponse<()>, AudioGenerationError>>;
77
78 fn audio_generation_request(
79 &self,
80 ) -> AudioGenerationRequestBuilder<AudioGenerationModelHandle<'_>>;
81}
82
83impl<T> AudioGenerationModelDyn for T
84where
85 T: AudioGenerationModel,
86{
87 fn audio_generation(
88 &self,
89 request: AudioGenerationRequest,
90 ) -> BoxFuture<'_, Result<AudioGenerationResponse<()>, AudioGenerationError>> {
91 Box::pin(async move {
92 let resp = self.audio_generation(request).await;
93
94 resp.map(|r| AudioGenerationResponse {
95 audio: r.audio,
96 response: (),
97 })
98 })
99 }
100
101 fn audio_generation_request(
102 &self,
103 ) -> AudioGenerationRequestBuilder<AudioGenerationModelHandle<'_>> {
104 AudioGenerationRequestBuilder::new(AudioGenerationModelHandle {
105 inner: Arc::new(self.clone()),
106 })
107 }
108}
109
110#[non_exhaustive]
111pub struct AudioGenerationRequest {
112 pub text: String,
113 pub voice: String,
114 pub speed: f32,
115 pub additional_params: Option<Value>,
116}
117
118#[non_exhaustive]
119pub struct AudioGenerationRequestBuilder<M>
120where
121 M: AudioGenerationModel,
122{
123 model: M,
124 text: String,
125 voice: String,
126 speed: f32,
127 additional_params: Option<Value>,
128}
129
130impl<M> AudioGenerationRequestBuilder<M>
131where
132 M: AudioGenerationModel,
133{
134 pub fn new(model: M) -> Self {
135 Self {
136 model,
137 text: "".to_string(),
138 voice: "".to_string(),
139 speed: 1.0,
140 additional_params: None,
141 }
142 }
143
144 pub fn text(mut self, text: &str) -> Self {
146 self.text = text.to_string();
147 self
148 }
149
150 pub fn voice(mut self, voice: &str) -> Self {
152 self.voice = voice.to_string();
153 self
154 }
155
156 pub fn speed(mut self, speed: f32) -> Self {
158 self.speed = speed;
159 self
160 }
161
162 pub fn additional_params(mut self, params: Value) -> Self {
164 self.additional_params = Some(params);
165 self
166 }
167
168 pub fn build(self) -> AudioGenerationRequest {
169 AudioGenerationRequest {
170 text: self.text,
171 voice: self.voice,
172 speed: self.speed,
173 additional_params: self.additional_params,
174 }
175 }
176
177 pub async fn send(self) -> Result<AudioGenerationResponse<M::Response>, AudioGenerationError> {
178 let model = self.model.clone();
179
180 model.audio_generation(self.build()).await
181 }
182}