rig/providers/
hyperbolic.rs

1//! Hyperbolic Inference API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::hyperbolic;
6//!
7//! let client = hyperbolic::Client::new("YOUR_API_KEY");
8//!
9//! let llama_3_1_8b = client.completion_model(hyperbolic::LLAMA_3_1_8B);
10//! ```
11use super::openai::{AssistantContent, send_compatible_streaming_request};
12
13use crate::client::{CompletionClient, ProviderClient, VerifyClient, VerifyError};
14use crate::http_client::{self, HttpClientExt};
15use crate::json_utils::merge_inplace;
16use crate::message;
17use crate::streaming::StreamingCompletionResponse;
18
19use crate::impl_conversion_traits;
20use crate::providers::openai;
21use crate::{
22    OneOrMany,
23    completion::{self, CompletionError, CompletionRequest},
24    json_utils,
25    providers::openai::Message,
26};
27use serde::{Deserialize, Serialize};
28use serde_json::{Value, json};
29
30// ================================================================
31// Main Hyperbolic Client
32// ================================================================
33const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz";
34
35pub struct ClientBuilder<'a, T = reqwest::Client> {
36    api_key: &'a str,
37    base_url: &'a str,
38    http_client: T,
39}
40
41impl<'a, T> ClientBuilder<'a, T>
42where
43    T: Default,
44{
45    pub fn new(api_key: &'a str) -> Self {
46        Self {
47            api_key,
48            base_url: HYPERBOLIC_API_BASE_URL,
49            http_client: Default::default(),
50        }
51    }
52}
53
54impl<'a, T> ClientBuilder<'a, T> {
55    pub fn base_url(mut self, base_url: &'a str) -> Self {
56        self.base_url = base_url;
57        self
58    }
59
60    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
61        ClientBuilder {
62            api_key: self.api_key,
63            base_url: self.base_url,
64            http_client,
65        }
66    }
67
68    pub fn build(self) -> Client<T> {
69        Client {
70            base_url: self.base_url.to_string(),
71            api_key: self.api_key.to_string(),
72            http_client: self.http_client,
73        }
74    }
75}
76
77#[derive(Clone)]
78pub struct Client<T = reqwest::Client> {
79    base_url: String,
80    api_key: String,
81    http_client: T,
82}
83
84impl<T> std::fmt::Debug for Client<T>
85where
86    T: std::fmt::Debug,
87{
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        f.debug_struct("Client")
90            .field("base_url", &self.base_url)
91            .field("http_client", &self.http_client)
92            .field("api_key", &"<REDACTED>")
93            .finish()
94    }
95}
96
97impl<T> Client<T>
98where
99    T: Default,
100{
101    /// Create a new Hyperbolic client builder.
102    ///
103    /// # Example
104    /// ```
105    /// use rig::providers::hyperbolic::{ClientBuilder, self};
106    ///
107    /// // Initialize the Hyperbolic client
108    /// let hyperbolic = Client::builder("your-hyperbolic-api-key")
109    ///    .build()
110    /// ```
111    pub fn builder(api_key: &str) -> ClientBuilder<'_, T> {
112        ClientBuilder::new(api_key)
113    }
114
115    /// Create a new Hyperbolic client. For more control, use the `builder` method.
116    ///
117    /// # Panics
118    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
119    pub fn new(api_key: &str) -> Self {
120        Self::builder(api_key).build()
121    }
122}
123
124impl<T> Client<T>
125where
126    T: HttpClientExt,
127{
128    fn req(
129        &self,
130        method: http_client::Method,
131        path: &str,
132    ) -> http_client::Result<http_client::Builder> {
133        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
134
135        http_client::with_bearer_auth(
136            http_client::Builder::new().method(method).uri(url),
137            &self.api_key,
138        )
139    }
140
141    fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
142        self.req(http_client::Method::GET, path)
143    }
144}
145
146impl Client<reqwest::Client> {
147    fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder {
148        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
149
150        self.http_client.post(url).bearer_auth(&self.api_key)
151    }
152}
153
154impl ProviderClient for Client<reqwest::Client> {
155    /// Create a new Hyperbolic client from the `HYPERBOLIC_API_KEY` environment variable.
156    /// Panics if the environment variable is not set.
157    fn from_env() -> Self {
158        let api_key = std::env::var("HYPERBOLIC_API_KEY").expect("HYPERBOLIC_API_KEY not set");
159        Self::new(&api_key)
160    }
161
162    fn from_val(input: crate::client::ProviderValue) -> Self {
163        let crate::client::ProviderValue::Simple(api_key) = input else {
164            panic!("Incorrect provider value type")
165        };
166        Self::new(&api_key)
167    }
168}
169
170impl CompletionClient for Client<reqwest::Client> {
171    type CompletionModel = CompletionModel<reqwest::Client>;
172
173    /// Create a completion model with the given name.
174    ///
175    /// # Example
176    /// ```
177    /// use rig::providers::hyperbolic::{Client, self};
178    ///
179    /// // Initialize the Hyperbolic client
180    /// let hyperbolic = Client::new("your-hyperbolic-api-key");
181    ///
182    /// let llama_3_1_8b = hyperbolic.completion_model(hyperbolic::LLAMA_3_1_8B);
183    /// ```
184    fn completion_model(&self, model: &str) -> CompletionModel<reqwest::Client> {
185        CompletionModel::new(self.clone(), model)
186    }
187}
188
189impl VerifyClient for Client<reqwest::Client> {
190    #[cfg_attr(feature = "worker", worker::send)]
191    async fn verify(&self) -> Result<(), VerifyError> {
192        let req = self
193            .get("/models")?
194            .body(http_client::NoBody)
195            .map_err(http_client::Error::from)?;
196
197        let response = HttpClientExt::send(&self.http_client, req).await?;
198
199        match response.status() {
200            reqwest::StatusCode::OK => Ok(()),
201            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
202            reqwest::StatusCode::INTERNAL_SERVER_ERROR
203            | reqwest::StatusCode::SERVICE_UNAVAILABLE
204            | reqwest::StatusCode::BAD_GATEWAY => {
205                let text = http_client::text(response).await?;
206                Err(VerifyError::ProviderError(text))
207            }
208            _ => {
209                //response.error_for_status()?;
210                Ok(())
211            }
212        }
213    }
214}
215
216impl_conversion_traits!(
217    AsEmbeddings,
218    AsTranscription for Client<T>
219);
220
221#[derive(Debug, Deserialize)]
222struct ApiErrorResponse {
223    message: String,
224}
225
226#[derive(Debug, Deserialize)]
227#[serde(untagged)]
228enum ApiResponse<T> {
229    Ok(T),
230    Err(ApiErrorResponse),
231}
232
233#[derive(Debug, Deserialize)]
234pub struct EmbeddingData {
235    pub object: String,
236    pub embedding: Vec<f64>,
237    pub index: usize,
238}
239
240#[derive(Clone, Debug, Deserialize, Serialize)]
241pub struct Usage {
242    pub prompt_tokens: usize,
243    pub total_tokens: usize,
244}
245
246impl std::fmt::Display for Usage {
247    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
248        write!(
249            f,
250            "Prompt tokens: {} Total tokens: {}",
251            self.prompt_tokens, self.total_tokens
252        )
253    }
254}
255
256// ================================================================
257// Hyperbolic Completion API
258// ================================================================
259/// Meta Llama 3.1b Instruct model with 8B parameters.
260pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
261/// Meta Llama 3.3b Instruct model with 70B parameters.
262pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
263/// Meta Llama 3.1b Instruct model with 70B parameters.
264pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
265/// Meta Llama 3 Instruct model with 70B parameters.
266pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
267/// Hermes 3 Instruct model with 70B parameters.
268pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
269/// Deepseek v2.5 model.
270pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
271/// Qwen 2.5 model with 72B parameters.
272pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
273/// Meta Llama 3.2b Instruct model with 3B parameters.
274pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
275/// Qwen 2.5 Coder Instruct model with 32B parameters.
276pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
277/// Preview (latest) version of Qwen model with 32B parameters.
278pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
279/// Deepseek R1 Zero model.
280pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
281/// Deepseek R1 model.
282pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
283
284/// A Hyperbolic completion object.
285///
286/// For more information, see this link: <https://docs.hyperbolic.xyz/reference/create_chat_completion_v1_chat_completions_post>
287#[derive(Debug, Deserialize, Serialize)]
288pub struct CompletionResponse {
289    pub id: String,
290    pub object: String,
291    pub created: u64,
292    pub model: String,
293    pub choices: Vec<Choice>,
294    pub usage: Option<Usage>,
295}
296
297impl From<ApiErrorResponse> for CompletionError {
298    fn from(err: ApiErrorResponse) -> Self {
299        CompletionError::ProviderError(err.message)
300    }
301}
302
303impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
304    type Error = CompletionError;
305
306    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
307        let choice = response.choices.first().ok_or_else(|| {
308            CompletionError::ResponseError("Response contained no choices".to_owned())
309        })?;
310
311        let content = match &choice.message {
312            Message::Assistant {
313                content,
314                tool_calls,
315                ..
316            } => {
317                let mut content = content
318                    .iter()
319                    .map(|c| match c {
320                        AssistantContent::Text { text } => completion::AssistantContent::text(text),
321                        AssistantContent::Refusal { refusal } => {
322                            completion::AssistantContent::text(refusal)
323                        }
324                    })
325                    .collect::<Vec<_>>();
326
327                content.extend(
328                    tool_calls
329                        .iter()
330                        .map(|call| {
331                            completion::AssistantContent::tool_call(
332                                &call.id,
333                                &call.function.name,
334                                call.function.arguments.clone(),
335                            )
336                        })
337                        .collect::<Vec<_>>(),
338                );
339                Ok(content)
340            }
341            _ => Err(CompletionError::ResponseError(
342                "Response did not contain a valid message or tool call".into(),
343            )),
344        }?;
345
346        let choice = OneOrMany::many(content).map_err(|_| {
347            CompletionError::ResponseError(
348                "Response contained no message or tool call (empty)".to_owned(),
349            )
350        })?;
351
352        let usage = response
353            .usage
354            .as_ref()
355            .map(|usage| completion::Usage {
356                input_tokens: usage.prompt_tokens as u64,
357                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
358                total_tokens: usage.total_tokens as u64,
359            })
360            .unwrap_or_default();
361
362        Ok(completion::CompletionResponse {
363            choice,
364            usage,
365            raw_response: response,
366        })
367    }
368}
369
370#[derive(Debug, Deserialize, Serialize)]
371pub struct Choice {
372    pub index: usize,
373    pub message: Message,
374    pub finish_reason: String,
375}
376
377#[derive(Clone)]
378pub struct CompletionModel<T> {
379    client: Client<T>,
380    /// Name of the model (e.g.: deepseek-ai/DeepSeek-R1)
381    pub model: String,
382}
383
384impl<T> CompletionModel<T> {
385    pub fn new(client: Client<T>, model: &str) -> Self {
386        Self {
387            client,
388            model: model.to_string(),
389        }
390    }
391
392    pub(crate) fn create_completion_request(
393        &self,
394        completion_request: CompletionRequest,
395    ) -> Result<Value, CompletionError> {
396        if completion_request.tool_choice.is_some() {
397            tracing::warn!("WARNING: `tool_choice` not supported on Hyperbolic");
398        }
399        // Build up the order of messages (context, chat_history, prompt)
400        let mut partial_history = vec![];
401        if let Some(docs) = completion_request.normalized_documents() {
402            partial_history.push(docs);
403        }
404        partial_history.extend(completion_request.chat_history);
405
406        // Initialize full history with preamble (or empty if non-existent)
407        let mut full_history: Vec<Message> = completion_request
408            .preamble
409            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
410
411        // Convert and extend the rest of the history
412        full_history.extend(
413            partial_history
414                .into_iter()
415                .map(message::Message::try_into)
416                .collect::<Result<Vec<Vec<Message>>, _>>()?
417                .into_iter()
418                .flatten()
419                .collect::<Vec<_>>(),
420        );
421
422        let request = json!({
423            "model": self.model,
424            "messages": full_history,
425            "temperature": completion_request.temperature,
426        });
427
428        let request = if let Some(params) = completion_request.additional_params {
429            json_utils::merge(request, params)
430        } else {
431            request
432        };
433
434        Ok(request)
435    }
436}
437
438impl completion::CompletionModel for CompletionModel<reqwest::Client> {
439    type Response = CompletionResponse;
440    type StreamingResponse = openai::StreamingCompletionResponse;
441
442    #[cfg_attr(feature = "worker", worker::send)]
443    async fn completion(
444        &self,
445        completion_request: CompletionRequest,
446    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
447        let preamble = completion_request.preamble.clone();
448        let request = self.create_completion_request(completion_request)?;
449
450        let span = if tracing::Span::current().is_disabled() {
451            info_span!(
452                target: "rig::completions",
453                "chat",
454                gen_ai.operation.name = "chat",
455                gen_ai.provider.name = "hyperbolic",
456                gen_ai.request.model = self.model,
457                gen_ai.system_instructions = preamble,
458                gen_ai.response.id = tracing::field::Empty,
459                gen_ai.response.model = tracing::field::Empty,
460                gen_ai.usage.output_tokens = tracing::field::Empty,
461                gen_ai.usage.input_tokens = tracing::field::Empty,
462                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
463                gen_ai.output.messages = tracing::field::Empty,
464            )
465        } else {
466            tracing::Span::current()
467        };
468
469        let async_block = async move {
470            let response = self
471                .client
472                .reqwest_post("/v1/chat/completions")
473                .json(&request)
474                .send()
475                .await
476                .map_err(|e| http_client::Error::Instance(e.into()))?;
477
478            if response.status().is_success() {
479                match response
480                    .json::<ApiResponse<CompletionResponse>>()
481                    .await
482                    .map_err(|e| http_client::Error::Instance(e.into()))?
483                {
484                    ApiResponse::Ok(response) => {
485                        tracing::info!(target: "rig",
486                            "Hyperbolic completion token usage: {:?}",
487                            response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
488                        );
489
490                        response.try_into()
491                    }
492                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
493                }
494            } else {
495                Err(CompletionError::ProviderError(
496                    response
497                        .text()
498                        .await
499                        .map_err(|e| http_client::Error::Instance(e.into()))?,
500                ))
501            }
502        };
503
504        async_block.instrument(span).await
505    }
506
507    #[cfg_attr(feature = "worker", worker::send)]
508    async fn stream(
509        &self,
510        completion_request: CompletionRequest,
511    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
512        let preamble = completion_request.preamble.clone();
513        let mut request = self.create_completion_request(completion_request)?;
514
515        let span = if tracing::Span::current().is_disabled() {
516            info_span!(
517                target: "rig::completions",
518                "chat_streaming",
519                gen_ai.operation.name = "chat_streaming",
520                gen_ai.provider.name = "hyperbolic",
521                gen_ai.request.model = self.model,
522                gen_ai.system_instructions = preamble,
523                gen_ai.response.id = tracing::field::Empty,
524                gen_ai.response.model = tracing::field::Empty,
525                gen_ai.usage.output_tokens = tracing::field::Empty,
526                gen_ai.usage.input_tokens = tracing::field::Empty,
527                gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
528                gen_ai.output.messages = tracing::field::Empty,
529            )
530        } else {
531            tracing::Span::current()
532        };
533
534        merge_inplace(
535            &mut request,
536            json!({"stream": true, "stream_options": {"include_usage": true}}),
537        );
538
539        let builder = self
540            .client
541            .reqwest_post("/v1/chat/completions")
542            .json(&request);
543
544        send_compatible_streaming_request(builder)
545            .instrument(span)
546            .await
547    }
548}
549
550// =======================================
551// Hyperbolic Image Generation API
552// =======================================
553
554#[cfg(feature = "image")]
555pub use image_generation::*;
556
557#[cfg(feature = "image")]
558#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
559mod image_generation {
560    use super::{ApiResponse, Client};
561    use crate::client::ImageGenerationClient;
562    use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
563    use crate::json_utils::merge_inplace;
564    use crate::{http_client, image_generation};
565    use base64::Engine;
566    use base64::prelude::BASE64_STANDARD;
567    use serde::Deserialize;
568    use serde_json::json;
569
570    pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
571    pub const SD2: &str = "SD2";
572    pub const SD1_5: &str = "SD1.5";
573    pub const SSD: &str = "SSD";
574    pub const SDXL_TURBO: &str = "SDXL-turbo";
575    pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
576    pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
577
578    #[derive(Clone)]
579    pub struct ImageGenerationModel<T> {
580        client: Client<T>,
581        pub model: String,
582    }
583
584    impl<T> ImageGenerationModel<T> {
585        pub(crate) fn new(client: Client<T>, model: &str) -> ImageGenerationModel<T> {
586            Self {
587                client,
588                model: model.to_string(),
589            }
590        }
591    }
592
593    #[derive(Clone, Deserialize)]
594    pub struct Image {
595        image: String,
596    }
597
598    #[derive(Clone, Deserialize)]
599    pub struct ImageGenerationResponse {
600        images: Vec<Image>,
601    }
602
603    impl TryFrom<ImageGenerationResponse>
604        for image_generation::ImageGenerationResponse<ImageGenerationResponse>
605    {
606        type Error = ImageGenerationError;
607
608        fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
609            let data = BASE64_STANDARD
610                .decode(&value.images[0].image)
611                .expect("Could not decode image.");
612
613            Ok(Self {
614                image: data,
615                response: value,
616            })
617        }
618    }
619
620    impl image_generation::ImageGenerationModel for ImageGenerationModel<reqwest::Client> {
621        type Response = ImageGenerationResponse;
622
623        #[cfg_attr(feature = "worker", worker::send)]
624        async fn image_generation(
625            &self,
626            generation_request: ImageGenerationRequest,
627        ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
628        {
629            let mut request = json!({
630                "model_name": self.model,
631                "prompt": generation_request.prompt,
632                "height": generation_request.height,
633                "width": generation_request.width,
634            });
635
636            if let Some(params) = generation_request.additional_params {
637                merge_inplace(&mut request, params);
638            }
639
640            let response = self
641                .client
642                .reqwest_post("/v1/image/generation")
643                .json(&request)
644                .send()
645                .await
646                .map_err(|e| {
647                    ImageGenerationError::HttpError(http_client::Error::Instance(e.into()))
648                })?;
649
650            if !response.status().is_success() {
651                return Err(ImageGenerationError::ProviderError(format!(
652                    "{}: {}",
653                    response.status().as_str(),
654                    response.text().await.map_err(|e| {
655                        ImageGenerationError::HttpError(http_client::Error::Instance(e.into()))
656                    })?
657                )));
658            }
659
660            match response
661                .json::<ApiResponse<ImageGenerationResponse>>()
662                .await
663                .map_err(|e| {
664                    ImageGenerationError::HttpError(http_client::Error::Instance(e.into()))
665                })? {
666                ApiResponse::Ok(response) => response.try_into(),
667                ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
668            }
669        }
670    }
671
672    impl ImageGenerationClient for Client<reqwest::Client> {
673        type ImageGenerationModel = ImageGenerationModel<reqwest::Client>;
674
675        /// Create an image generation model with the given name.
676        ///
677        /// # Example
678        /// ```
679        /// use rig::providers::hyperbolic::{Client, self};
680        ///
681        /// // Initialize the Hyperbolic client
682        /// let hyperbolic = Client::new("your-hyperbolic-api-key");
683        ///
684        /// let llama_3_1_8b = hyperbolic.image_generation_model(hyperbolic::SSD);
685        /// ```
686        fn image_generation_model(&self, model: &str) -> ImageGenerationModel<reqwest::Client> {
687            ImageGenerationModel::new(self.clone(), model)
688        }
689    }
690}
691
692// ======================================
693// Hyperbolic Audio Generation API
694// ======================================
695#[cfg(feature = "audio")]
696pub use audio_generation::*;
697use tracing::{Instrument, info_span};
698
699#[cfg(feature = "audio")]
700#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
701mod audio_generation {
702    use super::{ApiResponse, Client};
703    use crate::audio_generation;
704    use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
705    use crate::client::AudioGenerationClient;
706    use crate::http_client;
707    use base64::Engine;
708    use base64::prelude::BASE64_STANDARD;
709    use serde::Deserialize;
710    use serde_json::json;
711
712    #[derive(Clone)]
713    pub struct AudioGenerationModel<T> {
714        client: Client<T>,
715        pub language: String,
716    }
717
718    impl<T> AudioGenerationModel<T> {
719        pub(crate) fn new(client: Client<T>, language: &str) -> AudioGenerationModel<T> {
720            Self {
721                client,
722                language: language.to_string(),
723            }
724        }
725    }
726
727    #[derive(Clone, Deserialize)]
728    pub struct AudioGenerationResponse {
729        audio: String,
730    }
731
732    impl TryFrom<AudioGenerationResponse>
733        for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
734    {
735        type Error = AudioGenerationError;
736
737        fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
738            let data = BASE64_STANDARD
739                .decode(&value.audio)
740                .expect("Could not decode audio.");
741
742            Ok(Self {
743                audio: data,
744                response: value,
745            })
746        }
747    }
748
749    impl audio_generation::AudioGenerationModel for AudioGenerationModel<reqwest::Client> {
750        type Response = AudioGenerationResponse;
751
752        #[cfg_attr(feature = "worker", worker::send)]
753        async fn audio_generation(
754            &self,
755            request: AudioGenerationRequest,
756        ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
757        {
758            let request = json!({
759                "language": self.language,
760                "speaker": request.voice,
761                "text": request.text,
762                "speed": request.speed
763            });
764
765            let response = self
766                .client
767                .reqwest_post("/v1/audio/generation")
768                .json(&request)
769                .send()
770                .await
771                .map_err(|e| {
772                    AudioGenerationError::HttpError(http_client::Error::Instance(e.into()))
773                })?;
774
775            if !response.status().is_success() {
776                return Err(AudioGenerationError::ProviderError(format!(
777                    "{}: {}",
778                    response.status(),
779                    response.text().await.map_err(|e| {
780                        AudioGenerationError::HttpError(http_client::Error::Instance(e.into()))
781                    })?
782                )));
783            }
784
785            match serde_json::from_str::<ApiResponse<AudioGenerationResponse>>(
786                &response.text().await.map_err(|e| {
787                    AudioGenerationError::HttpError(http_client::Error::Instance(e.into()))
788                })?,
789            )? {
790                ApiResponse::Ok(response) => response.try_into(),
791                ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
792            }
793        }
794    }
795    impl AudioGenerationClient for Client<reqwest::Client> {
796        type AudioGenerationModel = AudioGenerationModel<reqwest::Client>;
797
798        /// Create a completion model with the given name.
799        ///
800        /// # Example
801        /// ```
802        /// use rig::providers::hyperbolic::{Client, self};
803        ///
804        /// // Initialize the Hyperbolic client
805        /// let hyperbolic = Client::new("your-hyperbolic-api-key");
806        ///
807        /// let tts = hyperbolic.audio_generation_model("EN");
808        /// ```
809        fn audio_generation_model(&self, language: &str) -> AudioGenerationModel<reqwest::Client> {
810            AudioGenerationModel::new(self.clone(), language)
811        }
812    }
813}