Skip to main content

rig/
audio_generation.rs

1//! Everything related to audio generation (ie, Text To Speech).
2//! Rig abstracts over a number of different providers using the [AudioGenerationModel] trait.
3use crate::markers::{Missing, Provided};
4use crate::{
5    http_client,
6    wasm_compat::{WasmCompatSend, WasmCompatSync},
7};
8use serde_json::Value;
9use thiserror::Error;
10
11#[derive(Debug, Error)]
12pub enum AudioGenerationError {
13    /// Http error (e.g.: connection error, timeout, etc.)
14    #[error("HttpError: {0}")]
15    HttpError(#[from] http_client::Error),
16
17    /// Json error (e.g.: serialization, deserialization)
18    #[error("JsonError: {0}")]
19    JsonError(#[from] serde_json::Error),
20
21    /// Error building the transcription request
22    #[error("RequestError: {0}")]
23    RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
24
25    /// Error parsing the transcription response
26    #[error("ResponseError: {0}")]
27    ResponseError(String),
28
29    /// Error returned by the transcription model provider
30    #[error("ProviderError: {0}")]
31    ProviderError(String),
32}
33pub trait AudioGeneration<M>
34where
35    M: AudioGenerationModel,
36{
37    /// Generates an audio generation request builder for the given `text` and `voice`.
38    /// This function is meant to be called by the user to further customize the
39    /// request at generation time before sending it.
40    ///
41    /// ❗IMPORTANT: The type that implements this trait might have already
42    /// populated fields in the builder (the exact fields depend on the type).
43    /// For fields that have already been set by the model, calling the corresponding
44    /// method on the builder will overwrite the value set by the model.
45    fn audio_generation(
46        &self,
47        text: &str,
48        voice: &str,
49    ) -> impl std::future::Future<
50        Output = Result<
51            AudioGenerationRequestBuilder<M, Provided<String>, Provided<String>>,
52            AudioGenerationError,
53        >,
54    > + Send;
55}
56
57pub struct AudioGenerationResponse<T> {
58    pub audio: Vec<u8>,
59    pub response: T,
60}
61
62pub trait AudioGenerationModel: Sized + Clone + WasmCompatSend + WasmCompatSync {
63    type Response: Send + Sync;
64
65    type Client;
66
67    fn make(client: &Self::Client, model: impl Into<String>) -> Self;
68
69    fn audio_generation(
70        &self,
71        request: AudioGenerationRequest,
72    ) -> impl std::future::Future<
73        Output = Result<AudioGenerationResponse<Self::Response>, AudioGenerationError>,
74    > + Send;
75
76    fn audio_generation_request(&self) -> AudioGenerationRequestBuilder<Self, Missing, Missing> {
77        AudioGenerationRequestBuilder::new(self.clone())
78    }
79}
80#[non_exhaustive]
81pub struct AudioGenerationRequest {
82    pub text: String,
83    pub voice: String,
84    pub speed: f32,
85    pub additional_params: Option<Value>,
86}
87
88#[non_exhaustive]
89pub struct AudioGenerationRequestBuilder<M, T = Missing, V = Missing>
90where
91    M: AudioGenerationModel,
92{
93    model: M,
94    text: T,
95    voice: V,
96    speed: f32,
97    additional_params: Option<Value>,
98}
99
100impl<M> AudioGenerationRequestBuilder<M, Missing, Missing>
101where
102    M: AudioGenerationModel,
103{
104    pub fn new(model: M) -> Self {
105        Self {
106            model,
107            text: Missing,
108            voice: Missing,
109            speed: 1.0,
110            additional_params: None,
111        }
112    }
113}
114
115impl<M, T, V> AudioGenerationRequestBuilder<M, T, V>
116where
117    M: AudioGenerationModel,
118{
119    /// Sets the text for the audio generation request
120    pub fn text(self, text: &str) -> AudioGenerationRequestBuilder<M, Provided<String>, V> {
121        AudioGenerationRequestBuilder {
122            model: self.model,
123            text: Provided(text.to_string()),
124            voice: self.voice,
125            speed: self.speed,
126            additional_params: self.additional_params,
127        }
128    }
129
130    /// The voice of the generated audio
131    pub fn voice(self, voice: &str) -> AudioGenerationRequestBuilder<M, T, Provided<String>> {
132        AudioGenerationRequestBuilder {
133            model: self.model,
134            text: self.text,
135            voice: Provided(voice.to_string()),
136            speed: self.speed,
137            additional_params: self.additional_params,
138        }
139    }
140
141    /// The speed of the generated audio
142    pub fn speed(mut self, speed: f32) -> Self {
143        self.speed = speed;
144        self
145    }
146
147    /// Adds additional parameters to the audio generation request.
148    pub fn additional_params(mut self, params: Value) -> Self {
149        self.additional_params = Some(params);
150        self
151    }
152}
153
154impl<M> AudioGenerationRequestBuilder<M, Provided<String>, Provided<String>>
155where
156    M: AudioGenerationModel,
157{
158    pub fn build(self) -> AudioGenerationRequest {
159        AudioGenerationRequest {
160            text: self.text.0,
161            voice: self.voice.0,
162            speed: self.speed,
163            additional_params: self.additional_params,
164        }
165    }
166
167    pub async fn send(self) -> Result<AudioGenerationResponse<M::Response>, AudioGenerationError> {
168        let model = self.model.clone();
169
170        model.audio_generation(self.build()).await
171    }
172}