Skip to main content

rig_core/providers/
hyperbolic.rs

1//! Hyperbolic Inference API client and Rig integration
2//!
3//! # Example
4//! ```no_run
5//! use rig_core::{client::CompletionClient, providers::hyperbolic};
6//!
7//! # fn run() -> Result<(), Box<dyn std::error::Error>> {
8//! let client = hyperbolic::Client::new("YOUR_API_KEY")?;
9//!
10//! let llama_3_1_8b = client.completion_model(hyperbolic::LLAMA_3_1_8B);
11//! # Ok(())
12//! # }
13//! ```
14use super::openai::{AssistantContent, send_compatible_streaming_request};
15
16use crate::client::{self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder};
17use crate::client::{BearerAuth, ProviderClient};
18use crate::http_client::{self, HttpClientExt};
19use crate::streaming::StreamingCompletionResponse;
20
21use crate::providers::openai;
22use crate::{
23    OneOrMany,
24    completion::{self, CompletionError, CompletionRequest},
25    json_utils,
26    providers::openai::Message,
27};
28use serde::{Deserialize, Serialize};
29
30// ================================================================
31// Main Hyperbolic Client
32// ================================================================
33const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz";
34
35#[derive(Debug, Default, Clone, Copy)]
36pub struct HyperbolicExt;
37#[derive(Debug, Default, Clone, Copy)]
38pub struct HyperbolicBuilder;
39
40type HyperbolicApiKey = BearerAuth;
41
42impl Provider for HyperbolicExt {
43    type Builder = HyperbolicBuilder;
44
45    const VERIFY_PATH: &'static str = "/models";
46}
47
48impl<H> Capabilities<H> for HyperbolicExt {
49    type Completion = Capable<CompletionModel<H>>;
50    type Embeddings = Nothing;
51    type Transcription = Nothing;
52    type ModelListing = Nothing;
53    #[cfg(feature = "image")]
54    type ImageGeneration = Capable<ImageGenerationModel<H>>;
55    #[cfg(feature = "audio")]
56    type AudioGeneration = Capable<AudioGenerationModel<H>>;
57    type Rerank = Nothing;
58}
59
60impl DebugExt for HyperbolicExt {}
61
62impl ProviderBuilder for HyperbolicBuilder {
63    type Extension<H>
64        = HyperbolicExt
65    where
66        H: HttpClientExt;
67    type ApiKey = HyperbolicApiKey;
68
69    const BASE_URL: &'static str = HYPERBOLIC_API_BASE_URL;
70
71    fn build<H>(
72        _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
73    ) -> http_client::Result<Self::Extension<H>>
74    where
75        H: HttpClientExt,
76    {
77        Ok(HyperbolicExt)
78    }
79}
80
81pub type Client<H = reqwest::Client> = client::Client<HyperbolicExt, H>;
82pub type ClientBuilder<H = crate::markers::Missing> =
83    client::ClientBuilder<HyperbolicBuilder, String, H>;
84
85impl ProviderClient for Client {
86    type Input = HyperbolicApiKey;
87    type Error = crate::client::ProviderClientError;
88
89    /// Create a new Hyperbolic client from the `HYPERBOLIC_API_KEY` environment variable.
90    fn from_env() -> Result<Self, Self::Error> {
91        let api_key = crate::client::required_env_var("HYPERBOLIC_API_KEY")?;
92        Self::new(&api_key).map_err(Into::into)
93    }
94
95    fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
96        Self::new(input).map_err(Into::into)
97    }
98}
99
100#[derive(Debug, Deserialize)]
101struct ApiErrorResponse {
102    message: String,
103}
104
105#[derive(Debug, Deserialize)]
106#[serde(untagged)]
107enum ApiResponse<T> {
108    Ok(T),
109    Err(ApiErrorResponse),
110}
111
112#[derive(Debug, Deserialize)]
113pub struct EmbeddingData {
114    pub object: String,
115    pub embedding: Vec<f64>,
116    pub index: usize,
117}
118
119#[derive(Clone, Debug, Deserialize, Serialize)]
120pub struct Usage {
121    pub prompt_tokens: usize,
122    pub total_tokens: usize,
123}
124
125impl std::fmt::Display for Usage {
126    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127        write!(
128            f,
129            "Prompt tokens: {} Total tokens: {}",
130            self.prompt_tokens, self.total_tokens
131        )
132    }
133}
134
135// ================================================================
136// Hyperbolic Completion API
137// ================================================================
138
139/// Meta Llama 3.1b Instruct model with 8B parameters.
140pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
141/// Meta Llama 3.3b Instruct model with 70B parameters.
142pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
143/// Meta Llama 3.1b Instruct model with 70B parameters.
144pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
145/// Meta Llama 3 Instruct model with 70B parameters.
146pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
147/// Hermes 3 Instruct model with 70B parameters.
148pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
149/// Deepseek v2.5 model.
150pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
151/// Qwen 2.5 model with 72B parameters.
152pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
153/// Meta Llama 3.2b Instruct model with 3B parameters.
154pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
155/// Qwen 2.5 Coder Instruct model with 32B parameters.
156pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
157/// Preview (latest) version of Qwen model with 32B parameters.
158pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
159/// Deepseek R1 Zero model.
160pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
161/// Deepseek R1 model.
162pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
163
164/// A Hyperbolic completion object.
165///
166/// For more information, see this link: <https://docs.hyperbolic.xyz/reference/create_chat_completion_v1_chat_completions_post>
167#[derive(Debug, Deserialize, Serialize)]
168pub struct CompletionResponse {
169    pub id: String,
170    pub object: String,
171    pub created: u64,
172    pub model: String,
173    pub choices: Vec<Choice>,
174    pub usage: Option<Usage>,
175}
176
177impl From<ApiErrorResponse> for CompletionError {
178    fn from(err: ApiErrorResponse) -> Self {
179        CompletionError::ProviderError(err.message)
180    }
181}
182
183impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
184    type Error = CompletionError;
185
186    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
187        let choice = response.choices.first().ok_or_else(|| {
188            CompletionError::ResponseError("Response contained no choices".to_owned())
189        })?;
190
191        let content = match &choice.message {
192            Message::Assistant {
193                content,
194                tool_calls,
195                ..
196            } => {
197                let mut content = content
198                    .iter()
199                    .map(|c| match c {
200                        AssistantContent::Text { text } => completion::AssistantContent::text(text),
201                        AssistantContent::Refusal { refusal } => {
202                            completion::AssistantContent::text(refusal)
203                        }
204                    })
205                    .collect::<Vec<_>>();
206
207                content.extend(
208                    tool_calls
209                        .iter()
210                        .map(|call| {
211                            completion::AssistantContent::tool_call(
212                                &call.id,
213                                &call.function.name,
214                                call.function.arguments.clone(),
215                            )
216                        })
217                        .collect::<Vec<_>>(),
218                );
219                Ok(content)
220            }
221            _ => Err(CompletionError::ResponseError(
222                "Response did not contain a valid message or tool call".into(),
223            )),
224        }?;
225
226        let choice = OneOrMany::many(content).map_err(|_| {
227            CompletionError::ResponseError(
228                "Response contained no message or tool call (empty)".to_owned(),
229            )
230        })?;
231
232        let usage = response
233            .usage
234            .as_ref()
235            .map(|usage| completion::Usage {
236                input_tokens: usage.prompt_tokens as u64,
237                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
238                total_tokens: usage.total_tokens as u64,
239                cached_input_tokens: 0,
240                cache_creation_input_tokens: 0,
241                tool_use_prompt_tokens: 0,
242                reasoning_tokens: 0,
243            })
244            .unwrap_or_default();
245
246        Ok(completion::CompletionResponse {
247            choice,
248            usage,
249            raw_response: response,
250            message_id: None,
251        })
252    }
253}
254
255#[derive(Debug, Deserialize, Serialize)]
256pub struct Choice {
257    pub index: usize,
258    pub message: Message,
259    pub finish_reason: String,
260}
261
262#[derive(Debug, Serialize, Deserialize)]
263pub(super) struct HyperbolicCompletionRequest {
264    model: String,
265    pub messages: Vec<Message>,
266    #[serde(skip_serializing_if = "Option::is_none")]
267    temperature: Option<f64>,
268    #[serde(flatten, skip_serializing_if = "Option::is_none")]
269    pub additional_params: Option<serde_json::Value>,
270}
271
272impl TryFrom<(&str, CompletionRequest)> for HyperbolicCompletionRequest {
273    type Error = CompletionError;
274
275    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
276        let chat_history = req.chat_history_with_documents();
277        if req.output_schema.is_some() {
278            tracing::warn!("Structured outputs currently not supported for Hyperbolic");
279        }
280
281        let model = req.model.clone().unwrap_or_else(|| model.to_string());
282        if req.tool_choice.is_some() {
283            tracing::warn!("WARNING: `tool_choice` not supported on Hyperbolic");
284        }
285
286        if !req.tools.is_empty() {
287            tracing::warn!("WARNING: `tools` not supported on Hyperbolic");
288        }
289
290        let mut full_history: Vec<Message> = match &req.preamble {
291            Some(preamble) => vec![Message::system(preamble)],
292            None => vec![],
293        };
294
295        let chat_history: Vec<Message> = chat_history
296            .into_iter()
297            .map(|message| message.try_into())
298            .collect::<Result<Vec<Vec<Message>>, _>>()?
299            .into_iter()
300            .flatten()
301            .collect();
302
303        full_history.extend(chat_history);
304
305        Ok(Self {
306            model: model.to_string(),
307            messages: full_history,
308            temperature: req.temperature,
309            additional_params: req.additional_params,
310        })
311    }
312}
313
314#[derive(Clone)]
315pub struct CompletionModel<T = reqwest::Client> {
316    client: Client<T>,
317    /// Name of the model (e.g.: deepseek-ai/DeepSeek-R1)
318    pub model: String,
319}
320
321impl<T> CompletionModel<T> {
322    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
323        Self {
324            client,
325            model: model.into(),
326        }
327    }
328
329    pub fn with_model(client: Client<T>, model: &str) -> Self {
330        Self {
331            client,
332            model: model.into(),
333        }
334    }
335}
336
337impl<T> completion::CompletionModel for CompletionModel<T>
338where
339    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
340{
341    type Response = CompletionResponse;
342    type StreamingResponse = openai::StreamingCompletionResponse;
343
344    type Client = Client<T>;
345
346    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
347        Self::new(client.clone(), model)
348    }
349
350    async fn completion(
351        &self,
352        completion_request: CompletionRequest,
353    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
354        let span = if tracing::Span::current().is_disabled() {
355            info_span!(
356                target: "rig::completions",
357                "chat",
358                gen_ai.operation.name = "chat",
359                gen_ai.provider.name = "hyperbolic",
360                gen_ai.request.model = self.model,
361                gen_ai.system_instructions = tracing::field::Empty,
362                gen_ai.response.id = tracing::field::Empty,
363                gen_ai.response.model = tracing::field::Empty,
364                gen_ai.usage.output_tokens = tracing::field::Empty,
365                gen_ai.usage.input_tokens = tracing::field::Empty,
366                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
367            )
368        } else {
369            tracing::Span::current()
370        };
371
372        span.record("gen_ai.system_instructions", &completion_request.preamble);
373        let request =
374            HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
375
376        if tracing::enabled!(tracing::Level::TRACE) {
377            tracing::trace!(target: "rig::completions",
378                "Hyperbolic completion request: {}",
379                serde_json::to_string_pretty(&request)?
380            );
381        }
382
383        let body = serde_json::to_vec(&request)?;
384
385        let req = self
386            .client
387            .post("/v1/chat/completions")?
388            .body(body)
389            .map_err(http_client::Error::from)?;
390
391        let async_block = async move {
392            let response = self.client.send::<_, bytes::Bytes>(req).await?;
393
394            let status = response.status();
395            let response_body = response.into_body().into_future().await?.to_vec();
396
397            if status.is_success() {
398                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
399                    ApiResponse::Ok(response) => {
400                        if tracing::enabled!(tracing::Level::TRACE) {
401                            tracing::trace!(target: "rig::completions",
402                                "Hyperbolic completion response: {}",
403                                serde_json::to_string_pretty(&response)?
404                            );
405                        }
406
407                        response.try_into()
408                    }
409                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
410                }
411            } else {
412                Err(CompletionError::ProviderError(
413                    String::from_utf8_lossy(&response_body).to_string(),
414                ))
415            }
416        };
417
418        async_block.instrument(span).await
419    }
420
421    async fn stream(
422        &self,
423        completion_request: CompletionRequest,
424    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
425        let span = if tracing::Span::current().is_disabled() {
426            info_span!(
427                target: "rig::completions",
428                "chat_streaming",
429                gen_ai.operation.name = "chat_streaming",
430                gen_ai.provider.name = "hyperbolic",
431                gen_ai.request.model = self.model,
432                gen_ai.system_instructions = tracing::field::Empty,
433                gen_ai.response.id = tracing::field::Empty,
434                gen_ai.response.model = tracing::field::Empty,
435                gen_ai.usage.output_tokens = tracing::field::Empty,
436                gen_ai.usage.input_tokens = tracing::field::Empty,
437                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
438            )
439        } else {
440            tracing::Span::current()
441        };
442
443        span.record("gen_ai.system_instructions", &completion_request.preamble);
444        let mut request =
445            HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
446
447        let params = json_utils::merge(
448            request.additional_params.unwrap_or(serde_json::json!({})),
449            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
450        );
451
452        request.additional_params = Some(params);
453
454        if tracing::enabled!(tracing::Level::TRACE) {
455            tracing::trace!(target: "rig::completions",
456                "Hyperbolic streaming completion request: {}",
457                serde_json::to_string_pretty(&request)?
458            );
459        }
460
461        let body = serde_json::to_vec(&request)?;
462
463        let req = self
464            .client
465            .post("/v1/chat/completions")?
466            .body(body)
467            .map_err(http_client::Error::from)?;
468
469        send_compatible_streaming_request(self.client.clone(), req)
470            .instrument(span)
471            .await
472    }
473}
474
475// =======================================
476// Hyperbolic Image Generation API
477// =======================================
478
479#[cfg(feature = "image")]
480pub use image_generation::*;
481
482#[cfg(feature = "image")]
483#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
484mod image_generation {
485    use super::{ApiResponse, Client};
486    use crate::http_client::HttpClientExt;
487    use crate::image_generation;
488    use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
489    use crate::json_utils::merge_inplace;
490    use base64::Engine;
491    use base64::prelude::BASE64_STANDARD;
492    use serde::Deserialize;
493    use serde_json::json;
494
495    pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
496    pub const SD2: &str = "SD2";
497    pub const SD1_5: &str = "SD1.5";
498    pub const SSD: &str = "SSD";
499    pub const SDXL_TURBO: &str = "SDXL-turbo";
500    pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
501    pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
502
503    #[derive(Clone)]
504    pub struct ImageGenerationModel<T> {
505        client: Client<T>,
506        pub model: String,
507    }
508
509    impl<T> ImageGenerationModel<T> {
510        pub(crate) fn new(client: Client<T>, model: impl Into<String>) -> Self {
511            Self {
512                client,
513                model: model.into(),
514            }
515        }
516
517        pub fn with_model(client: Client<T>, model: &str) -> Self {
518            Self {
519                client,
520                model: model.into(),
521            }
522        }
523    }
524
525    #[derive(Clone, Deserialize)]
526    pub struct Image {
527        image: String,
528    }
529
530    #[derive(Clone, Deserialize)]
531    pub struct ImageGenerationResponse {
532        images: Vec<Image>,
533    }
534
535    impl TryFrom<ImageGenerationResponse>
536        for image_generation::ImageGenerationResponse<ImageGenerationResponse>
537    {
538        type Error = ImageGenerationError;
539
540        fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
541            let image = value
542                .images
543                .first()
544                .ok_or_else(|| ImageGenerationError::ResponseError("missing image data".into()))?;
545            let data = BASE64_STANDARD
546                .decode(&image.image)
547                .map_err(|err| ImageGenerationError::ResponseError(err.to_string()))?;
548
549            Ok(Self {
550                image: data,
551                response: value,
552            })
553        }
554    }
555
556    impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
557    where
558        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
559    {
560        type Response = ImageGenerationResponse;
561
562        type Client = Client<T>;
563
564        fn make(client: &Self::Client, model: impl Into<String>) -> Self {
565            Self::new(client.clone(), model)
566        }
567
568        async fn image_generation(
569            &self,
570            generation_request: ImageGenerationRequest,
571        ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
572        {
573            let mut request = json!({
574                "model_name": self.model,
575                "prompt": generation_request.prompt,
576                "height": generation_request.height,
577                "width": generation_request.width,
578            });
579
580            if let Some(params) = generation_request.additional_params {
581                merge_inplace(&mut request, params);
582            }
583
584            let body = serde_json::to_vec(&request)?;
585
586            let request = self
587                .client
588                .post("/v1/image/generation")?
589                .header("Content-Type", "application/json")
590                .body(body)
591                .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
592
593            let response = self.client.send::<_, bytes::Bytes>(request).await?;
594
595            let status = response.status();
596            let response_body = response.into_body().into_future().await?.to_vec();
597
598            if !status.is_success() {
599                return Err(ImageGenerationError::ProviderError(format!(
600                    "{status}: {}",
601                    String::from_utf8_lossy(&response_body)
602                )));
603            }
604
605            match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
606                ApiResponse::Ok(response) => response.try_into(),
607                ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
608            }
609        }
610    }
611}
612
613// ======================================
614// Hyperbolic Audio Generation API
615// ======================================
616#[cfg(feature = "audio")]
617pub use audio_generation::*;
618use tracing::{Instrument, info_span};
619
620#[cfg(feature = "audio")]
621#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
622mod audio_generation {
623    use super::{ApiResponse, Client};
624    use crate::audio_generation;
625    use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
626    use crate::http_client::{self, HttpClientExt};
627    use base64::Engine;
628    use base64::prelude::BASE64_STANDARD;
629    use bytes::Bytes;
630    use serde::Deserialize;
631    use serde_json::json;
632
633    #[derive(Clone)]
634    pub struct AudioGenerationModel<T> {
635        client: Client<T>,
636        pub language: String,
637    }
638
639    #[derive(Clone, Deserialize)]
640    pub struct AudioGenerationResponse {
641        audio: String,
642    }
643
644    impl TryFrom<AudioGenerationResponse>
645        for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
646    {
647        type Error = AudioGenerationError;
648
649        fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
650            let data = BASE64_STANDARD
651                .decode(&value.audio)
652                .map_err(|err| AudioGenerationError::ResponseError(err.to_string()))?;
653
654            Ok(Self {
655                audio: data,
656                response: value,
657            })
658        }
659    }
660
661    impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
662    where
663        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
664    {
665        type Response = AudioGenerationResponse;
666        type Client = Client<T>;
667
668        fn make(client: &Self::Client, language: impl Into<String>) -> Self {
669            Self {
670                client: client.clone(),
671                language: language.into(),
672            }
673        }
674
675        async fn audio_generation(
676            &self,
677            request: AudioGenerationRequest,
678        ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
679        {
680            let request = json!({
681                "language": self.language,
682                "speaker": request.voice,
683                "text": request.text,
684                "speed": request.speed
685            });
686
687            let body = serde_json::to_vec(&request)?;
688
689            let req = self
690                .client
691                .post("/v1/audio/generation")?
692                .body(body)
693                .map_err(http_client::Error::from)?;
694
695            let response = self.client.send::<_, Bytes>(req).await?;
696            let status = response.status();
697            let response_body = response.into_body().into_future().await?.to_vec();
698
699            if !status.is_success() {
700                return Err(AudioGenerationError::ProviderError(format!(
701                    "{status}: {}",
702                    String::from_utf8_lossy(&response_body)
703                )));
704            }
705
706            match serde_json::from_slice::<ApiResponse<AudioGenerationResponse>>(&response_body)? {
707                ApiResponse::Ok(response) => response.try_into(),
708                ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
709            }
710        }
711    }
712}
713
714#[cfg(test)]
715mod tests {
716    #[test]
717    fn test_client_initialization() {
718        let _client =
719            crate::providers::hyperbolic::Client::new("dummy-key").expect("Client::new() failed");
720        let _client_from_builder = crate::providers::hyperbolic::Client::builder()
721            .api_key("dummy-key")
722            .build()
723            .expect("Client::builder() failed");
724    }
725}