Skip to main content

rig/providers/
hyperbolic.rs

1//! Hyperbolic Inference API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::hyperbolic;
6//!
7//! let client = hyperbolic::Client::new("YOUR_API_KEY");
8//!
9//! let llama_3_1_8b = client.completion_model(hyperbolic::LLAMA_3_1_8B);
10//! ```
11use super::openai::{AssistantContent, send_compatible_streaming_request};
12
13use crate::client::{self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder};
14use crate::client::{BearerAuth, ProviderClient};
15use crate::http_client::{self, HttpClientExt};
16use crate::streaming::StreamingCompletionResponse;
17
18use crate::providers::openai;
19use crate::{
20    OneOrMany,
21    completion::{self, CompletionError, CompletionRequest},
22    json_utils,
23    providers::openai::Message,
24};
25use serde::{Deserialize, Serialize};
26
27// ================================================================
28// Main Hyperbolic Client
29// ================================================================
30const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz";
31
32#[derive(Debug, Default, Clone, Copy)]
33pub struct HyperbolicExt;
34#[derive(Debug, Default, Clone, Copy)]
35pub struct HyperbolicBuilder;
36
37type HyperbolicApiKey = BearerAuth;
38
39impl Provider for HyperbolicExt {
40    type Builder = HyperbolicBuilder;
41
42    const VERIFY_PATH: &'static str = "/models";
43}
44
45impl<H> Capabilities<H> for HyperbolicExt {
46    type Completion = Capable<CompletionModel<H>>;
47    type Embeddings = Nothing;
48    type Transcription = Nothing;
49    type ModelListing = Nothing;
50    #[cfg(feature = "image")]
51    type ImageGeneration = Capable<ImageGenerationModel<H>>;
52    #[cfg(feature = "audio")]
53    type AudioGeneration = Capable<AudioGenerationModel<H>>;
54}
55
56impl DebugExt for HyperbolicExt {}
57
58impl ProviderBuilder for HyperbolicBuilder {
59    type Extension<H>
60        = HyperbolicExt
61    where
62        H: HttpClientExt;
63    type ApiKey = HyperbolicApiKey;
64
65    const BASE_URL: &'static str = HYPERBOLIC_API_BASE_URL;
66
67    fn build<H>(
68        _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
69    ) -> http_client::Result<Self::Extension<H>>
70    where
71        H: HttpClientExt,
72    {
73        Ok(HyperbolicExt)
74    }
75}
76
77pub type Client<H = reqwest::Client> = client::Client<HyperbolicExt, H>;
78pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<HyperbolicBuilder, String, H>;
79
80impl ProviderClient for Client {
81    type Input = HyperbolicApiKey;
82
83    /// Create a new Hyperbolic client from the `HYPERBOLIC_API_KEY` environment variable.
84    /// Panics if the environment variable is not set.
85    fn from_env() -> Self {
86        let api_key = std::env::var("HYPERBOLIC_API_KEY").expect("HYPERBOLIC_API_KEY not set");
87        Self::new(&api_key).unwrap()
88    }
89
90    fn from_val(input: Self::Input) -> Self {
91        Self::new(input).unwrap()
92    }
93}
94
95#[derive(Debug, Deserialize)]
96struct ApiErrorResponse {
97    message: String,
98}
99
100#[derive(Debug, Deserialize)]
101#[serde(untagged)]
102enum ApiResponse<T> {
103    Ok(T),
104    Err(ApiErrorResponse),
105}
106
107#[derive(Debug, Deserialize)]
108pub struct EmbeddingData {
109    pub object: String,
110    pub embedding: Vec<f64>,
111    pub index: usize,
112}
113
114#[derive(Clone, Debug, Deserialize, Serialize)]
115pub struct Usage {
116    pub prompt_tokens: usize,
117    pub total_tokens: usize,
118}
119
120impl std::fmt::Display for Usage {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        write!(
123            f,
124            "Prompt tokens: {} Total tokens: {}",
125            self.prompt_tokens, self.total_tokens
126        )
127    }
128}
129
130// ================================================================
131// Hyperbolic Completion API
132// ================================================================
133
134/// Meta Llama 3.1b Instruct model with 8B parameters.
135pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
136/// Meta Llama 3.3b Instruct model with 70B parameters.
137pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
138/// Meta Llama 3.1b Instruct model with 70B parameters.
139pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
140/// Meta Llama 3 Instruct model with 70B parameters.
141pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
142/// Hermes 3 Instruct model with 70B parameters.
143pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
144/// Deepseek v2.5 model.
145pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
146/// Qwen 2.5 model with 72B parameters.
147pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
148/// Meta Llama 3.2b Instruct model with 3B parameters.
149pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
150/// Qwen 2.5 Coder Instruct model with 32B parameters.
151pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
152/// Preview (latest) version of Qwen model with 32B parameters.
153pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
154/// Deepseek R1 Zero model.
155pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
156/// Deepseek R1 model.
157pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
158
159/// A Hyperbolic completion object.
160///
161/// For more information, see this link: <https://docs.hyperbolic.xyz/reference/create_chat_completion_v1_chat_completions_post>
162#[derive(Debug, Deserialize, Serialize)]
163pub struct CompletionResponse {
164    pub id: String,
165    pub object: String,
166    pub created: u64,
167    pub model: String,
168    pub choices: Vec<Choice>,
169    pub usage: Option<Usage>,
170}
171
172impl From<ApiErrorResponse> for CompletionError {
173    fn from(err: ApiErrorResponse) -> Self {
174        CompletionError::ProviderError(err.message)
175    }
176}
177
178impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
179    type Error = CompletionError;
180
181    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
182        let choice = response.choices.first().ok_or_else(|| {
183            CompletionError::ResponseError("Response contained no choices".to_owned())
184        })?;
185
186        let content = match &choice.message {
187            Message::Assistant {
188                content,
189                tool_calls,
190                ..
191            } => {
192                let mut content = content
193                    .iter()
194                    .map(|c| match c {
195                        AssistantContent::Text { text } => completion::AssistantContent::text(text),
196                        AssistantContent::Refusal { refusal } => {
197                            completion::AssistantContent::text(refusal)
198                        }
199                    })
200                    .collect::<Vec<_>>();
201
202                content.extend(
203                    tool_calls
204                        .iter()
205                        .map(|call| {
206                            completion::AssistantContent::tool_call(
207                                &call.id,
208                                &call.function.name,
209                                call.function.arguments.clone(),
210                            )
211                        })
212                        .collect::<Vec<_>>(),
213                );
214                Ok(content)
215            }
216            _ => Err(CompletionError::ResponseError(
217                "Response did not contain a valid message or tool call".into(),
218            )),
219        }?;
220
221        let choice = OneOrMany::many(content).map_err(|_| {
222            CompletionError::ResponseError(
223                "Response contained no message or tool call (empty)".to_owned(),
224            )
225        })?;
226
227        let usage = response
228            .usage
229            .as_ref()
230            .map(|usage| completion::Usage {
231                input_tokens: usage.prompt_tokens as u64,
232                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
233                total_tokens: usage.total_tokens as u64,
234                cached_input_tokens: 0,
235            })
236            .unwrap_or_default();
237
238        Ok(completion::CompletionResponse {
239            choice,
240            usage,
241            raw_response: response,
242            message_id: None,
243        })
244    }
245}
246
247#[derive(Debug, Deserialize, Serialize)]
248pub struct Choice {
249    pub index: usize,
250    pub message: Message,
251    pub finish_reason: String,
252}
253
254#[derive(Debug, Serialize, Deserialize)]
255pub(super) struct HyperbolicCompletionRequest {
256    model: String,
257    pub messages: Vec<Message>,
258    #[serde(skip_serializing_if = "Option::is_none")]
259    temperature: Option<f64>,
260    #[serde(flatten, skip_serializing_if = "Option::is_none")]
261    pub additional_params: Option<serde_json::Value>,
262}
263
264impl TryFrom<(&str, CompletionRequest)> for HyperbolicCompletionRequest {
265    type Error = CompletionError;
266
267    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
268        if req.output_schema.is_some() {
269            tracing::warn!("Structured outputs currently not supported for Hyperbolic");
270        }
271
272        let model = req.model.clone().unwrap_or_else(|| model.to_string());
273        if req.tool_choice.is_some() {
274            tracing::warn!("WARNING: `tool_choice` not supported on Hyperbolic");
275        }
276
277        if !req.tools.is_empty() {
278            tracing::warn!("WARNING: `tools` not supported on Hyperbolic");
279        }
280
281        let mut full_history: Vec<Message> = match &req.preamble {
282            Some(preamble) => vec![Message::system(preamble)],
283            None => vec![],
284        };
285
286        if let Some(docs) = req.normalized_documents() {
287            let docs: Vec<Message> = docs.try_into()?;
288            full_history.extend(docs);
289        }
290
291        let chat_history: Vec<Message> = req
292            .chat_history
293            .clone()
294            .into_iter()
295            .map(|message| message.try_into())
296            .collect::<Result<Vec<Vec<Message>>, _>>()?
297            .into_iter()
298            .flatten()
299            .collect();
300
301        full_history.extend(chat_history);
302
303        Ok(Self {
304            model: model.to_string(),
305            messages: full_history,
306            temperature: req.temperature,
307            additional_params: req.additional_params,
308        })
309    }
310}
311
312#[derive(Clone)]
313pub struct CompletionModel<T = reqwest::Client> {
314    client: Client<T>,
315    /// Name of the model (e.g.: deepseek-ai/DeepSeek-R1)
316    pub model: String,
317}
318
319impl<T> CompletionModel<T> {
320    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
321        Self {
322            client,
323            model: model.into(),
324        }
325    }
326
327    pub fn with_model(client: Client<T>, model: &str) -> Self {
328        Self {
329            client,
330            model: model.into(),
331        }
332    }
333}
334
335impl<T> completion::CompletionModel for CompletionModel<T>
336where
337    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
338{
339    type Response = CompletionResponse;
340    type StreamingResponse = openai::StreamingCompletionResponse;
341
342    type Client = Client<T>;
343
344    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
345        Self::new(client.clone(), model)
346    }
347
348    async fn completion(
349        &self,
350        completion_request: CompletionRequest,
351    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
352        let span = if tracing::Span::current().is_disabled() {
353            info_span!(
354                target: "rig::completions",
355                "chat",
356                gen_ai.operation.name = "chat",
357                gen_ai.provider.name = "hyperbolic",
358                gen_ai.request.model = self.model,
359                gen_ai.system_instructions = tracing::field::Empty,
360                gen_ai.response.id = tracing::field::Empty,
361                gen_ai.response.model = tracing::field::Empty,
362                gen_ai.usage.output_tokens = tracing::field::Empty,
363                gen_ai.usage.input_tokens = tracing::field::Empty,
364            )
365        } else {
366            tracing::Span::current()
367        };
368
369        span.record("gen_ai.system_instructions", &completion_request.preamble);
370        let request =
371            HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
372
373        if tracing::enabled!(tracing::Level::TRACE) {
374            tracing::trace!(target: "rig::completions",
375                "Hyperbolic completion request: {}",
376                serde_json::to_string_pretty(&request)?
377            );
378        }
379
380        let body = serde_json::to_vec(&request)?;
381
382        let req = self
383            .client
384            .post("/v1/chat/completions")?
385            .body(body)
386            .map_err(http_client::Error::from)?;
387
388        let async_block = async move {
389            let response = self.client.send::<_, bytes::Bytes>(req).await?;
390
391            let status = response.status();
392            let response_body = response.into_body().into_future().await?.to_vec();
393
394            if status.is_success() {
395                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
396                    ApiResponse::Ok(response) => {
397                        if tracing::enabled!(tracing::Level::TRACE) {
398                            tracing::trace!(target: "rig::completions",
399                                "Hyperbolic completion response: {}",
400                                serde_json::to_string_pretty(&response)?
401                            );
402                        }
403
404                        response.try_into()
405                    }
406                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
407                }
408            } else {
409                Err(CompletionError::ProviderError(
410                    String::from_utf8_lossy(&response_body).to_string(),
411                ))
412            }
413        };
414
415        async_block.instrument(span).await
416    }
417
418    async fn stream(
419        &self,
420        completion_request: CompletionRequest,
421    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
422        let span = if tracing::Span::current().is_disabled() {
423            info_span!(
424                target: "rig::completions",
425                "chat_streaming",
426                gen_ai.operation.name = "chat_streaming",
427                gen_ai.provider.name = "hyperbolic",
428                gen_ai.request.model = self.model,
429                gen_ai.system_instructions = tracing::field::Empty,
430                gen_ai.response.id = tracing::field::Empty,
431                gen_ai.response.model = tracing::field::Empty,
432                gen_ai.usage.output_tokens = tracing::field::Empty,
433                gen_ai.usage.input_tokens = tracing::field::Empty,
434            )
435        } else {
436            tracing::Span::current()
437        };
438
439        span.record("gen_ai.system_instructions", &completion_request.preamble);
440        let mut request =
441            HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
442
443        let params = json_utils::merge(
444            request.additional_params.unwrap_or(serde_json::json!({})),
445            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
446        );
447
448        request.additional_params = Some(params);
449
450        if tracing::enabled!(tracing::Level::TRACE) {
451            tracing::trace!(target: "rig::completions",
452                "Hyperbolic streaming completion request: {}",
453                serde_json::to_string_pretty(&request)?
454            );
455        }
456
457        let body = serde_json::to_vec(&request)?;
458
459        let req = self
460            .client
461            .post("/v1/chat/completions")?
462            .body(body)
463            .map_err(http_client::Error::from)?;
464
465        send_compatible_streaming_request(self.client.clone(), req)
466            .instrument(span)
467            .await
468    }
469}
470
471// =======================================
472// Hyperbolic Image Generation API
473// =======================================
474
475#[cfg(feature = "image")]
476pub use image_generation::*;
477
478#[cfg(feature = "image")]
479#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
480mod image_generation {
481    use super::{ApiResponse, Client};
482    use crate::http_client::HttpClientExt;
483    use crate::image_generation;
484    use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
485    use crate::json_utils::merge_inplace;
486    use base64::Engine;
487    use base64::prelude::BASE64_STANDARD;
488    use serde::Deserialize;
489    use serde_json::json;
490
491    pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
492    pub const SD2: &str = "SD2";
493    pub const SD1_5: &str = "SD1.5";
494    pub const SSD: &str = "SSD";
495    pub const SDXL_TURBO: &str = "SDXL-turbo";
496    pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
497    pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
498
499    #[derive(Clone)]
500    pub struct ImageGenerationModel<T> {
501        client: Client<T>,
502        pub model: String,
503    }
504
505    impl<T> ImageGenerationModel<T> {
506        pub(crate) fn new(client: Client<T>, model: impl Into<String>) -> Self {
507            Self {
508                client,
509                model: model.into(),
510            }
511        }
512
513        pub fn with_model(client: Client<T>, model: &str) -> Self {
514            Self {
515                client,
516                model: model.into(),
517            }
518        }
519    }
520
521    #[derive(Clone, Deserialize)]
522    pub struct Image {
523        image: String,
524    }
525
526    #[derive(Clone, Deserialize)]
527    pub struct ImageGenerationResponse {
528        images: Vec<Image>,
529    }
530
531    impl TryFrom<ImageGenerationResponse>
532        for image_generation::ImageGenerationResponse<ImageGenerationResponse>
533    {
534        type Error = ImageGenerationError;
535
536        fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
537            let data = BASE64_STANDARD
538                .decode(&value.images[0].image)
539                .expect("Could not decode image.");
540
541            Ok(Self {
542                image: data,
543                response: value,
544            })
545        }
546    }
547
548    impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
549    where
550        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
551    {
552        type Response = ImageGenerationResponse;
553
554        type Client = Client<T>;
555
556        fn make(client: &Self::Client, model: impl Into<String>) -> Self {
557            Self::new(client.clone(), model)
558        }
559
560        async fn image_generation(
561            &self,
562            generation_request: ImageGenerationRequest,
563        ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
564        {
565            let mut request = json!({
566                "model_name": self.model,
567                "prompt": generation_request.prompt,
568                "height": generation_request.height,
569                "width": generation_request.width,
570            });
571
572            if let Some(params) = generation_request.additional_params {
573                merge_inplace(&mut request, params);
574            }
575
576            let body = serde_json::to_vec(&request)?;
577
578            let request = self
579                .client
580                .post("/v1/image/generation")?
581                .header("Content-Type", "application/json")
582                .body(body)
583                .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
584
585            let response = self.client.send::<_, bytes::Bytes>(request).await?;
586
587            let status = response.status();
588            let response_body = response.into_body().into_future().await?.to_vec();
589
590            if !status.is_success() {
591                return Err(ImageGenerationError::ProviderError(format!(
592                    "{status}: {}",
593                    String::from_utf8_lossy(&response_body)
594                )));
595            }
596
597            match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
598                ApiResponse::Ok(response) => response.try_into(),
599                ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
600            }
601        }
602    }
603}
604
605// ======================================
606// Hyperbolic Audio Generation API
607// ======================================
608#[cfg(feature = "audio")]
609pub use audio_generation::*;
610use tracing::{Instrument, info_span};
611
612#[cfg(feature = "audio")]
613#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
614mod audio_generation {
615    use super::{ApiResponse, Client};
616    use crate::audio_generation;
617    use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
618    use crate::http_client::{self, HttpClientExt};
619    use base64::Engine;
620    use base64::prelude::BASE64_STANDARD;
621    use bytes::Bytes;
622    use serde::Deserialize;
623    use serde_json::json;
624
625    #[derive(Clone)]
626    pub struct AudioGenerationModel<T> {
627        client: Client<T>,
628        pub language: String,
629    }
630
631    #[derive(Clone, Deserialize)]
632    pub struct AudioGenerationResponse {
633        audio: String,
634    }
635
636    impl TryFrom<AudioGenerationResponse>
637        for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
638    {
639        type Error = AudioGenerationError;
640
641        fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
642            let data = BASE64_STANDARD
643                .decode(&value.audio)
644                .expect("Could not decode audio.");
645
646            Ok(Self {
647                audio: data,
648                response: value,
649            })
650        }
651    }
652
653    impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
654    where
655        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
656    {
657        type Response = AudioGenerationResponse;
658        type Client = Client<T>;
659
660        fn make(client: &Self::Client, language: impl Into<String>) -> Self {
661            Self {
662                client: client.clone(),
663                language: language.into(),
664            }
665        }
666
667        async fn audio_generation(
668            &self,
669            request: AudioGenerationRequest,
670        ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
671        {
672            let request = json!({
673                "language": self.language,
674                "speaker": request.voice,
675                "text": request.text,
676                "speed": request.speed
677            });
678
679            let body = serde_json::to_vec(&request)?;
680
681            let req = self
682                .client
683                .post("/v1/audio/generation")?
684                .body(body)
685                .map_err(http_client::Error::from)?;
686
687            let response = self.client.send::<_, Bytes>(req).await?;
688            let status = response.status();
689            let response_body = response.into_body().into_future().await?.to_vec();
690
691            if !status.is_success() {
692                return Err(AudioGenerationError::ProviderError(format!(
693                    "{status}: {}",
694                    String::from_utf8_lossy(&response_body)
695                )));
696            }
697
698            match serde_json::from_slice::<ApiResponse<AudioGenerationResponse>>(&response_body)? {
699                ApiResponse::Ok(response) => response.try_into(),
700                ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
701            }
702        }
703    }
704}
705
706#[cfg(test)]
707mod tests {
708    #[test]
709    fn test_client_initialization() {
710        let _client =
711            crate::providers::hyperbolic::Client::new("dummy-key").expect("Client::new() failed");
712        let _client_from_builder = crate::providers::hyperbolic::Client::builder()
713            .api_key("dummy-key")
714            .build()
715            .expect("Client::builder() failed");
716    }
717}