rig/
audio_generation.rs

1//! Everything related to audio generation (ie, Text To Speech).
2//! Rig abstracts over a number of different providers using the [AudioGenerationModel] trait.
3use crate::{client::audio_generation::AudioGenerationModelHandle, http_client};
4use futures::future::BoxFuture;
5use serde_json::Value;
6use std::sync::Arc;
7use thiserror::Error;
8
9#[derive(Debug, Error)]
10pub enum AudioGenerationError {
11    /// Http error (e.g.: connection error, timeout, etc.)
12    #[error("HttpError: {0}")]
13    HttpError(#[from] http_client::Error),
14
15    /// Json error (e.g.: serialization, deserialization)
16    #[error("JsonError: {0}")]
17    JsonError(#[from] serde_json::Error),
18
19    /// Error building the transcription request
20    #[error("RequestError: {0}")]
21    RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
22
23    /// Error parsing the transcription response
24    #[error("ResponseError: {0}")]
25    ResponseError(String),
26
27    /// Error returned by the transcription model provider
28    #[error("ProviderError: {0}")]
29    ProviderError(String),
30}
31pub trait AudioGeneration<M>
32where
33    M: AudioGenerationModel,
34{
35    /// Generates an audio generation request builder for the given `text` and `voice`.
36    /// This function is meant to be called by the user to further customize the
37    /// request at generation time before sending it.
38    ///
39    /// ❗IMPORTANT: The type that implements this trait might have already
40    /// populated fields in the builder (the exact fields depend on the type).
41    /// For fields that have already been set by the model, calling the corresponding
42    /// method on the builder will overwrite the value set by the model.
43    fn audio_generation(
44        &self,
45        text: &str,
46        voice: &str,
47    ) -> impl std::future::Future<
48        Output = Result<AudioGenerationRequestBuilder<M>, AudioGenerationError>,
49    > + Send;
50}
51
52pub struct AudioGenerationResponse<T> {
53    pub audio: Vec<u8>,
54    pub response: T,
55}
56
57pub trait AudioGenerationModel: Clone + Send + Sync {
58    type Response: Send + Sync;
59
60    fn audio_generation(
61        &self,
62        request: AudioGenerationRequest,
63    ) -> impl std::future::Future<
64        Output = Result<AudioGenerationResponse<Self::Response>, AudioGenerationError>,
65    > + Send;
66
67    fn audio_generation_request(&self) -> AudioGenerationRequestBuilder<Self> {
68        AudioGenerationRequestBuilder::new(self.clone())
69    }
70}
71
72pub trait AudioGenerationModelDyn: Send + Sync {
73    fn audio_generation(
74        &self,
75        request: AudioGenerationRequest,
76    ) -> BoxFuture<'_, Result<AudioGenerationResponse<()>, AudioGenerationError>>;
77
78    fn audio_generation_request(
79        &self,
80    ) -> AudioGenerationRequestBuilder<AudioGenerationModelHandle<'_>>;
81}
82
83impl<T> AudioGenerationModelDyn for T
84where
85    T: AudioGenerationModel,
86{
87    fn audio_generation(
88        &self,
89        request: AudioGenerationRequest,
90    ) -> BoxFuture<'_, Result<AudioGenerationResponse<()>, AudioGenerationError>> {
91        Box::pin(async move {
92            let resp = self.audio_generation(request).await;
93
94            resp.map(|r| AudioGenerationResponse {
95                audio: r.audio,
96                response: (),
97            })
98        })
99    }
100
101    fn audio_generation_request(
102        &self,
103    ) -> AudioGenerationRequestBuilder<AudioGenerationModelHandle<'_>> {
104        AudioGenerationRequestBuilder::new(AudioGenerationModelHandle {
105            inner: Arc::new(self.clone()),
106        })
107    }
108}
109
110#[non_exhaustive]
111pub struct AudioGenerationRequest {
112    pub text: String,
113    pub voice: String,
114    pub speed: f32,
115    pub additional_params: Option<Value>,
116}
117
118#[non_exhaustive]
119pub struct AudioGenerationRequestBuilder<M>
120where
121    M: AudioGenerationModel,
122{
123    model: M,
124    text: String,
125    voice: String,
126    speed: f32,
127    additional_params: Option<Value>,
128}
129
130impl<M> AudioGenerationRequestBuilder<M>
131where
132    M: AudioGenerationModel,
133{
134    pub fn new(model: M) -> Self {
135        Self {
136            model,
137            text: "".to_string(),
138            voice: "".to_string(),
139            speed: 1.0,
140            additional_params: None,
141        }
142    }
143
144    /// Sets the text for the audio generation request
145    pub fn text(mut self, text: &str) -> Self {
146        self.text = text.to_string();
147        self
148    }
149
150    /// The voice of the generated audio
151    pub fn voice(mut self, voice: &str) -> Self {
152        self.voice = voice.to_string();
153        self
154    }
155
156    /// The speed of the generated audio
157    pub fn speed(mut self, speed: f32) -> Self {
158        self.speed = speed;
159        self
160    }
161
162    /// Adds additional parameters to the audio generation request.
163    pub fn additional_params(mut self, params: Value) -> Self {
164        self.additional_params = Some(params);
165        self
166    }
167
168    pub fn build(self) -> AudioGenerationRequest {
169        AudioGenerationRequest {
170            text: self.text,
171            voice: self.voice,
172            speed: self.speed,
173            additional_params: self.additional_params,
174        }
175    }
176
177    pub async fn send(self) -> Result<AudioGenerationResponse<M::Response>, AudioGenerationError> {
178        let model = self.model.clone();
179
180        model.audio_generation(self.build()).await
181    }
182}