rig/providers/openai/
audio_generation.rs1use crate::audio_generation::{
2 self, AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
3};
4use crate::http_client::{self, HttpClientExt};
5use crate::providers::openai::Client;
6use bytes::{Buf, Bytes};
7use serde_json::json;
8
9pub const TTS_1: &str = "tts-1";
10pub const TTS_1_HD: &str = "tts-1-hd";
11
12#[derive(Clone)]
13pub struct AudioGenerationModel<T = reqwest::Client> {
14 client: Client<T>,
15 pub model: String,
16}
17
18impl<T> AudioGenerationModel<T> {
19 pub fn new(client: Client<T>, model: &str) -> Self {
20 Self {
21 client,
22 model: model.to_string(),
23 }
24 }
25}
26
27impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
28where
29 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
30{
31 type Response = Bytes;
32
33 #[cfg_attr(feature = "worker", worker::send)]
34 async fn audio_generation(
35 &self,
36 request: AudioGenerationRequest,
37 ) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
38 let body = serde_json::to_vec(&json!({
39 "model": self.model,
40 "input": request.text,
41 "voice": request.voice,
42 "speed": request.speed,
43 }))?;
44
45 let req = self
46 .client
47 .post("/audio/speech")?
48 .header("Content-Type", "application/json")
49 .body(body)
50 .map_err(http_client::Error::from)?;
51
52 let response = self.client.send(req).await?;
53
54 if !response.status().is_success() {
55 let status = response.status();
56 let mut bytes: Bytes = response.into_body().await?;
57 let mut as_slice = Vec::new();
58 bytes.copy_to_slice(&mut as_slice);
59
60 let text: String = String::from_utf8_lossy(&as_slice).into();
61
62 return Err(AudioGenerationError::ProviderError(format!(
63 "{}: {}",
64 status, text
65 )));
66 }
67
68 let bytes: Bytes = response.into_body().await?;
69
70 Ok(AudioGenerationResponse {
71 audio: bytes.to_vec(),
72 response: bytes,
73 })
74 }
75}