Skip to main content

rig_core/providers/openrouter/
audio_generation.rs

1use crate::audio_generation::{
2    self, AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
3};
4use crate::http_client::{self, HttpClientExt};
5use crate::providers::openrouter::Client;
6use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
7use bytes::Bytes;
8use serde_json::json;
9
10// ================================================================
11// Model constants
12// ================================================================
13
14/// The `openai/gpt-4o-mini-tts-2025-12-15` model.
15pub const GPT_4O_MINI_TTS: &str = "openai/gpt-4o-mini-tts-2025-12-15";
16/// The `mistralai/voxtral-mini-tts-2603` model.
17pub const VOXTRAL_MINI_TTS: &str = "mistralai/voxtral-mini-tts-2603";
18/// The `hexgrad/kokoro-82m` model.
19pub const KOKORO_82M: &str = "hexgrad/kokoro-82m";
20
21// ================================================================
22// Model
23// ================================================================
24
25#[derive(Clone)]
26pub struct AudioGenerationModel<T = reqwest::Client> {
27    client: Client<T>,
28    pub model: String,
29}
30
31impl<T> AudioGenerationModel<T> {
32    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
33        Self {
34            client,
35            model: model.into(),
36        }
37    }
38}
39
40impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
41where
42    T: HttpClientExt
43        + Clone
44        + std::fmt::Debug
45        + Default
46        + WasmCompatSend
47        + WasmCompatSync
48        + 'static,
49{
50    type Response = Bytes;
51    type Client = Client<T>;
52
53    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
54        Self::new(client.clone(), model)
55    }
56
57    async fn audio_generation(
58        &self,
59        request: AudioGenerationRequest,
60    ) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
61        let mut body_map: serde_json::Map<String, serde_json::Value> = [
62            ("model".to_string(), json!(self.model)),
63            ("input".to_string(), json!(request.text)),
64            ("voice".to_string(), json!(request.voice)),
65            ("response_format".to_string(), json!("mp3")),
66            ("speed".to_string(), json!(request.speed)),
67        ]
68        .into_iter()
69        .collect();
70
71        if let Some(ref additional_params) = request.additional_params {
72            let params = additional_params.as_object().ok_or_else(|| {
73                AudioGenerationError::RequestError(Box::new(std::io::Error::new(
74                    std::io::ErrorKind::InvalidInput,
75                    "additional audio generation parameters must be a JSON object",
76                )))
77            })?;
78            for (k, v) in params {
79                body_map.insert(k.clone(), v.clone());
80            }
81        }
82
83        let body = serde_json::to_vec(&serde_json::Value::Object(body_map))?;
84
85        let req = self
86            .client
87            .post("/audio/speech")?
88            .header("Content-Type", "application/json")
89            .body(body)
90            .map_err(http_client::Error::from)?;
91
92        let response = self.client.send(req).await?;
93
94        if !response.status().is_success() {
95            let status = response.status();
96            let text = http_client::text(response).await?;
97            return Err(AudioGenerationError::ProviderError(format!(
98                "{}: {}",
99                status, text
100            )));
101        }
102
103        let audio: Vec<u8> = response.into_body().await?;
104
105        Ok(AudioGenerationResponse {
106            audio: audio.clone(),
107            response: Bytes::from(audio),
108        })
109    }
110}