rig/providers/openai/
audio_generation.rs1use crate::audio_generation::{
2 self, AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
3};
4use crate::http_client::{self, HttpClientExt};
5use crate::providers::openai::Client;
6use bytes::{Buf, Bytes};
7use serde_json::json;
8
9pub const TTS_1: &str = "tts-1";
10pub const TTS_1_HD: &str = "tts-1-hd";
11
12#[derive(Clone)]
13pub struct AudioGenerationModel<T = reqwest::Client> {
14 client: Client<T>,
15 pub model: String,
16}
17
18impl<T> AudioGenerationModel<T> {
19 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
20 Self {
21 client,
22 model: model.into(),
23 }
24 }
25}
26
27impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
28where
29 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
30{
31 type Response = Bytes;
32
33 type Client = Client<T>;
34
35 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
36 Self::new(client.clone(), model)
37 }
38
39 async fn audio_generation(
40 &self,
41 request: AudioGenerationRequest,
42 ) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
43 let body = serde_json::to_vec(&json!({
44 "model": self.model,
45 "input": request.text,
46 "voice": request.voice,
47 "speed": request.speed,
48 }))?;
49
50 let req = self
51 .client
52 .post("/audio/speech")?
53 .body(body)
54 .map_err(http_client::Error::from)?;
55
56 let response = self.client.send(req).await?;
57
58 if !response.status().is_success() {
59 let status = response.status();
60 let mut bytes: Bytes = response.into_body().await?;
61 let mut as_slice = Vec::new();
62 bytes.copy_to_slice(&mut as_slice);
63
64 let text: String = String::from_utf8_lossy(&as_slice).into();
65
66 return Err(AudioGenerationError::ProviderError(format!(
67 "{}: {}",
68 status, text
69 )));
70 }
71
72 let bytes: Bytes = response.into_body().await?;
73
74 Ok(AudioGenerationResponse {
75 audio: bytes.to_vec(),
76 response: bytes,
77 })
78 }
79}