rig/providers/
hyperbolic.rs

1//! Hyperbolic Inference API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::hyperbolic;
6//!
7//! let client = hyperbolic::Client::new("YOUR_API_KEY");
8//!
9//! let llama_3_1_8b = client.completion_model(hyperbolic::LLAMA_3_1_8B);
10//! ```
11use super::openai::{AssistantContent, send_compatible_streaming_request};
12
13use crate::client::{
14    ClientBuilderError, CompletionClient, ProviderClient, VerifyClient, VerifyError,
15};
16use crate::json_utils::merge_inplace;
17use crate::message;
18use crate::streaming::StreamingCompletionResponse;
19
20use crate::impl_conversion_traits;
21use crate::providers::openai;
22use crate::{
23    OneOrMany,
24    completion::{self, CompletionError, CompletionRequest},
25    json_utils,
26    providers::openai::Message,
27};
28use serde::{Deserialize, Serialize};
29use serde_json::{Value, json};
30
31// ================================================================
32// Main Hyperbolic Client
33// ================================================================
34const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz";
35
36pub struct ClientBuilder<'a> {
37    api_key: &'a str,
38    base_url: &'a str,
39    http_client: Option<reqwest::Client>,
40}
41
42impl<'a> ClientBuilder<'a> {
43    pub fn new(api_key: &'a str) -> Self {
44        Self {
45            api_key,
46            base_url: HYPERBOLIC_API_BASE_URL,
47            http_client: None,
48        }
49    }
50
51    pub fn base_url(mut self, base_url: &'a str) -> Self {
52        self.base_url = base_url;
53        self
54    }
55
56    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
57        self.http_client = Some(client);
58        self
59    }
60
61    pub fn build(self) -> Result<Client, ClientBuilderError> {
62        let http_client = if let Some(http_client) = self.http_client {
63            http_client
64        } else {
65            reqwest::Client::builder().build()?
66        };
67
68        Ok(Client {
69            base_url: self.base_url.to_string(),
70            api_key: self.api_key.to_string(),
71            http_client,
72        })
73    }
74}
75
76#[derive(Clone)]
77pub struct Client {
78    base_url: String,
79    api_key: String,
80    http_client: reqwest::Client,
81}
82
83impl std::fmt::Debug for Client {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        f.debug_struct("Client")
86            .field("base_url", &self.base_url)
87            .field("http_client", &self.http_client)
88            .field("api_key", &"<REDACTED>")
89            .finish()
90    }
91}
92
93impl Client {
94    /// Create a new Hyperbolic client builder.
95    ///
96    /// # Example
97    /// ```
98    /// use rig::providers::hyperbolic::{ClientBuilder, self};
99    ///
100    /// // Initialize the Hyperbolic client
101    /// let hyperbolic = Client::builder("your-hyperbolic-api-key")
102    ///    .build()
103    /// ```
104    pub fn builder(api_key: &str) -> ClientBuilder<'_> {
105        ClientBuilder::new(api_key)
106    }
107
108    /// Create a new Hyperbolic client. For more control, use the `builder` method.
109    ///
110    /// # Panics
111    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
112    pub fn new(api_key: &str) -> Self {
113        Self::builder(api_key)
114            .build()
115            .expect("Hyperbolic client should build")
116    }
117
118    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
119        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
120        self.http_client.post(url).bearer_auth(&self.api_key)
121    }
122
123    pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder {
124        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
125        self.http_client.get(url).bearer_auth(&self.api_key)
126    }
127}
128
129impl ProviderClient for Client {
130    /// Create a new Hyperbolic client from the `HYPERBOLIC_API_KEY` environment variable.
131    /// Panics if the environment variable is not set.
132    fn from_env() -> Self {
133        let api_key = std::env::var("HYPERBOLIC_API_KEY").expect("HYPERBOLIC_API_KEY not set");
134        Self::new(&api_key)
135    }
136
137    fn from_val(input: crate::client::ProviderValue) -> Self {
138        let crate::client::ProviderValue::Simple(api_key) = input else {
139            panic!("Incorrect provider value type")
140        };
141        Self::new(&api_key)
142    }
143}
144
145impl CompletionClient for Client {
146    type CompletionModel = CompletionModel;
147
148    /// Create a completion model with the given name.
149    ///
150    /// # Example
151    /// ```
152    /// use rig::providers::hyperbolic::{Client, self};
153    ///
154    /// // Initialize the Hyperbolic client
155    /// let hyperbolic = Client::new("your-hyperbolic-api-key");
156    ///
157    /// let llama_3_1_8b = hyperbolic.completion_model(hyperbolic::LLAMA_3_1_8B);
158    /// ```
159    fn completion_model(&self, model: &str) -> CompletionModel {
160        CompletionModel::new(self.clone(), model)
161    }
162}
163
164impl VerifyClient for Client {
165    #[cfg_attr(feature = "worker", worker::send)]
166    async fn verify(&self) -> Result<(), VerifyError> {
167        let response = self.get("/users/me").send().await?;
168        match response.status() {
169            reqwest::StatusCode::OK => Ok(()),
170            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
171            reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
172                Err(VerifyError::ProviderError(response.text().await?))
173            }
174            _ => {
175                response.error_for_status()?;
176                Ok(())
177            }
178        }
179    }
180}
181
182impl_conversion_traits!(
183    AsEmbeddings,
184    AsTranscription for Client
185);
186
187#[derive(Debug, Deserialize)]
188struct ApiErrorResponse {
189    message: String,
190}
191
192#[derive(Debug, Deserialize)]
193#[serde(untagged)]
194enum ApiResponse<T> {
195    Ok(T),
196    Err(ApiErrorResponse),
197}
198
199#[derive(Debug, Deserialize)]
200pub struct EmbeddingData {
201    pub object: String,
202    pub embedding: Vec<f64>,
203    pub index: usize,
204}
205
206#[derive(Clone, Debug, Deserialize, Serialize)]
207pub struct Usage {
208    pub prompt_tokens: usize,
209    pub total_tokens: usize,
210}
211
212impl std::fmt::Display for Usage {
213    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214        write!(
215            f,
216            "Prompt tokens: {} Total tokens: {}",
217            self.prompt_tokens, self.total_tokens
218        )
219    }
220}
221
222// ================================================================
223// Hyperbolic Completion API
224// ================================================================
225/// Meta Llama 3.1b Instruct model with 8B parameters.
226pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
227/// Meta Llama 3.3b Instruct model with 70B parameters.
228pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
229/// Meta Llama 3.1b Instruct model with 70B parameters.
230pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
231/// Meta Llama 3 Instruct model with 70B parameters.
232pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
233/// Hermes 3 Instruct model with 70B parameters.
234pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
235/// Deepseek v2.5 model.
236pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
237/// Qwen 2.5 model with 72B parameters.
238pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
239/// Meta Llama 3.2b Instruct model with 3B parameters.
240pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
241/// Qwen 2.5 Coder Instruct model with 32B parameters.
242pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
243/// Preview (latest) version of Qwen model with 32B parameters.
244pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
245/// Deepseek R1 Zero model.
246pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
247/// Deepseek R1 model.
248pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
249
250/// A Hyperbolic completion object.
251///
252/// For more information, see this link: <https://docs.hyperbolic.xyz/reference/create_chat_completion_v1_chat_completions_post>
253#[derive(Debug, Deserialize, Serialize)]
254pub struct CompletionResponse {
255    pub id: String,
256    pub object: String,
257    pub created: u64,
258    pub model: String,
259    pub choices: Vec<Choice>,
260    pub usage: Option<Usage>,
261}
262
263impl From<ApiErrorResponse> for CompletionError {
264    fn from(err: ApiErrorResponse) -> Self {
265        CompletionError::ProviderError(err.message)
266    }
267}
268
269impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
270    type Error = CompletionError;
271
272    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
273        let choice = response.choices.first().ok_or_else(|| {
274            CompletionError::ResponseError("Response contained no choices".to_owned())
275        })?;
276
277        let content = match &choice.message {
278            Message::Assistant {
279                content,
280                tool_calls,
281                ..
282            } => {
283                let mut content = content
284                    .iter()
285                    .map(|c| match c {
286                        AssistantContent::Text { text } => completion::AssistantContent::text(text),
287                        AssistantContent::Refusal { refusal } => {
288                            completion::AssistantContent::text(refusal)
289                        }
290                    })
291                    .collect::<Vec<_>>();
292
293                content.extend(
294                    tool_calls
295                        .iter()
296                        .map(|call| {
297                            completion::AssistantContent::tool_call(
298                                &call.id,
299                                &call.function.name,
300                                call.function.arguments.clone(),
301                            )
302                        })
303                        .collect::<Vec<_>>(),
304                );
305                Ok(content)
306            }
307            _ => Err(CompletionError::ResponseError(
308                "Response did not contain a valid message or tool call".into(),
309            )),
310        }?;
311
312        let choice = OneOrMany::many(content).map_err(|_| {
313            CompletionError::ResponseError(
314                "Response contained no message or tool call (empty)".to_owned(),
315            )
316        })?;
317
318        let usage = response
319            .usage
320            .as_ref()
321            .map(|usage| completion::Usage {
322                input_tokens: usage.prompt_tokens as u64,
323                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
324                total_tokens: usage.total_tokens as u64,
325            })
326            .unwrap_or_default();
327
328        Ok(completion::CompletionResponse {
329            choice,
330            usage,
331            raw_response: response,
332        })
333    }
334}
335
336#[derive(Debug, Deserialize, Serialize)]
337pub struct Choice {
338    pub index: usize,
339    pub message: Message,
340    pub finish_reason: String,
341}
342
343#[derive(Clone)]
344pub struct CompletionModel {
345    client: Client,
346    /// Name of the model (e.g.: deepseek-ai/DeepSeek-R1)
347    pub model: String,
348}
349
350impl CompletionModel {
351    pub(crate) fn create_completion_request(
352        &self,
353        completion_request: CompletionRequest,
354    ) -> Result<Value, CompletionError> {
355        // Build up the order of messages (context, chat_history, prompt)
356        let mut partial_history = vec![];
357        if let Some(docs) = completion_request.normalized_documents() {
358            partial_history.push(docs);
359        }
360        partial_history.extend(completion_request.chat_history);
361
362        // Initialize full history with preamble (or empty if non-existent)
363        let mut full_history: Vec<Message> = completion_request
364            .preamble
365            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
366
367        // Convert and extend the rest of the history
368        full_history.extend(
369            partial_history
370                .into_iter()
371                .map(message::Message::try_into)
372                .collect::<Result<Vec<Vec<Message>>, _>>()?
373                .into_iter()
374                .flatten()
375                .collect::<Vec<_>>(),
376        );
377
378        let request = json!({
379            "model": self.model,
380            "messages": full_history,
381            "temperature": completion_request.temperature,
382        });
383
384        let request = if let Some(params) = completion_request.additional_params {
385            json_utils::merge(request, params)
386        } else {
387            request
388        };
389
390        Ok(request)
391    }
392}
393
394impl CompletionModel {
395    pub fn new(client: Client, model: &str) -> Self {
396        Self {
397            client,
398            model: model.to_string(),
399        }
400    }
401}
402
403impl completion::CompletionModel for CompletionModel {
404    type Response = CompletionResponse;
405    type StreamingResponse = openai::StreamingCompletionResponse;
406
407    #[cfg_attr(feature = "worker", worker::send)]
408    async fn completion(
409        &self,
410        completion_request: CompletionRequest,
411    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
412        let request = self.create_completion_request(completion_request)?;
413
414        let response = self
415            .client
416            .post("/v1/chat/completions")
417            .json(&request)
418            .send()
419            .await?;
420
421        if response.status().is_success() {
422            match response.json::<ApiResponse<CompletionResponse>>().await? {
423                ApiResponse::Ok(response) => {
424                    tracing::info!(target: "rig",
425                        "Hyperbolic completion token usage: {:?}",
426                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
427                    );
428
429                    response.try_into()
430                }
431                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
432            }
433        } else {
434            Err(CompletionError::ProviderError(response.text().await?))
435        }
436    }
437
438    #[cfg_attr(feature = "worker", worker::send)]
439    async fn stream(
440        &self,
441        completion_request: CompletionRequest,
442    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
443        let mut request = self.create_completion_request(completion_request)?;
444
445        merge_inplace(
446            &mut request,
447            json!({"stream": true, "stream_options": {"include_usage": true}}),
448        );
449
450        let builder = self.client.post("/v1/chat/completions").json(&request);
451
452        send_compatible_streaming_request(builder).await
453    }
454}
455
456// =======================================
457// Hyperbolic Image Generation API
458// =======================================
459
460#[cfg(feature = "image")]
461pub use image_generation::*;
462
463#[cfg(feature = "image")]
464#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
465mod image_generation {
466    use super::{ApiResponse, Client};
467    use crate::client::ImageGenerationClient;
468    use crate::image_generation;
469    use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
470    use crate::json_utils::merge_inplace;
471    use base64::Engine;
472    use base64::prelude::BASE64_STANDARD;
473    use serde::Deserialize;
474    use serde_json::json;
475
476    pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
477    pub const SD2: &str = "SD2";
478    pub const SD1_5: &str = "SD1.5";
479    pub const SSD: &str = "SSD";
480    pub const SDXL_TURBO: &str = "SDXL-turbo";
481    pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
482    pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
483
484    #[derive(Clone)]
485    pub struct ImageGenerationModel {
486        client: Client,
487        pub model: String,
488    }
489
490    impl ImageGenerationModel {
491        pub(crate) fn new(client: Client, model: &str) -> ImageGenerationModel {
492            Self {
493                client,
494                model: model.to_string(),
495            }
496        }
497    }
498
499    #[derive(Clone, Deserialize)]
500    pub struct Image {
501        image: String,
502    }
503
504    #[derive(Clone, Deserialize)]
505    pub struct ImageGenerationResponse {
506        images: Vec<Image>,
507    }
508
509    impl TryFrom<ImageGenerationResponse>
510        for image_generation::ImageGenerationResponse<ImageGenerationResponse>
511    {
512        type Error = ImageGenerationError;
513
514        fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
515            let data = BASE64_STANDARD
516                .decode(&value.images[0].image)
517                .expect("Could not decode image.");
518
519            Ok(Self {
520                image: data,
521                response: value,
522            })
523        }
524    }
525
526    impl image_generation::ImageGenerationModel for ImageGenerationModel {
527        type Response = ImageGenerationResponse;
528
529        #[cfg_attr(feature = "worker", worker::send)]
530        async fn image_generation(
531            &self,
532            generation_request: ImageGenerationRequest,
533        ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
534        {
535            let mut request = json!({
536                "model_name": self.model,
537                "prompt": generation_request.prompt,
538                "height": generation_request.height,
539                "width": generation_request.width,
540            });
541
542            if let Some(params) = generation_request.additional_params {
543                merge_inplace(&mut request, params);
544            }
545
546            let response = self
547                .client
548                .post("/v1/image/generation")
549                .json(&request)
550                .send()
551                .await?;
552
553            if !response.status().is_success() {
554                return Err(ImageGenerationError::ProviderError(format!(
555                    "{}: {}",
556                    response.status().as_str(),
557                    response.text().await?
558                )));
559            }
560
561            match response
562                .json::<ApiResponse<ImageGenerationResponse>>()
563                .await?
564            {
565                ApiResponse::Ok(response) => response.try_into(),
566                ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
567            }
568        }
569    }
570
571    impl ImageGenerationClient for Client {
572        type ImageGenerationModel = ImageGenerationModel;
573
574        /// Create an image generation model with the given name.
575        ///
576        /// # Example
577        /// ```
578        /// use rig::providers::hyperbolic::{Client, self};
579        ///
580        /// // Initialize the Hyperbolic client
581        /// let hyperbolic = Client::new("your-hyperbolic-api-key");
582        ///
583        /// let llama_3_1_8b = hyperbolic.image_generation_model(hyperbolic::SSD);
584        /// ```
585        fn image_generation_model(&self, model: &str) -> ImageGenerationModel {
586            ImageGenerationModel::new(self.clone(), model)
587        }
588    }
589}
590
591// ======================================
592// Hyperbolic Audio Generation API
593// ======================================
594#[cfg(feature = "audio")]
595pub use audio_generation::*;
596
597#[cfg(feature = "audio")]
598#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
599mod audio_generation {
600    use super::{ApiResponse, Client};
601    use crate::audio_generation;
602    use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
603    use crate::client::AudioGenerationClient;
604    use base64::Engine;
605    use base64::prelude::BASE64_STANDARD;
606    use serde::Deserialize;
607    use serde_json::json;
608
609    #[derive(Clone)]
610    pub struct AudioGenerationModel {
611        client: Client,
612        pub language: String,
613    }
614
615    impl AudioGenerationModel {
616        pub(crate) fn new(client: Client, language: &str) -> AudioGenerationModel {
617            Self {
618                client,
619                language: language.to_string(),
620            }
621        }
622    }
623
624    #[derive(Clone, Deserialize)]
625    pub struct AudioGenerationResponse {
626        audio: String,
627    }
628
629    impl TryFrom<AudioGenerationResponse>
630        for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
631    {
632        type Error = AudioGenerationError;
633
634        fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
635            let data = BASE64_STANDARD
636                .decode(&value.audio)
637                .expect("Could not decode audio.");
638
639            Ok(Self {
640                audio: data,
641                response: value,
642            })
643        }
644    }
645
646    impl audio_generation::AudioGenerationModel for AudioGenerationModel {
647        type Response = AudioGenerationResponse;
648
649        #[cfg_attr(feature = "worker", worker::send)]
650        async fn audio_generation(
651            &self,
652            request: AudioGenerationRequest,
653        ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
654        {
655            let request = json!({
656                "language": self.language,
657                "speaker": request.voice,
658                "text": request.text,
659                "speed": request.speed
660            });
661
662            let response = self
663                .client
664                .post("/v1/audio/generation")
665                .json(&request)
666                .send()
667                .await?;
668
669            if !response.status().is_success() {
670                return Err(AudioGenerationError::ProviderError(format!(
671                    "{}: {}",
672                    response.status(),
673                    response.text().await?
674                )));
675            }
676
677            match serde_json::from_str::<ApiResponse<AudioGenerationResponse>>(
678                &response.text().await?,
679            )? {
680                ApiResponse::Ok(response) => response.try_into(),
681                ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
682            }
683        }
684    }
685    impl AudioGenerationClient for Client {
686        type AudioGenerationModel = AudioGenerationModel;
687
688        /// Create a completion model with the given name.
689        ///
690        /// # Example
691        /// ```
692        /// use rig::providers::hyperbolic::{Client, self};
693        ///
694        /// // Initialize the Hyperbolic client
695        /// let hyperbolic = Client::new("your-hyperbolic-api-key");
696        ///
697        /// let tts = hyperbolic.audio_generation_model("EN");
698        /// ```
699        fn audio_generation_model(&self, language: &str) -> AudioGenerationModel {
700            AudioGenerationModel::new(self.clone(), language)
701        }
702    }
703}