Skip to main content

rig/providers/xai/
audio_generation.rs

1use crate::audio_generation::{
2    self, AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
3};
4use crate::http_client::{self, HttpClientExt};
5use crate::json_utils::merge_inplace;
6use crate::providers::xai::Client;
7use bytes::Bytes;
8use serde_json::json;
9
10// ================================================================
11// xAI TTS API
12// ================================================================
13pub const TTS_1: &str = "tts-1";
14
15#[derive(Clone)]
16pub struct AudioGenerationModel<T = reqwest::Client> {
17    client: Client<T>,
18    pub model: String,
19}
20
21impl<T> AudioGenerationModel<T> {
22    pub(crate) fn new(client: Client<T>, model: impl Into<String>) -> Self {
23        Self {
24            client,
25            model: model.into(),
26        }
27    }
28}
29
30impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
31where
32    T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
33{
34    type Response = Bytes;
35
36    type Client = Client<T>;
37
38    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
39        Self::new(client.clone(), model)
40    }
41
42    async fn audio_generation(
43        &self,
44        request: AudioGenerationRequest,
45    ) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
46        let voice = if request.voice.is_empty() {
47            "eve".to_string()
48        } else {
49            request.voice
50        };
51
52        let mut body = json!({
53            "text": request.text,
54            "voice_id": voice,
55            "language": "en",
56        });
57
58        if let Some(additional_params) = request.additional_params {
59            merge_inplace(&mut body, additional_params);
60        }
61
62        let body = serde_json::to_vec(&body)?;
63
64        let req = self
65            .client
66            .post("/v1/tts")?
67            .body(body)
68            .map_err(http_client::Error::from)?;
69
70        let response = self.client.send(req).await?;
71
72        if !response.status().is_success() {
73            let status = response.status();
74            let text = http_client::text(response).await?;
75
76            return Err(AudioGenerationError::ProviderError(format!(
77                "{}: {}",
78                status, text,
79            )));
80        }
81
82        let bytes: Bytes = response.into_body().await?.into();
83
84        Ok(AudioGenerationResponse {
85            audio: bytes.to_vec(),
86            response: bytes,
87        })
88    }
89}