rig/
audio_generation.rs

1//! Everything related to audio generation (ie, Text To Speech).
2//! Rig abstracts over a number of different providers using the [AudioGenerationModel] trait.
3#[allow(deprecated)]
4use crate::client::audio_generation::AudioGenerationModelHandle;
5use crate::{
6    http_client,
7    wasm_compat::{WasmCompatSend, WasmCompatSync},
8};
9use futures::future::BoxFuture;
10use serde_json::Value;
11use std::sync::Arc;
12use thiserror::Error;
13
14#[derive(Debug, Error)]
15pub enum AudioGenerationError {
16    /// Http error (e.g.: connection error, timeout, etc.)
17    #[error("HttpError: {0}")]
18    HttpError(#[from] http_client::Error),
19
20    /// Json error (e.g.: serialization, deserialization)
21    #[error("JsonError: {0}")]
22    JsonError(#[from] serde_json::Error),
23
24    /// Error building the transcription request
25    #[error("RequestError: {0}")]
26    RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
27
28    /// Error parsing the transcription response
29    #[error("ResponseError: {0}")]
30    ResponseError(String),
31
32    /// Error returned by the transcription model provider
33    #[error("ProviderError: {0}")]
34    ProviderError(String),
35}
36pub trait AudioGeneration<M>
37where
38    M: AudioGenerationModel,
39{
40    /// Generates an audio generation request builder for the given `text` and `voice`.
41    /// This function is meant to be called by the user to further customize the
42    /// request at generation time before sending it.
43    ///
44    /// ❗IMPORTANT: The type that implements this trait might have already
45    /// populated fields in the builder (the exact fields depend on the type).
46    /// For fields that have already been set by the model, calling the corresponding
47    /// method on the builder will overwrite the value set by the model.
48    fn audio_generation(
49        &self,
50        text: &str,
51        voice: &str,
52    ) -> impl std::future::Future<
53        Output = Result<AudioGenerationRequestBuilder<M>, AudioGenerationError>,
54    > + Send;
55}
56
57pub struct AudioGenerationResponse<T> {
58    pub audio: Vec<u8>,
59    pub response: T,
60}
61
62pub trait AudioGenerationModel: Sized + Clone + WasmCompatSend + WasmCompatSync {
63    type Response: Send + Sync;
64
65    type Client;
66
67    fn make(client: &Self::Client, model: impl Into<String>) -> Self;
68
69    fn audio_generation(
70        &self,
71        request: AudioGenerationRequest,
72    ) -> impl std::future::Future<
73        Output = Result<AudioGenerationResponse<Self::Response>, AudioGenerationError>,
74    > + Send;
75
76    fn audio_generation_request(&self) -> AudioGenerationRequestBuilder<Self> {
77        AudioGenerationRequestBuilder::new(self.clone())
78    }
79}
80
81#[allow(deprecated)]
82#[deprecated(
83    since = "0.25.0",
84    note = "`DynClientBuilder` and related features have been deprecated and will be removed in a future release. In this case, use `AudioGenerationModel` instead."
85)]
86pub trait AudioGenerationModelDyn: Send + Sync {
87    fn audio_generation(
88        &self,
89        request: AudioGenerationRequest,
90    ) -> BoxFuture<'_, Result<AudioGenerationResponse<()>, AudioGenerationError>>;
91
92    fn audio_generation_request(
93        &self,
94    ) -> AudioGenerationRequestBuilder<AudioGenerationModelHandle<'_>>;
95}
96
97#[allow(deprecated)]
98impl<T> AudioGenerationModelDyn for T
99where
100    T: AudioGenerationModel,
101{
102    fn audio_generation(
103        &self,
104        request: AudioGenerationRequest,
105    ) -> BoxFuture<'_, Result<AudioGenerationResponse<()>, AudioGenerationError>> {
106        Box::pin(async move {
107            let resp = self.audio_generation(request).await;
108
109            resp.map(|r| AudioGenerationResponse {
110                audio: r.audio,
111                response: (),
112            })
113        })
114    }
115
116    fn audio_generation_request(
117        &self,
118    ) -> AudioGenerationRequestBuilder<AudioGenerationModelHandle<'_>> {
119        AudioGenerationRequestBuilder::new(AudioGenerationModelHandle {
120            inner: Arc::new(self.clone()),
121        })
122    }
123}
124
125#[non_exhaustive]
126pub struct AudioGenerationRequest {
127    pub text: String,
128    pub voice: String,
129    pub speed: f32,
130    pub additional_params: Option<Value>,
131}
132
133#[non_exhaustive]
134pub struct AudioGenerationRequestBuilder<M>
135where
136    M: AudioGenerationModel,
137{
138    model: M,
139    text: String,
140    voice: String,
141    speed: f32,
142    additional_params: Option<Value>,
143}
144
145impl<M> AudioGenerationRequestBuilder<M>
146where
147    M: AudioGenerationModel,
148{
149    pub fn new(model: M) -> Self {
150        Self {
151            model,
152            text: "".to_string(),
153            voice: "".to_string(),
154            speed: 1.0,
155            additional_params: None,
156        }
157    }
158
159    /// Sets the text for the audio generation request
160    pub fn text(mut self, text: &str) -> Self {
161        self.text = text.to_string();
162        self
163    }
164
165    /// The voice of the generated audio
166    pub fn voice(mut self, voice: &str) -> Self {
167        self.voice = voice.to_string();
168        self
169    }
170
171    /// The speed of the generated audio
172    pub fn speed(mut self, speed: f32) -> Self {
173        self.speed = speed;
174        self
175    }
176
177    /// Adds additional parameters to the audio generation request.
178    pub fn additional_params(mut self, params: Value) -> Self {
179        self.additional_params = Some(params);
180        self
181    }
182
183    pub fn build(self) -> AudioGenerationRequest {
184        AudioGenerationRequest {
185            text: self.text,
186            voice: self.voice,
187            speed: self.speed,
188            additional_params: self.additional_params,
189        }
190    }
191
192    pub async fn send(self) -> Result<AudioGenerationResponse<M::Response>, AudioGenerationError> {
193        let model = self.model.clone();
194
195        model.audio_generation(self.build()).await
196    }
197}