Skip to main content

rig/providers/
hyperbolic.rs

1//! Hyperbolic Inference API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::hyperbolic;
6//!
7//! let client = hyperbolic::Client::new("YOUR_API_KEY");
8//!
9//! let llama_3_1_8b = client.completion_model(hyperbolic::LLAMA_3_1_8B);
10//! ```
11use super::openai::{AssistantContent, send_compatible_streaming_request};
12
13use crate::client::{self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder};
14use crate::client::{BearerAuth, ProviderClient};
15use crate::http_client::{self, HttpClientExt};
16use crate::streaming::StreamingCompletionResponse;
17
18use crate::providers::openai;
19use crate::{
20    OneOrMany,
21    completion::{self, CompletionError, CompletionRequest},
22    json_utils,
23    providers::openai::Message,
24};
25use serde::{Deserialize, Serialize};
26
27// ================================================================
28// Main Hyperbolic Client
29// ================================================================
30const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz";
31
32#[derive(Debug, Default, Clone, Copy)]
33pub struct HyperbolicExt;
34#[derive(Debug, Default, Clone, Copy)]
35pub struct HyperbolicBuilder;
36
37type HyperbolicApiKey = BearerAuth;
38
39impl Provider for HyperbolicExt {
40    type Builder = HyperbolicBuilder;
41
42    const VERIFY_PATH: &'static str = "/models";
43}
44
45impl<H> Capabilities<H> for HyperbolicExt {
46    type Completion = Capable<CompletionModel<H>>;
47    type Embeddings = Nothing;
48    type Transcription = Nothing;
49    type ModelListing = Nothing;
50    #[cfg(feature = "image")]
51    type ImageGeneration = Capable<ImageGenerationModel<H>>;
52    #[cfg(feature = "audio")]
53    type AudioGeneration = Capable<AudioGenerationModel<H>>;
54}
55
56impl DebugExt for HyperbolicExt {}
57
58impl ProviderBuilder for HyperbolicBuilder {
59    type Extension<H>
60        = HyperbolicExt
61    where
62        H: HttpClientExt;
63    type ApiKey = HyperbolicApiKey;
64
65    const BASE_URL: &'static str = HYPERBOLIC_API_BASE_URL;
66
67    fn build<H>(
68        _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
69    ) -> http_client::Result<Self::Extension<H>>
70    where
71        H: HttpClientExt,
72    {
73        Ok(HyperbolicExt)
74    }
75}
76
77pub type Client<H = reqwest::Client> = client::Client<HyperbolicExt, H>;
78pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<HyperbolicBuilder, String, H>;
79
80impl ProviderClient for Client {
81    type Input = HyperbolicApiKey;
82
83    /// Create a new Hyperbolic client from the `HYPERBOLIC_API_KEY` environment variable.
84    /// Panics if the environment variable is not set.
85    fn from_env() -> Self {
86        let api_key = std::env::var("HYPERBOLIC_API_KEY").expect("HYPERBOLIC_API_KEY not set");
87        Self::new(&api_key).unwrap()
88    }
89
90    fn from_val(input: Self::Input) -> Self {
91        Self::new(input).unwrap()
92    }
93}
94
95#[derive(Debug, Deserialize)]
96struct ApiErrorResponse {
97    message: String,
98}
99
100#[derive(Debug, Deserialize)]
101#[serde(untagged)]
102enum ApiResponse<T> {
103    Ok(T),
104    Err(ApiErrorResponse),
105}
106
107#[derive(Debug, Deserialize)]
108pub struct EmbeddingData {
109    pub object: String,
110    pub embedding: Vec<f64>,
111    pub index: usize,
112}
113
114#[derive(Clone, Debug, Deserialize, Serialize)]
115pub struct Usage {
116    pub prompt_tokens: usize,
117    pub total_tokens: usize,
118}
119
120impl std::fmt::Display for Usage {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        write!(
123            f,
124            "Prompt tokens: {} Total tokens: {}",
125            self.prompt_tokens, self.total_tokens
126        )
127    }
128}
129
130// ================================================================
131// Hyperbolic Completion API
132// ================================================================
133
134/// Meta Llama 3.1b Instruct model with 8B parameters.
135pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
136/// Meta Llama 3.3b Instruct model with 70B parameters.
137pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
138/// Meta Llama 3.1b Instruct model with 70B parameters.
139pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
140/// Meta Llama 3 Instruct model with 70B parameters.
141pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
142/// Hermes 3 Instruct model with 70B parameters.
143pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
144/// Deepseek v2.5 model.
145pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
146/// Qwen 2.5 model with 72B parameters.
147pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
148/// Meta Llama 3.2b Instruct model with 3B parameters.
149pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
150/// Qwen 2.5 Coder Instruct model with 32B parameters.
151pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
152/// Preview (latest) version of Qwen model with 32B parameters.
153pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
154/// Deepseek R1 Zero model.
155pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
156/// Deepseek R1 model.
157pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
158
159/// A Hyperbolic completion object.
160///
161/// For more information, see this link: <https://docs.hyperbolic.xyz/reference/create_chat_completion_v1_chat_completions_post>
162#[derive(Debug, Deserialize, Serialize)]
163pub struct CompletionResponse {
164    pub id: String,
165    pub object: String,
166    pub created: u64,
167    pub model: String,
168    pub choices: Vec<Choice>,
169    pub usage: Option<Usage>,
170}
171
172impl From<ApiErrorResponse> for CompletionError {
173    fn from(err: ApiErrorResponse) -> Self {
174        CompletionError::ProviderError(err.message)
175    }
176}
177
178impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
179    type Error = CompletionError;
180
181    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
182        let choice = response.choices.first().ok_or_else(|| {
183            CompletionError::ResponseError("Response contained no choices".to_owned())
184        })?;
185
186        let content = match &choice.message {
187            Message::Assistant {
188                content,
189                tool_calls,
190                ..
191            } => {
192                let mut content = content
193                    .iter()
194                    .map(|c| match c {
195                        AssistantContent::Text { text } => completion::AssistantContent::text(text),
196                        AssistantContent::Refusal { refusal } => {
197                            completion::AssistantContent::text(refusal)
198                        }
199                    })
200                    .collect::<Vec<_>>();
201
202                content.extend(
203                    tool_calls
204                        .iter()
205                        .map(|call| {
206                            completion::AssistantContent::tool_call(
207                                &call.id,
208                                &call.function.name,
209                                call.function.arguments.clone(),
210                            )
211                        })
212                        .collect::<Vec<_>>(),
213                );
214                Ok(content)
215            }
216            _ => Err(CompletionError::ResponseError(
217                "Response did not contain a valid message or tool call".into(),
218            )),
219        }?;
220
221        let choice = OneOrMany::many(content).map_err(|_| {
222            CompletionError::ResponseError(
223                "Response contained no message or tool call (empty)".to_owned(),
224            )
225        })?;
226
227        let usage = response
228            .usage
229            .as_ref()
230            .map(|usage| completion::Usage {
231                input_tokens: usage.prompt_tokens as u64,
232                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
233                total_tokens: usage.total_tokens as u64,
234                cached_input_tokens: 0,
235                cache_creation_input_tokens: 0,
236            })
237            .unwrap_or_default();
238
239        Ok(completion::CompletionResponse {
240            choice,
241            usage,
242            raw_response: response,
243            message_id: None,
244        })
245    }
246}
247
248#[derive(Debug, Deserialize, Serialize)]
249pub struct Choice {
250    pub index: usize,
251    pub message: Message,
252    pub finish_reason: String,
253}
254
255#[derive(Debug, Serialize, Deserialize)]
256pub(super) struct HyperbolicCompletionRequest {
257    model: String,
258    pub messages: Vec<Message>,
259    #[serde(skip_serializing_if = "Option::is_none")]
260    temperature: Option<f64>,
261    #[serde(flatten, skip_serializing_if = "Option::is_none")]
262    pub additional_params: Option<serde_json::Value>,
263}
264
265impl TryFrom<(&str, CompletionRequest)> for HyperbolicCompletionRequest {
266    type Error = CompletionError;
267
268    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
269        if req.output_schema.is_some() {
270            tracing::warn!("Structured outputs currently not supported for Hyperbolic");
271        }
272
273        let model = req.model.clone().unwrap_or_else(|| model.to_string());
274        if req.tool_choice.is_some() {
275            tracing::warn!("WARNING: `tool_choice` not supported on Hyperbolic");
276        }
277
278        if !req.tools.is_empty() {
279            tracing::warn!("WARNING: `tools` not supported on Hyperbolic");
280        }
281
282        let mut full_history: Vec<Message> = match &req.preamble {
283            Some(preamble) => vec![Message::system(preamble)],
284            None => vec![],
285        };
286
287        if let Some(docs) = req.normalized_documents() {
288            let docs: Vec<Message> = docs.try_into()?;
289            full_history.extend(docs);
290        }
291
292        let chat_history: Vec<Message> = req
293            .chat_history
294            .clone()
295            .into_iter()
296            .map(|message| message.try_into())
297            .collect::<Result<Vec<Vec<Message>>, _>>()?
298            .into_iter()
299            .flatten()
300            .collect();
301
302        full_history.extend(chat_history);
303
304        Ok(Self {
305            model: model.to_string(),
306            messages: full_history,
307            temperature: req.temperature,
308            additional_params: req.additional_params,
309        })
310    }
311}
312
313#[derive(Clone)]
314pub struct CompletionModel<T = reqwest::Client> {
315    client: Client<T>,
316    /// Name of the model (e.g.: deepseek-ai/DeepSeek-R1)
317    pub model: String,
318}
319
320impl<T> CompletionModel<T> {
321    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
322        Self {
323            client,
324            model: model.into(),
325        }
326    }
327
328    pub fn with_model(client: Client<T>, model: &str) -> Self {
329        Self {
330            client,
331            model: model.into(),
332        }
333    }
334}
335
336impl<T> completion::CompletionModel for CompletionModel<T>
337where
338    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
339{
340    type Response = CompletionResponse;
341    type StreamingResponse = openai::StreamingCompletionResponse;
342
343    type Client = Client<T>;
344
345    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
346        Self::new(client.clone(), model)
347    }
348
349    async fn completion(
350        &self,
351        completion_request: CompletionRequest,
352    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
353        let span = if tracing::Span::current().is_disabled() {
354            info_span!(
355                target: "rig::completions",
356                "chat",
357                gen_ai.operation.name = "chat",
358                gen_ai.provider.name = "hyperbolic",
359                gen_ai.request.model = self.model,
360                gen_ai.system_instructions = tracing::field::Empty,
361                gen_ai.response.id = tracing::field::Empty,
362                gen_ai.response.model = tracing::field::Empty,
363                gen_ai.usage.output_tokens = tracing::field::Empty,
364                gen_ai.usage.input_tokens = tracing::field::Empty,
365                gen_ai.usage.cached_tokens = tracing::field::Empty,
366            )
367        } else {
368            tracing::Span::current()
369        };
370
371        span.record("gen_ai.system_instructions", &completion_request.preamble);
372        let request =
373            HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
374
375        if tracing::enabled!(tracing::Level::TRACE) {
376            tracing::trace!(target: "rig::completions",
377                "Hyperbolic completion request: {}",
378                serde_json::to_string_pretty(&request)?
379            );
380        }
381
382        let body = serde_json::to_vec(&request)?;
383
384        let req = self
385            .client
386            .post("/v1/chat/completions")?
387            .body(body)
388            .map_err(http_client::Error::from)?;
389
390        let async_block = async move {
391            let response = self.client.send::<_, bytes::Bytes>(req).await?;
392
393            let status = response.status();
394            let response_body = response.into_body().into_future().await?.to_vec();
395
396            if status.is_success() {
397                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
398                    ApiResponse::Ok(response) => {
399                        if tracing::enabled!(tracing::Level::TRACE) {
400                            tracing::trace!(target: "rig::completions",
401                                "Hyperbolic completion response: {}",
402                                serde_json::to_string_pretty(&response)?
403                            );
404                        }
405
406                        response.try_into()
407                    }
408                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
409                }
410            } else {
411                Err(CompletionError::ProviderError(
412                    String::from_utf8_lossy(&response_body).to_string(),
413                ))
414            }
415        };
416
417        async_block.instrument(span).await
418    }
419
420    async fn stream(
421        &self,
422        completion_request: CompletionRequest,
423    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
424        let span = if tracing::Span::current().is_disabled() {
425            info_span!(
426                target: "rig::completions",
427                "chat_streaming",
428                gen_ai.operation.name = "chat_streaming",
429                gen_ai.provider.name = "hyperbolic",
430                gen_ai.request.model = self.model,
431                gen_ai.system_instructions = tracing::field::Empty,
432                gen_ai.response.id = tracing::field::Empty,
433                gen_ai.response.model = tracing::field::Empty,
434                gen_ai.usage.output_tokens = tracing::field::Empty,
435                gen_ai.usage.input_tokens = tracing::field::Empty,
436                gen_ai.usage.cached_tokens = tracing::field::Empty,
437            )
438        } else {
439            tracing::Span::current()
440        };
441
442        span.record("gen_ai.system_instructions", &completion_request.preamble);
443        let mut request =
444            HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
445
446        let params = json_utils::merge(
447            request.additional_params.unwrap_or(serde_json::json!({})),
448            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
449        );
450
451        request.additional_params = Some(params);
452
453        if tracing::enabled!(tracing::Level::TRACE) {
454            tracing::trace!(target: "rig::completions",
455                "Hyperbolic streaming completion request: {}",
456                serde_json::to_string_pretty(&request)?
457            );
458        }
459
460        let body = serde_json::to_vec(&request)?;
461
462        let req = self
463            .client
464            .post("/v1/chat/completions")?
465            .body(body)
466            .map_err(http_client::Error::from)?;
467
468        send_compatible_streaming_request(self.client.clone(), req)
469            .instrument(span)
470            .await
471    }
472}
473
474// =======================================
475// Hyperbolic Image Generation API
476// =======================================
477
478#[cfg(feature = "image")]
479pub use image_generation::*;
480
481#[cfg(feature = "image")]
482#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
483mod image_generation {
484    use super::{ApiResponse, Client};
485    use crate::http_client::HttpClientExt;
486    use crate::image_generation;
487    use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
488    use crate::json_utils::merge_inplace;
489    use base64::Engine;
490    use base64::prelude::BASE64_STANDARD;
491    use serde::Deserialize;
492    use serde_json::json;
493
494    pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
495    pub const SD2: &str = "SD2";
496    pub const SD1_5: &str = "SD1.5";
497    pub const SSD: &str = "SSD";
498    pub const SDXL_TURBO: &str = "SDXL-turbo";
499    pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
500    pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
501
502    #[derive(Clone)]
503    pub struct ImageGenerationModel<T> {
504        client: Client<T>,
505        pub model: String,
506    }
507
508    impl<T> ImageGenerationModel<T> {
509        pub(crate) fn new(client: Client<T>, model: impl Into<String>) -> Self {
510            Self {
511                client,
512                model: model.into(),
513            }
514        }
515
516        pub fn with_model(client: Client<T>, model: &str) -> Self {
517            Self {
518                client,
519                model: model.into(),
520            }
521        }
522    }
523
524    #[derive(Clone, Deserialize)]
525    pub struct Image {
526        image: String,
527    }
528
529    #[derive(Clone, Deserialize)]
530    pub struct ImageGenerationResponse {
531        images: Vec<Image>,
532    }
533
534    impl TryFrom<ImageGenerationResponse>
535        for image_generation::ImageGenerationResponse<ImageGenerationResponse>
536    {
537        type Error = ImageGenerationError;
538
539        fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
540            let data = BASE64_STANDARD
541                .decode(&value.images[0].image)
542                .expect("Could not decode image.");
543
544            Ok(Self {
545                image: data,
546                response: value,
547            })
548        }
549    }
550
551    impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
552    where
553        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
554    {
555        type Response = ImageGenerationResponse;
556
557        type Client = Client<T>;
558
559        fn make(client: &Self::Client, model: impl Into<String>) -> Self {
560            Self::new(client.clone(), model)
561        }
562
563        async fn image_generation(
564            &self,
565            generation_request: ImageGenerationRequest,
566        ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
567        {
568            let mut request = json!({
569                "model_name": self.model,
570                "prompt": generation_request.prompt,
571                "height": generation_request.height,
572                "width": generation_request.width,
573            });
574
575            if let Some(params) = generation_request.additional_params {
576                merge_inplace(&mut request, params);
577            }
578
579            let body = serde_json::to_vec(&request)?;
580
581            let request = self
582                .client
583                .post("/v1/image/generation")?
584                .header("Content-Type", "application/json")
585                .body(body)
586                .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
587
588            let response = self.client.send::<_, bytes::Bytes>(request).await?;
589
590            let status = response.status();
591            let response_body = response.into_body().into_future().await?.to_vec();
592
593            if !status.is_success() {
594                return Err(ImageGenerationError::ProviderError(format!(
595                    "{status}: {}",
596                    String::from_utf8_lossy(&response_body)
597                )));
598            }
599
600            match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
601                ApiResponse::Ok(response) => response.try_into(),
602                ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
603            }
604        }
605    }
606}
607
608// ======================================
609// Hyperbolic Audio Generation API
610// ======================================
611#[cfg(feature = "audio")]
612pub use audio_generation::*;
613use tracing::{Instrument, info_span};
614
615#[cfg(feature = "audio")]
616#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
617mod audio_generation {
618    use super::{ApiResponse, Client};
619    use crate::audio_generation;
620    use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
621    use crate::http_client::{self, HttpClientExt};
622    use base64::Engine;
623    use base64::prelude::BASE64_STANDARD;
624    use bytes::Bytes;
625    use serde::Deserialize;
626    use serde_json::json;
627
628    #[derive(Clone)]
629    pub struct AudioGenerationModel<T> {
630        client: Client<T>,
631        pub language: String,
632    }
633
634    #[derive(Clone, Deserialize)]
635    pub struct AudioGenerationResponse {
636        audio: String,
637    }
638
639    impl TryFrom<AudioGenerationResponse>
640        for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
641    {
642        type Error = AudioGenerationError;
643
644        fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
645            let data = BASE64_STANDARD
646                .decode(&value.audio)
647                .expect("Could not decode audio.");
648
649            Ok(Self {
650                audio: data,
651                response: value,
652            })
653        }
654    }
655
656    impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
657    where
658        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
659    {
660        type Response = AudioGenerationResponse;
661        type Client = Client<T>;
662
663        fn make(client: &Self::Client, language: impl Into<String>) -> Self {
664            Self {
665                client: client.clone(),
666                language: language.into(),
667            }
668        }
669
670        async fn audio_generation(
671            &self,
672            request: AudioGenerationRequest,
673        ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
674        {
675            let request = json!({
676                "language": self.language,
677                "speaker": request.voice,
678                "text": request.text,
679                "speed": request.speed
680            });
681
682            let body = serde_json::to_vec(&request)?;
683
684            let req = self
685                .client
686                .post("/v1/audio/generation")?
687                .body(body)
688                .map_err(http_client::Error::from)?;
689
690            let response = self.client.send::<_, Bytes>(req).await?;
691            let status = response.status();
692            let response_body = response.into_body().into_future().await?.to_vec();
693
694            if !status.is_success() {
695                return Err(AudioGenerationError::ProviderError(format!(
696                    "{status}: {}",
697                    String::from_utf8_lossy(&response_body)
698                )));
699            }
700
701            match serde_json::from_slice::<ApiResponse<AudioGenerationResponse>>(&response_body)? {
702                ApiResponse::Ok(response) => response.try_into(),
703                ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
704            }
705        }
706    }
707}
708
709#[cfg(test)]
710mod tests {
711    #[test]
712    fn test_client_initialization() {
713        let _client =
714            crate::providers::hyperbolic::Client::new("dummy-key").expect("Client::new() failed");
715        let _client_from_builder = crate::providers::hyperbolic::Client::builder()
716            .api_key("dummy-key")
717            .build()
718            .expect("Client::builder() failed");
719    }
720}