rig_core/providers/openrouter/
audio_generation.rs1use crate::audio_generation::{
2 self, AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse,
3};
4use crate::http_client::{self, HttpClientExt};
5use crate::providers::openrouter::Client;
6use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
7use bytes::Bytes;
8use serde_json::json;
9
10pub const GPT_4O_MINI_TTS: &str = "openai/gpt-4o-mini-tts-2025-12-15";
16pub const VOXTRAL_MINI_TTS: &str = "mistralai/voxtral-mini-tts-2603";
18pub const KOKORO_82M: &str = "hexgrad/kokoro-82m";
20
21#[derive(Clone)]
26pub struct AudioGenerationModel<T = reqwest::Client> {
27 client: Client<T>,
28 pub model: String,
29}
30
31impl<T> AudioGenerationModel<T> {
32 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
33 Self {
34 client,
35 model: model.into(),
36 }
37 }
38}
39
40impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
41where
42 T: HttpClientExt
43 + Clone
44 + std::fmt::Debug
45 + Default
46 + WasmCompatSend
47 + WasmCompatSync
48 + 'static,
49{
50 type Response = Bytes;
51 type Client = Client<T>;
52
53 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
54 Self::new(client.clone(), model)
55 }
56
57 async fn audio_generation(
58 &self,
59 request: AudioGenerationRequest,
60 ) -> Result<AudioGenerationResponse<Self::Response>, AudioGenerationError> {
61 let mut body_map: serde_json::Map<String, serde_json::Value> = [
62 ("model".to_string(), json!(self.model)),
63 ("input".to_string(), json!(request.text)),
64 ("voice".to_string(), json!(request.voice)),
65 ("response_format".to_string(), json!("mp3")),
66 ("speed".to_string(), json!(request.speed)),
67 ]
68 .into_iter()
69 .collect();
70
71 if let Some(ref additional_params) = request.additional_params {
72 let params = additional_params.as_object().ok_or_else(|| {
73 AudioGenerationError::RequestError(Box::new(std::io::Error::new(
74 std::io::ErrorKind::InvalidInput,
75 "additional audio generation parameters must be a JSON object",
76 )))
77 })?;
78 for (k, v) in params {
79 body_map.insert(k.clone(), v.clone());
80 }
81 }
82
83 let body = serde_json::to_vec(&serde_json::Value::Object(body_map))?;
84
85 let req = self
86 .client
87 .post("/audio/speech")?
88 .header("Content-Type", "application/json")
89 .body(body)
90 .map_err(http_client::Error::from)?;
91
92 let response = self.client.send(req).await?;
93
94 if !response.status().is_success() {
95 let status = response.status();
96 let text = http_client::text(response).await?;
97 return Err(AudioGenerationError::ProviderError(format!(
98 "{}: {}",
99 status, text
100 )));
101 }
102
103 let audio: Vec<u8> = response.into_body().await?;
104
105 Ok(AudioGenerationResponse {
106 audio: audio.clone(),
107 response: Bytes::from(audio),
108 })
109 }
110}