Skip to main content

rig_core/providers/
hyperbolic.rs

1//! Hyperbolic Inference API client and Rig integration
2//!
3//! # Example
4//! ```no_run
5//! use rig_core::{client::CompletionClient, providers::hyperbolic};
6//!
7//! # fn run() -> Result<(), Box<dyn std::error::Error>> {
8//! let client = hyperbolic::Client::new("YOUR_API_KEY")?;
9//!
10//! let llama_3_1_8b = client.completion_model(hyperbolic::LLAMA_3_1_8B);
11//! # Ok(())
12//! # }
13//! ```
14use super::openai::{AssistantContent, send_compatible_streaming_request};
15
16use crate::client::{self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder};
17use crate::client::{BearerAuth, ProviderClient};
18use crate::http_client::{self, HttpClientExt};
19use crate::streaming::StreamingCompletionResponse;
20
21use crate::providers::openai;
22use crate::{
23    OneOrMany,
24    completion::{self, CompletionError, CompletionRequest},
25    json_utils,
26    providers::openai::Message,
27};
28use serde::{Deserialize, Serialize};
29
30// ================================================================
31// Main Hyperbolic Client
32// ================================================================
33const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz";
34
35#[derive(Debug, Default, Clone, Copy)]
36pub struct HyperbolicExt;
37#[derive(Debug, Default, Clone, Copy)]
38pub struct HyperbolicBuilder;
39
40type HyperbolicApiKey = BearerAuth;
41
42impl Provider for HyperbolicExt {
43    type Builder = HyperbolicBuilder;
44
45    const VERIFY_PATH: &'static str = "/models";
46}
47
48impl<H> Capabilities<H> for HyperbolicExt {
49    type Completion = Capable<CompletionModel<H>>;
50    type Embeddings = Nothing;
51    type Transcription = Nothing;
52    type ModelListing = Nothing;
53    #[cfg(feature = "image")]
54    type ImageGeneration = Capable<ImageGenerationModel<H>>;
55    #[cfg(feature = "audio")]
56    type AudioGeneration = Capable<AudioGenerationModel<H>>;
57}
58
59impl DebugExt for HyperbolicExt {}
60
61impl ProviderBuilder for HyperbolicBuilder {
62    type Extension<H>
63        = HyperbolicExt
64    where
65        H: HttpClientExt;
66    type ApiKey = HyperbolicApiKey;
67
68    const BASE_URL: &'static str = HYPERBOLIC_API_BASE_URL;
69
70    fn build<H>(
71        _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
72    ) -> http_client::Result<Self::Extension<H>>
73    where
74        H: HttpClientExt,
75    {
76        Ok(HyperbolicExt)
77    }
78}
79
80pub type Client<H = reqwest::Client> = client::Client<HyperbolicExt, H>;
81pub type ClientBuilder<H = crate::markers::Missing> =
82    client::ClientBuilder<HyperbolicBuilder, String, H>;
83
84impl ProviderClient for Client {
85    type Input = HyperbolicApiKey;
86    type Error = crate::client::ProviderClientError;
87
88    /// Create a new Hyperbolic client from the `HYPERBOLIC_API_KEY` environment variable.
89    fn from_env() -> Result<Self, Self::Error> {
90        let api_key = crate::client::required_env_var("HYPERBOLIC_API_KEY")?;
91        Self::new(&api_key).map_err(Into::into)
92    }
93
94    fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
95        Self::new(input).map_err(Into::into)
96    }
97}
98
99#[derive(Debug, Deserialize)]
100struct ApiErrorResponse {
101    message: String,
102}
103
104#[derive(Debug, Deserialize)]
105#[serde(untagged)]
106enum ApiResponse<T> {
107    Ok(T),
108    Err(ApiErrorResponse),
109}
110
111#[derive(Debug, Deserialize)]
112pub struct EmbeddingData {
113    pub object: String,
114    pub embedding: Vec<f64>,
115    pub index: usize,
116}
117
118#[derive(Clone, Debug, Deserialize, Serialize)]
119pub struct Usage {
120    pub prompt_tokens: usize,
121    pub total_tokens: usize,
122}
123
124impl std::fmt::Display for Usage {
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        write!(
127            f,
128            "Prompt tokens: {} Total tokens: {}",
129            self.prompt_tokens, self.total_tokens
130        )
131    }
132}
133
134// ================================================================
135// Hyperbolic Completion API
136// ================================================================
137
138/// Meta Llama 3.1b Instruct model with 8B parameters.
139pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
140/// Meta Llama 3.3b Instruct model with 70B parameters.
141pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
142/// Meta Llama 3.1b Instruct model with 70B parameters.
143pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
144/// Meta Llama 3 Instruct model with 70B parameters.
145pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
146/// Hermes 3 Instruct model with 70B parameters.
147pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
148/// Deepseek v2.5 model.
149pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
150/// Qwen 2.5 model with 72B parameters.
151pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
152/// Meta Llama 3.2b Instruct model with 3B parameters.
153pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
154/// Qwen 2.5 Coder Instruct model with 32B parameters.
155pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
156/// Preview (latest) version of Qwen model with 32B parameters.
157pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
158/// Deepseek R1 Zero model.
159pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
160/// Deepseek R1 model.
161pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
162
163/// A Hyperbolic completion object.
164///
165/// For more information, see this link: <https://docs.hyperbolic.xyz/reference/create_chat_completion_v1_chat_completions_post>
166#[derive(Debug, Deserialize, Serialize)]
167pub struct CompletionResponse {
168    pub id: String,
169    pub object: String,
170    pub created: u64,
171    pub model: String,
172    pub choices: Vec<Choice>,
173    pub usage: Option<Usage>,
174}
175
176impl From<ApiErrorResponse> for CompletionError {
177    fn from(err: ApiErrorResponse) -> Self {
178        CompletionError::ProviderError(err.message)
179    }
180}
181
182impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
183    type Error = CompletionError;
184
185    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
186        let choice = response.choices.first().ok_or_else(|| {
187            CompletionError::ResponseError("Response contained no choices".to_owned())
188        })?;
189
190        let content = match &choice.message {
191            Message::Assistant {
192                content,
193                tool_calls,
194                ..
195            } => {
196                let mut content = content
197                    .iter()
198                    .map(|c| match c {
199                        AssistantContent::Text { text } => completion::AssistantContent::text(text),
200                        AssistantContent::Refusal { refusal } => {
201                            completion::AssistantContent::text(refusal)
202                        }
203                    })
204                    .collect::<Vec<_>>();
205
206                content.extend(
207                    tool_calls
208                        .iter()
209                        .map(|call| {
210                            completion::AssistantContent::tool_call(
211                                &call.id,
212                                &call.function.name,
213                                call.function.arguments.clone(),
214                            )
215                        })
216                        .collect::<Vec<_>>(),
217                );
218                Ok(content)
219            }
220            _ => Err(CompletionError::ResponseError(
221                "Response did not contain a valid message or tool call".into(),
222            )),
223        }?;
224
225        let choice = OneOrMany::many(content).map_err(|_| {
226            CompletionError::ResponseError(
227                "Response contained no message or tool call (empty)".to_owned(),
228            )
229        })?;
230
231        let usage = response
232            .usage
233            .as_ref()
234            .map(|usage| completion::Usage {
235                input_tokens: usage.prompt_tokens as u64,
236                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
237                total_tokens: usage.total_tokens as u64,
238                cached_input_tokens: 0,
239                cache_creation_input_tokens: 0,
240                reasoning_tokens: 0,
241            })
242            .unwrap_or_default();
243
244        Ok(completion::CompletionResponse {
245            choice,
246            usage,
247            raw_response: response,
248            message_id: None,
249        })
250    }
251}
252
253#[derive(Debug, Deserialize, Serialize)]
254pub struct Choice {
255    pub index: usize,
256    pub message: Message,
257    pub finish_reason: String,
258}
259
260#[derive(Debug, Serialize, Deserialize)]
261pub(super) struct HyperbolicCompletionRequest {
262    model: String,
263    pub messages: Vec<Message>,
264    #[serde(skip_serializing_if = "Option::is_none")]
265    temperature: Option<f64>,
266    #[serde(flatten, skip_serializing_if = "Option::is_none")]
267    pub additional_params: Option<serde_json::Value>,
268}
269
270impl TryFrom<(&str, CompletionRequest)> for HyperbolicCompletionRequest {
271    type Error = CompletionError;
272
273    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
274        if req.output_schema.is_some() {
275            tracing::warn!("Structured outputs currently not supported for Hyperbolic");
276        }
277
278        let model = req.model.clone().unwrap_or_else(|| model.to_string());
279        if req.tool_choice.is_some() {
280            tracing::warn!("WARNING: `tool_choice` not supported on Hyperbolic");
281        }
282
283        if !req.tools.is_empty() {
284            tracing::warn!("WARNING: `tools` not supported on Hyperbolic");
285        }
286
287        let mut full_history: Vec<Message> = match &req.preamble {
288            Some(preamble) => vec![Message::system(preamble)],
289            None => vec![],
290        };
291
292        if let Some(docs) = req.normalized_documents() {
293            let docs: Vec<Message> = docs.try_into()?;
294            full_history.extend(docs);
295        }
296
297        let chat_history: Vec<Message> = req
298            .chat_history
299            .clone()
300            .into_iter()
301            .map(|message| message.try_into())
302            .collect::<Result<Vec<Vec<Message>>, _>>()?
303            .into_iter()
304            .flatten()
305            .collect();
306
307        full_history.extend(chat_history);
308
309        Ok(Self {
310            model: model.to_string(),
311            messages: full_history,
312            temperature: req.temperature,
313            additional_params: req.additional_params,
314        })
315    }
316}
317
318#[derive(Clone)]
319pub struct CompletionModel<T = reqwest::Client> {
320    client: Client<T>,
321    /// Name of the model (e.g.: deepseek-ai/DeepSeek-R1)
322    pub model: String,
323}
324
325impl<T> CompletionModel<T> {
326    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
327        Self {
328            client,
329            model: model.into(),
330        }
331    }
332
333    pub fn with_model(client: Client<T>, model: &str) -> Self {
334        Self {
335            client,
336            model: model.into(),
337        }
338    }
339}
340
341impl<T> completion::CompletionModel for CompletionModel<T>
342where
343    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
344{
345    type Response = CompletionResponse;
346    type StreamingResponse = openai::StreamingCompletionResponse;
347
348    type Client = Client<T>;
349
350    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
351        Self::new(client.clone(), model)
352    }
353
354    async fn completion(
355        &self,
356        completion_request: CompletionRequest,
357    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
358        let span = if tracing::Span::current().is_disabled() {
359            info_span!(
360                target: "rig::completions",
361                "chat",
362                gen_ai.operation.name = "chat",
363                gen_ai.provider.name = "hyperbolic",
364                gen_ai.request.model = self.model,
365                gen_ai.system_instructions = tracing::field::Empty,
366                gen_ai.response.id = tracing::field::Empty,
367                gen_ai.response.model = tracing::field::Empty,
368                gen_ai.usage.output_tokens = tracing::field::Empty,
369                gen_ai.usage.input_tokens = tracing::field::Empty,
370                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
371            )
372        } else {
373            tracing::Span::current()
374        };
375
376        span.record("gen_ai.system_instructions", &completion_request.preamble);
377        let request =
378            HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
379
380        if tracing::enabled!(tracing::Level::TRACE) {
381            tracing::trace!(target: "rig::completions",
382                "Hyperbolic completion request: {}",
383                serde_json::to_string_pretty(&request)?
384            );
385        }
386
387        let body = serde_json::to_vec(&request)?;
388
389        let req = self
390            .client
391            .post("/v1/chat/completions")?
392            .body(body)
393            .map_err(http_client::Error::from)?;
394
395        let async_block = async move {
396            let response = self.client.send::<_, bytes::Bytes>(req).await?;
397
398            let status = response.status();
399            let response_body = response.into_body().into_future().await?.to_vec();
400
401            if status.is_success() {
402                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
403                    ApiResponse::Ok(response) => {
404                        if tracing::enabled!(tracing::Level::TRACE) {
405                            tracing::trace!(target: "rig::completions",
406                                "Hyperbolic completion response: {}",
407                                serde_json::to_string_pretty(&response)?
408                            );
409                        }
410
411                        response.try_into()
412                    }
413                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
414                }
415            } else {
416                Err(CompletionError::ProviderError(
417                    String::from_utf8_lossy(&response_body).to_string(),
418                ))
419            }
420        };
421
422        async_block.instrument(span).await
423    }
424
425    async fn stream(
426        &self,
427        completion_request: CompletionRequest,
428    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
429        let span = if tracing::Span::current().is_disabled() {
430            info_span!(
431                target: "rig::completions",
432                "chat_streaming",
433                gen_ai.operation.name = "chat_streaming",
434                gen_ai.provider.name = "hyperbolic",
435                gen_ai.request.model = self.model,
436                gen_ai.system_instructions = tracing::field::Empty,
437                gen_ai.response.id = tracing::field::Empty,
438                gen_ai.response.model = tracing::field::Empty,
439                gen_ai.usage.output_tokens = tracing::field::Empty,
440                gen_ai.usage.input_tokens = tracing::field::Empty,
441                gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
442            )
443        } else {
444            tracing::Span::current()
445        };
446
447        span.record("gen_ai.system_instructions", &completion_request.preamble);
448        let mut request =
449            HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
450
451        let params = json_utils::merge(
452            request.additional_params.unwrap_or(serde_json::json!({})),
453            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
454        );
455
456        request.additional_params = Some(params);
457
458        if tracing::enabled!(tracing::Level::TRACE) {
459            tracing::trace!(target: "rig::completions",
460                "Hyperbolic streaming completion request: {}",
461                serde_json::to_string_pretty(&request)?
462            );
463        }
464
465        let body = serde_json::to_vec(&request)?;
466
467        let req = self
468            .client
469            .post("/v1/chat/completions")?
470            .body(body)
471            .map_err(http_client::Error::from)?;
472
473        send_compatible_streaming_request(self.client.clone(), req)
474            .instrument(span)
475            .await
476    }
477}
478
479// =======================================
480// Hyperbolic Image Generation API
481// =======================================
482
483#[cfg(feature = "image")]
484pub use image_generation::*;
485
486#[cfg(feature = "image")]
487#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
488mod image_generation {
489    use super::{ApiResponse, Client};
490    use crate::http_client::HttpClientExt;
491    use crate::image_generation;
492    use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
493    use crate::json_utils::merge_inplace;
494    use base64::Engine;
495    use base64::prelude::BASE64_STANDARD;
496    use serde::Deserialize;
497    use serde_json::json;
498
499    pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
500    pub const SD2: &str = "SD2";
501    pub const SD1_5: &str = "SD1.5";
502    pub const SSD: &str = "SSD";
503    pub const SDXL_TURBO: &str = "SDXL-turbo";
504    pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
505    pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
506
507    #[derive(Clone)]
508    pub struct ImageGenerationModel<T> {
509        client: Client<T>,
510        pub model: String,
511    }
512
513    impl<T> ImageGenerationModel<T> {
514        pub(crate) fn new(client: Client<T>, model: impl Into<String>) -> Self {
515            Self {
516                client,
517                model: model.into(),
518            }
519        }
520
521        pub fn with_model(client: Client<T>, model: &str) -> Self {
522            Self {
523                client,
524                model: model.into(),
525            }
526        }
527    }
528
529    #[derive(Clone, Deserialize)]
530    pub struct Image {
531        image: String,
532    }
533
534    #[derive(Clone, Deserialize)]
535    pub struct ImageGenerationResponse {
536        images: Vec<Image>,
537    }
538
539    impl TryFrom<ImageGenerationResponse>
540        for image_generation::ImageGenerationResponse<ImageGenerationResponse>
541    {
542        type Error = ImageGenerationError;
543
544        fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
545            let image = value
546                .images
547                .first()
548                .ok_or_else(|| ImageGenerationError::ResponseError("missing image data".into()))?;
549            let data = BASE64_STANDARD
550                .decode(&image.image)
551                .map_err(|err| ImageGenerationError::ResponseError(err.to_string()))?;
552
553            Ok(Self {
554                image: data,
555                response: value,
556            })
557        }
558    }
559
560    impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
561    where
562        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
563    {
564        type Response = ImageGenerationResponse;
565
566        type Client = Client<T>;
567
568        fn make(client: &Self::Client, model: impl Into<String>) -> Self {
569            Self::new(client.clone(), model)
570        }
571
572        async fn image_generation(
573            &self,
574            generation_request: ImageGenerationRequest,
575        ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
576        {
577            let mut request = json!({
578                "model_name": self.model,
579                "prompt": generation_request.prompt,
580                "height": generation_request.height,
581                "width": generation_request.width,
582            });
583
584            if let Some(params) = generation_request.additional_params {
585                merge_inplace(&mut request, params);
586            }
587
588            let body = serde_json::to_vec(&request)?;
589
590            let request = self
591                .client
592                .post("/v1/image/generation")?
593                .header("Content-Type", "application/json")
594                .body(body)
595                .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
596
597            let response = self.client.send::<_, bytes::Bytes>(request).await?;
598
599            let status = response.status();
600            let response_body = response.into_body().into_future().await?.to_vec();
601
602            if !status.is_success() {
603                return Err(ImageGenerationError::ProviderError(format!(
604                    "{status}: {}",
605                    String::from_utf8_lossy(&response_body)
606                )));
607            }
608
609            match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
610                ApiResponse::Ok(response) => response.try_into(),
611                ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
612            }
613        }
614    }
615}
616
617// ======================================
618// Hyperbolic Audio Generation API
619// ======================================
620#[cfg(feature = "audio")]
621pub use audio_generation::*;
622use tracing::{Instrument, info_span};
623
624#[cfg(feature = "audio")]
625#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
626mod audio_generation {
627    use super::{ApiResponse, Client};
628    use crate::audio_generation;
629    use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
630    use crate::http_client::{self, HttpClientExt};
631    use base64::Engine;
632    use base64::prelude::BASE64_STANDARD;
633    use bytes::Bytes;
634    use serde::Deserialize;
635    use serde_json::json;
636
637    #[derive(Clone)]
638    pub struct AudioGenerationModel<T> {
639        client: Client<T>,
640        pub language: String,
641    }
642
643    #[derive(Clone, Deserialize)]
644    pub struct AudioGenerationResponse {
645        audio: String,
646    }
647
648    impl TryFrom<AudioGenerationResponse>
649        for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
650    {
651        type Error = AudioGenerationError;
652
653        fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
654            let data = BASE64_STANDARD
655                .decode(&value.audio)
656                .map_err(|err| AudioGenerationError::ResponseError(err.to_string()))?;
657
658            Ok(Self {
659                audio: data,
660                response: value,
661            })
662        }
663    }
664
665    impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
666    where
667        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
668    {
669        type Response = AudioGenerationResponse;
670        type Client = Client<T>;
671
672        fn make(client: &Self::Client, language: impl Into<String>) -> Self {
673            Self {
674                client: client.clone(),
675                language: language.into(),
676            }
677        }
678
679        async fn audio_generation(
680            &self,
681            request: AudioGenerationRequest,
682        ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
683        {
684            let request = json!({
685                "language": self.language,
686                "speaker": request.voice,
687                "text": request.text,
688                "speed": request.speed
689            });
690
691            let body = serde_json::to_vec(&request)?;
692
693            let req = self
694                .client
695                .post("/v1/audio/generation")?
696                .body(body)
697                .map_err(http_client::Error::from)?;
698
699            let response = self.client.send::<_, Bytes>(req).await?;
700            let status = response.status();
701            let response_body = response.into_body().into_future().await?.to_vec();
702
703            if !status.is_success() {
704                return Err(AudioGenerationError::ProviderError(format!(
705                    "{status}: {}",
706                    String::from_utf8_lossy(&response_body)
707                )));
708            }
709
710            match serde_json::from_slice::<ApiResponse<AudioGenerationResponse>>(&response_body)? {
711                ApiResponse::Ok(response) => response.try_into(),
712                ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
713            }
714        }
715    }
716}
717
718#[cfg(test)]
719mod tests {
720    #[test]
721    fn test_client_initialization() {
722        let _client =
723            crate::providers::hyperbolic::Client::new("dummy-key").expect("Client::new() failed");
724        let _client_from_builder = crate::providers::hyperbolic::Client::builder()
725            .api_key("dummy-key")
726            .build()
727            .expect("Client::builder() failed");
728    }
729}