rig/providers/openai/
audio_generation.rs

1use crate::audio_generation::{
2    self, AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
3};
4use crate::http_client::{self, HttpClientExt};
5use crate::providers::openai::Client;
6use bytes::{Buf, Bytes};
7use serde_json::json;
8
9pub const TTS_1: &str = "tts-1";
10pub const TTS_1_HD: &str = "tts-1-hd";
11
12#[derive(Clone)]
13pub struct AudioGenerationModel<T = reqwest::Client> {
14    client: Client<T>,
15    pub model: String,
16}
17
18impl<T> AudioGenerationModel<T> {
19    pub fn new(client: Client<T>, model: &str) -> Self {
20        Self {
21            client,
22            model: model.to_string(),
23        }
24    }
25}
26
27impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
28where
29    T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
30{
31    type Response = Bytes;
32
33    #[cfg_attr(feature = "worker", worker::send)]
34    async fn audio_generation(
35        &self,
36        request: AudioGenerationRequest,
37    ) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
38        let body = serde_json::to_vec(&json!({
39            "model": self.model,
40            "input": request.text,
41            "voice": request.voice,
42            "speed": request.speed,
43        }))?;
44
45        let req = self
46            .client
47            .post("/audio/speech")?
48            .header("Content-Type", "application/json")
49            .body(body)
50            .map_err(http_client::Error::from)?;
51
52        let response = self.client.send(req).await?;
53
54        if !response.status().is_success() {
55            let status = response.status();
56            let mut bytes: Bytes = response.into_body().await?;
57            let mut as_slice = Vec::new();
58            bytes.copy_to_slice(&mut as_slice);
59
60            let text: String = String::from_utf8_lossy(&as_slice).into();
61
62            return Err(AudioGenerationError::ProviderError(format!(
63                "{}: {}",
64                status, text
65            )));
66        }
67
68        let bytes: Bytes = response.into_body().await?;
69
70        Ok(AudioGenerationResponse {
71            audio: bytes.to_vec(),
72            response: bytes,
73        })
74    }
75}