rig/providers/openai/
audio_generation.rs1use crate::audio_generation::{
2 self, AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
3};
4use crate::http_client::{self, HttpClientExt};
5use crate::providers::openai::Client;
6use bytes::{Buf, Bytes};
7use serde_json::json;
8
9pub const TTS_1: &str = "tts-1";
10pub const TTS_1_HD: &str = "tts-1-hd";
11
12#[derive(Clone)]
13pub struct AudioGenerationModel<T = reqwest::Client> {
14 client: Client<T>,
15 pub model: String,
16}
17
18impl<T> AudioGenerationModel<T> {
19 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
20 Self {
21 client,
22 model: model.into(),
23 }
24 }
25}
26
27impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
28where
29 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
30{
31 type Response = Bytes;
32
33 type Client = Client<T>;
34
35 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
36 Self::new(client.clone(), model)
37 }
38
39 #[cfg_attr(feature = "worker", worker::send)]
40 async fn audio_generation(
41 &self,
42 request: AudioGenerationRequest,
43 ) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
44 let body = serde_json::to_vec(&json!({
45 "model": self.model,
46 "input": request.text,
47 "voice": request.voice,
48 "speed": request.speed,
49 }))?;
50
51 let req = self
52 .client
53 .post("/audio/speech")?
54 .body(body)
55 .map_err(http_client::Error::from)?;
56
57 let response = self.client.send(req).await?;
58
59 if !response.status().is_success() {
60 let status = response.status();
61 let mut bytes: Bytes = response.into_body().await?;
62 let mut as_slice = Vec::new();
63 bytes.copy_to_slice(&mut as_slice);
64
65 let text: String = String::from_utf8_lossy(&as_slice).into();
66
67 return Err(AudioGenerationError::ProviderError(format!(
68 "{}: {}",
69 status, text
70 )));
71 }
72
73 let bytes: Bytes = response.into_body().await?;
74
75 Ok(AudioGenerationResponse {
76 audio: bytes.to_vec(),
77 response: bytes,
78 })
79 }
80}