Skip to main content

rig_core/providers/
hyperbolic.rs

1//! Hyperbolic Inference API client and Rig integration
2//!
3//! # Example
4//! ```no_run
5//! use rig_core::{client::CompletionClient, providers::hyperbolic};
6//!
7//! # fn run() -> Result<(), Box<dyn std::error::Error>> {
8//! let client = hyperbolic::Client::new("YOUR_API_KEY")?;
9//!
10//! let llama_3_1_8b = client.completion_model(hyperbolic::LLAMA_3_1_8B);
11//! # Ok(())
12//! # }
13//! ```
14use super::openai::{AssistantContent, send_compatible_streaming_request};
15
16use crate::client::{self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder};
17use crate::client::{BearerAuth, ProviderClient};
18use crate::http_client::{self, HttpClientExt};
19use crate::streaming::StreamingCompletionResponse;
20
21use crate::providers::openai;
22use crate::{
23    OneOrMany,
24    completion::{self, CompletionError, CompletionRequest},
25    json_utils,
26    providers::openai::Message,
27};
28use serde::{Deserialize, Serialize};
29
30// ================================================================
31// Main Hyperbolic Client
32// ================================================================
33const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz";
34
35#[derive(Debug, Default, Clone, Copy)]
36pub struct HyperbolicExt;
37#[derive(Debug, Default, Clone, Copy)]
38pub struct HyperbolicBuilder;
39
40type HyperbolicApiKey = BearerAuth;
41
42impl Provider for HyperbolicExt {
43    type Builder = HyperbolicBuilder;
44
45    const VERIFY_PATH: &'static str = "/models";
46}
47
48impl<H> Capabilities<H> for HyperbolicExt {
49    type Completion = Capable<CompletionModel<H>>;
50    type Embeddings = Nothing;
51    type Transcription = Nothing;
52    type ModelListing = Nothing;
53    #[cfg(feature = "image")]
54    type ImageGeneration = Capable<ImageGenerationModel<H>>;
55    #[cfg(feature = "audio")]
56    type AudioGeneration = Capable<AudioGenerationModel<H>>;
57}
58
59impl DebugExt for HyperbolicExt {}
60
61impl ProviderBuilder for HyperbolicBuilder {
62    type Extension<H>
63        = HyperbolicExt
64    where
65        H: HttpClientExt;
66    type ApiKey = HyperbolicApiKey;
67
68    const BASE_URL: &'static str = HYPERBOLIC_API_BASE_URL;
69
70    fn build<H>(
71        _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
72    ) -> http_client::Result<Self::Extension<H>>
73    where
74        H: HttpClientExt,
75    {
76        Ok(HyperbolicExt)
77    }
78}
79
80pub type Client<H = reqwest::Client> = client::Client<HyperbolicExt, H>;
81pub type ClientBuilder<H = crate::markers::Missing> =
82    client::ClientBuilder<HyperbolicBuilder, String, H>;
83
84impl ProviderClient for Client {
85    type Input = HyperbolicApiKey;
86    type Error = crate::client::ProviderClientError;
87
88    /// Create a new Hyperbolic client from the `HYPERBOLIC_API_KEY` environment variable.
89    fn from_env() -> Result<Self, Self::Error> {
90        let api_key = crate::client::required_env_var("HYPERBOLIC_API_KEY")?;
91        Self::new(&api_key).map_err(Into::into)
92    }
93
94    fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
95        Self::new(input).map_err(Into::into)
96    }
97}
98
99#[derive(Debug, Deserialize)]
100struct ApiErrorResponse {
101    message: String,
102}
103
104#[derive(Debug, Deserialize)]
105#[serde(untagged)]
106enum ApiResponse<T> {
107    Ok(T),
108    Err(ApiErrorResponse),
109}
110
111#[derive(Debug, Deserialize)]
112pub struct EmbeddingData {
113    pub object: String,
114    pub embedding: Vec<f64>,
115    pub index: usize,
116}
117
118#[derive(Clone, Debug, Deserialize, Serialize)]
119pub struct Usage {
120    pub prompt_tokens: usize,
121    pub total_tokens: usize,
122}
123
124impl std::fmt::Display for Usage {
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        write!(
127            f,
128            "Prompt tokens: {} Total tokens: {}",
129            self.prompt_tokens, self.total_tokens
130        )
131    }
132}
133
134// ================================================================
135// Hyperbolic Completion API
136// ================================================================
137
138/// Meta Llama 3.1b Instruct model with 8B parameters.
139pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
140/// Meta Llama 3.3b Instruct model with 70B parameters.
141pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
142/// Meta Llama 3.1b Instruct model with 70B parameters.
143pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
144/// Meta Llama 3 Instruct model with 70B parameters.
145pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
146/// Hermes 3 Instruct model with 70B parameters.
147pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
148/// Deepseek v2.5 model.
149pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
150/// Qwen 2.5 model with 72B parameters.
151pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
152/// Meta Llama 3.2b Instruct model with 3B parameters.
153pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
154/// Qwen 2.5 Coder Instruct model with 32B parameters.
155pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
156/// Preview (latest) version of Qwen model with 32B parameters.
157pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
158/// Deepseek R1 Zero model.
159pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
160/// Deepseek R1 model.
161pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
162
163/// A Hyperbolic completion object.
164///
165/// For more information, see this link: <https://docs.hyperbolic.xyz/reference/create_chat_completion_v1_chat_completions_post>
166#[derive(Debug, Deserialize, Serialize)]
167pub struct CompletionResponse {
168    pub id: String,
169    pub object: String,
170    pub created: u64,
171    pub model: String,
172    pub choices: Vec<Choice>,
173    pub usage: Option<Usage>,
174}
175
176impl From<ApiErrorResponse> for CompletionError {
177    fn from(err: ApiErrorResponse) -> Self {
178        CompletionError::ProviderError(err.message)
179    }
180}
181
182impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
183    type Error = CompletionError;
184
185    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
186        let choice = response.choices.first().ok_or_else(|| {
187            CompletionError::ResponseError("Response contained no choices".to_owned())
188        })?;
189
190        let content = match &choice.message {
191            Message::Assistant {
192                content,
193                tool_calls,
194                ..
195            } => {
196                let mut content = content
197                    .iter()
198                    .map(|c| match c {
199                        AssistantContent::Text { text } => completion::AssistantContent::text(text),
200                        AssistantContent::Refusal { refusal } => {
201                            completion::AssistantContent::text(refusal)
202                        }
203                    })
204                    .collect::<Vec<_>>();
205
206                content.extend(
207                    tool_calls
208                        .iter()
209                        .map(|call| {
210                            completion::AssistantContent::tool_call(
211                                &call.id,
212                                &call.function.name,
213                                call.function.arguments.clone(),
214                            )
215                        })
216                        .collect::<Vec<_>>(),
217                );
218                Ok(content)
219            }
220            _ => Err(CompletionError::ResponseError(
221                "Response did not contain a valid message or tool call".into(),
222            )),
223        }?;
224
225        let choice = OneOrMany::many(content).map_err(|_| {
226            CompletionError::ResponseError(
227                "Response contained no message or tool call (empty)".to_owned(),
228            )
229        })?;
230
231        let usage = response
232            .usage
233            .as_ref()
234            .map(|usage| completion::Usage {
235                input_tokens: usage.prompt_tokens as u64,
236                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
237                total_tokens: usage.total_tokens as u64,
238                cached_input_tokens: 0,
239                cache_creation_input_tokens: 0,
240                tool_use_prompt_tokens: 0,
241                reasoning_tokens: 0,
242            })
243            .unwrap_or_default();
244
245        Ok(completion::CompletionResponse {
246            choice,
247            usage,
248            raw_response: response,
249            message_id: None,
250        })
251    }
252}
253
254#[derive(Debug, Deserialize, Serialize)]
255pub struct Choice {
256    pub index: usize,
257    pub message: Message,
258    pub finish_reason: String,
259}
260
261#[derive(Debug, Serialize, Deserialize)]
262pub(super) struct HyperbolicCompletionRequest {
263    model: String,
264    pub messages: Vec<Message>,
265    #[serde(skip_serializing_if = "Option::is_none")]
266    temperature: Option<f64>,
267    #[serde(flatten, skip_serializing_if = "Option::is_none")]
268    pub additional_params: Option<serde_json::Value>,
269}
270
271impl TryFrom<(&str, CompletionRequest)> for HyperbolicCompletionRequest {
272    type Error = CompletionError;
273
274    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
275        if req.output_schema.is_some() {
276            tracing::warn!("Structured outputs currently not supported for Hyperbolic");
277        }
278
279        let model = req.model.clone().unwrap_or_else(|| model.to_string());
280        if req.tool_choice.is_some() {
281            tracing::warn!("WARNING: `tool_choice` not supported on Hyperbolic");
282        }
283
284        if !req.tools.is_empty() {
285            tracing::warn!("WARNING: `tools` not supported on Hyperbolic");
286        }
287
288        let mut full_history: Vec<Message> = match &req.preamble {
289            Some(preamble) => vec![Message::system(preamble)],
290            None => vec![],
291        };
292
293        if let Some(docs) = req.normalized_documents() {
294            let docs: Vec<Message> = docs.try_into()?;
295            full_history.extend(docs);
296        }
297
298        let chat_history: Vec<Message> = req
299            .chat_history
300            .clone()
301            .into_iter()
302            .map(|message| message.try_into())
303            .collect::<Result<Vec<Vec<Message>>, _>>()?
304            .into_iter()
305            .flatten()
306            .collect();
307
308        full_history.extend(chat_history);
309
310        Ok(Self {
311            model: model.to_string(),
312            messages: full_history,
313            temperature: req.temperature,
314            additional_params: req.additional_params,
315        })
316    }
317}
318
319#[derive(Clone)]
320pub struct CompletionModel<T = reqwest::Client> {
321    client: Client<T>,
322    /// Name of the model (e.g.: deepseek-ai/DeepSeek-R1)
323    pub model: String,
324}
325
326impl<T> CompletionModel<T> {
327    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
328        Self {
329            client,
330            model: model.into(),
331        }
332    }
333
334    pub fn with_model(client: Client<T>, model: &str) -> Self {
335        Self {
336            client,
337            model: model.into(),
338        }
339    }
340}
341
342impl<T> completion::CompletionModel for CompletionModel<T>
343where
344    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
345{
346    type Response = CompletionResponse;
347    type StreamingResponse = openai::StreamingCompletionResponse;
348
349    type Client = Client<T>;
350
351    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
352        Self::new(client.clone(), model)
353    }
354
355    async fn completion(
356        &self,
357        completion_request: CompletionRequest,
358    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
359        let span = if tracing::Span::current().is_disabled() {
360            info_span!(
361                target: "rig::completions",
362                "chat",
363                gen_ai.operation.name = "chat",
364                gen_ai.provider.name = "hyperbolic",
365                gen_ai.request.model = self.model,
366                gen_ai.system_instructions = tracing::field::Empty,
367                gen_ai.response.id = tracing::field::Empty,
368                gen_ai.response.model = tracing::field::Empty,
369                gen_ai.usage.output_tokens = tracing::field::Empty,
370                gen_ai.usage.input_tokens = tracing::field::Empty,
371                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
372            )
373        } else {
374            tracing::Span::current()
375        };
376
377        span.record("gen_ai.system_instructions", &completion_request.preamble);
378        let request =
379            HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
380
381        if tracing::enabled!(tracing::Level::TRACE) {
382            tracing::trace!(target: "rig::completions",
383                "Hyperbolic completion request: {}",
384                serde_json::to_string_pretty(&request)?
385            );
386        }
387
388        let body = serde_json::to_vec(&request)?;
389
390        let req = self
391            .client
392            .post("/v1/chat/completions")?
393            .body(body)
394            .map_err(http_client::Error::from)?;
395
396        let async_block = async move {
397            let response = self.client.send::<_, bytes::Bytes>(req).await?;
398
399            let status = response.status();
400            let response_body = response.into_body().into_future().await?.to_vec();
401
402            if status.is_success() {
403                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
404                    ApiResponse::Ok(response) => {
405                        if tracing::enabled!(tracing::Level::TRACE) {
406                            tracing::trace!(target: "rig::completions",
407                                "Hyperbolic completion response: {}",
408                                serde_json::to_string_pretty(&response)?
409                            );
410                        }
411
412                        response.try_into()
413                    }
414                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
415                }
416            } else {
417                Err(CompletionError::ProviderError(
418                    String::from_utf8_lossy(&response_body).to_string(),
419                ))
420            }
421        };
422
423        async_block.instrument(span).await
424    }
425
426    async fn stream(
427        &self,
428        completion_request: CompletionRequest,
429    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
430        let span = if tracing::Span::current().is_disabled() {
431            info_span!(
432                target: "rig::completions",
433                "chat_streaming",
434                gen_ai.operation.name = "chat_streaming",
435                gen_ai.provider.name = "hyperbolic",
436                gen_ai.request.model = self.model,
437                gen_ai.system_instructions = tracing::field::Empty,
438                gen_ai.response.id = tracing::field::Empty,
439                gen_ai.response.model = tracing::field::Empty,
440                gen_ai.usage.output_tokens = tracing::field::Empty,
441                gen_ai.usage.input_tokens = tracing::field::Empty,
442                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
443            )
444        } else {
445            tracing::Span::current()
446        };
447
448        span.record("gen_ai.system_instructions", &completion_request.preamble);
449        let mut request =
450            HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
451
452        let params = json_utils::merge(
453            request.additional_params.unwrap_or(serde_json::json!({})),
454            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
455        );
456
457        request.additional_params = Some(params);
458
459        if tracing::enabled!(tracing::Level::TRACE) {
460            tracing::trace!(target: "rig::completions",
461                "Hyperbolic streaming completion request: {}",
462                serde_json::to_string_pretty(&request)?
463            );
464        }
465
466        let body = serde_json::to_vec(&request)?;
467
468        let req = self
469            .client
470            .post("/v1/chat/completions")?
471            .body(body)
472            .map_err(http_client::Error::from)?;
473
474        send_compatible_streaming_request(self.client.clone(), req)
475            .instrument(span)
476            .await
477    }
478}
479
480// =======================================
481// Hyperbolic Image Generation API
482// =======================================
483
484#[cfg(feature = "image")]
485pub use image_generation::*;
486
487#[cfg(feature = "image")]
488#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
489mod image_generation {
490    use super::{ApiResponse, Client};
491    use crate::http_client::HttpClientExt;
492    use crate::image_generation;
493    use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
494    use crate::json_utils::merge_inplace;
495    use base64::Engine;
496    use base64::prelude::BASE64_STANDARD;
497    use serde::Deserialize;
498    use serde_json::json;
499
500    pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
501    pub const SD2: &str = "SD2";
502    pub const SD1_5: &str = "SD1.5";
503    pub const SSD: &str = "SSD";
504    pub const SDXL_TURBO: &str = "SDXL-turbo";
505    pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
506    pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
507
508    #[derive(Clone)]
509    pub struct ImageGenerationModel<T> {
510        client: Client<T>,
511        pub model: String,
512    }
513
514    impl<T> ImageGenerationModel<T> {
515        pub(crate) fn new(client: Client<T>, model: impl Into<String>) -> Self {
516            Self {
517                client,
518                model: model.into(),
519            }
520        }
521
522        pub fn with_model(client: Client<T>, model: &str) -> Self {
523            Self {
524                client,
525                model: model.into(),
526            }
527        }
528    }
529
530    #[derive(Clone, Deserialize)]
531    pub struct Image {
532        image: String,
533    }
534
535    #[derive(Clone, Deserialize)]
536    pub struct ImageGenerationResponse {
537        images: Vec<Image>,
538    }
539
540    impl TryFrom<ImageGenerationResponse>
541        for image_generation::ImageGenerationResponse<ImageGenerationResponse>
542    {
543        type Error = ImageGenerationError;
544
545        fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
546            let image = value
547                .images
548                .first()
549                .ok_or_else(|| ImageGenerationError::ResponseError("missing image data".into()))?;
550            let data = BASE64_STANDARD
551                .decode(&image.image)
552                .map_err(|err| ImageGenerationError::ResponseError(err.to_string()))?;
553
554            Ok(Self {
555                image: data,
556                response: value,
557            })
558        }
559    }
560
561    impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
562    where
563        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
564    {
565        type Response = ImageGenerationResponse;
566
567        type Client = Client<T>;
568
569        fn make(client: &Self::Client, model: impl Into<String>) -> Self {
570            Self::new(client.clone(), model)
571        }
572
573        async fn image_generation(
574            &self,
575            generation_request: ImageGenerationRequest,
576        ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
577        {
578            let mut request = json!({
579                "model_name": self.model,
580                "prompt": generation_request.prompt,
581                "height": generation_request.height,
582                "width": generation_request.width,
583            });
584
585            if let Some(params) = generation_request.additional_params {
586                merge_inplace(&mut request, params);
587            }
588
589            let body = serde_json::to_vec(&request)?;
590
591            let request = self
592                .client
593                .post("/v1/image/generation")?
594                .header("Content-Type", "application/json")
595                .body(body)
596                .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
597
598            let response = self.client.send::<_, bytes::Bytes>(request).await?;
599
600            let status = response.status();
601            let response_body = response.into_body().into_future().await?.to_vec();
602
603            if !status.is_success() {
604                return Err(ImageGenerationError::ProviderError(format!(
605                    "{status}: {}",
606                    String::from_utf8_lossy(&response_body)
607                )));
608            }
609
610            match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
611                ApiResponse::Ok(response) => response.try_into(),
612                ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
613            }
614        }
615    }
616}
617
618// ======================================
619// Hyperbolic Audio Generation API
620// ======================================
621#[cfg(feature = "audio")]
622pub use audio_generation::*;
623use tracing::{Instrument, info_span};
624
625#[cfg(feature = "audio")]
626#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
627mod audio_generation {
628    use super::{ApiResponse, Client};
629    use crate::audio_generation;
630    use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
631    use crate::http_client::{self, HttpClientExt};
632    use base64::Engine;
633    use base64::prelude::BASE64_STANDARD;
634    use bytes::Bytes;
635    use serde::Deserialize;
636    use serde_json::json;
637
638    #[derive(Clone)]
639    pub struct AudioGenerationModel<T> {
640        client: Client<T>,
641        pub language: String,
642    }
643
644    #[derive(Clone, Deserialize)]
645    pub struct AudioGenerationResponse {
646        audio: String,
647    }
648
649    impl TryFrom<AudioGenerationResponse>
650        for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
651    {
652        type Error = AudioGenerationError;
653
654        fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
655            let data = BASE64_STANDARD
656                .decode(&value.audio)
657                .map_err(|err| AudioGenerationError::ResponseError(err.to_string()))?;
658
659            Ok(Self {
660                audio: data,
661                response: value,
662            })
663        }
664    }
665
666    impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
667    where
668        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
669    {
670        type Response = AudioGenerationResponse;
671        type Client = Client<T>;
672
673        fn make(client: &Self::Client, language: impl Into<String>) -> Self {
674            Self {
675                client: client.clone(),
676                language: language.into(),
677            }
678        }
679
680        async fn audio_generation(
681            &self,
682            request: AudioGenerationRequest,
683        ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
684        {
685            let request = json!({
686                "language": self.language,
687                "speaker": request.voice,
688                "text": request.text,
689                "speed": request.speed
690            });
691
692            let body = serde_json::to_vec(&request)?;
693
694            let req = self
695                .client
696                .post("/v1/audio/generation")?
697                .body(body)
698                .map_err(http_client::Error::from)?;
699
700            let response = self.client.send::<_, Bytes>(req).await?;
701            let status = response.status();
702            let response_body = response.into_body().into_future().await?.to_vec();
703
704            if !status.is_success() {
705                return Err(AudioGenerationError::ProviderError(format!(
706                    "{status}: {}",
707                    String::from_utf8_lossy(&response_body)
708                )));
709            }
710
711            match serde_json::from_slice::<ApiResponse<AudioGenerationResponse>>(&response_body)? {
712                ApiResponse::Ok(response) => response.try_into(),
713                ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
714            }
715        }
716    }
717}
718
719#[cfg(test)]
720mod tests {
721    #[test]
722    fn test_client_initialization() {
723        let _client =
724            crate::providers::hyperbolic::Client::new("dummy-key").expect("Client::new() failed");
725        let _client_from_builder = crate::providers::hyperbolic::Client::builder()
726            .api_key("dummy-key")
727            .build()
728            .expect("Client::builder() failed");
729    }
730}