Skip to main content

rig/providers/
hyperbolic.rs

1//! Hyperbolic Inference API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::hyperbolic;
6//!
7//! let client = hyperbolic::Client::new("YOUR_API_KEY");
8//!
9//! let llama_3_1_8b = client.completion_model(hyperbolic::LLAMA_3_1_8B);
10//! ```
11use super::openai::{AssistantContent, send_compatible_streaming_request};
12
13use crate::client::{self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder};
14use crate::client::{BearerAuth, ProviderClient};
15use crate::http_client::{self, HttpClientExt};
16use crate::streaming::StreamingCompletionResponse;
17
18use crate::providers::openai;
19use crate::{
20    OneOrMany,
21    completion::{self, CompletionError, CompletionRequest},
22    json_utils,
23    providers::openai::Message,
24};
25use serde::{Deserialize, Serialize};
26
27// ================================================================
28// Main Hyperbolic Client
29// ================================================================
30const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz";
31
32#[derive(Debug, Default, Clone, Copy)]
33pub struct HyperbolicExt;
34#[derive(Debug, Default, Clone, Copy)]
35pub struct HyperbolicBuilder;
36
37type HyperbolicApiKey = BearerAuth;
38
39impl Provider for HyperbolicExt {
40    type Builder = HyperbolicBuilder;
41
42    const VERIFY_PATH: &'static str = "/models";
43}
44
45impl<H> Capabilities<H> for HyperbolicExt {
46    type Completion = Capable<CompletionModel<H>>;
47    type Embeddings = Nothing;
48    type Transcription = Nothing;
49    type ModelListing = Nothing;
50    #[cfg(feature = "image")]
51    type ImageGeneration = Capable<ImageGenerationModel<H>>;
52    #[cfg(feature = "audio")]
53    type AudioGeneration = Capable<AudioGenerationModel<H>>;
54}
55
56impl DebugExt for HyperbolicExt {}
57
58impl ProviderBuilder for HyperbolicBuilder {
59    type Extension<H>
60        = HyperbolicExt
61    where
62        H: HttpClientExt;
63    type ApiKey = HyperbolicApiKey;
64
65    const BASE_URL: &'static str = HYPERBOLIC_API_BASE_URL;
66
67    fn build<H>(
68        _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
69    ) -> http_client::Result<Self::Extension<H>>
70    where
71        H: HttpClientExt,
72    {
73        Ok(HyperbolicExt)
74    }
75}
76
77pub type Client<H = reqwest::Client> = client::Client<HyperbolicExt, H>;
78pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<HyperbolicBuilder, String, H>;
79
80impl ProviderClient for Client {
81    type Input = HyperbolicApiKey;
82
83    /// Create a new Hyperbolic client from the `HYPERBOLIC_API_KEY` environment variable.
84    /// Panics if the environment variable is not set.
85    fn from_env() -> Self {
86        let api_key = std::env::var("HYPERBOLIC_API_KEY").expect("HYPERBOLIC_API_KEY not set");
87        Self::new(&api_key).unwrap()
88    }
89
90    fn from_val(input: Self::Input) -> Self {
91        Self::new(input).unwrap()
92    }
93}
94
95#[derive(Debug, Deserialize)]
96struct ApiErrorResponse {
97    message: String,
98}
99
100#[derive(Debug, Deserialize)]
101#[serde(untagged)]
102enum ApiResponse<T> {
103    Ok(T),
104    Err(ApiErrorResponse),
105}
106
107#[derive(Debug, Deserialize)]
108pub struct EmbeddingData {
109    pub object: String,
110    pub embedding: Vec<f64>,
111    pub index: usize,
112}
113
114#[derive(Clone, Debug, Deserialize, Serialize)]
115pub struct Usage {
116    pub prompt_tokens: usize,
117    pub total_tokens: usize,
118}
119
120impl std::fmt::Display for Usage {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        write!(
123            f,
124            "Prompt tokens: {} Total tokens: {}",
125            self.prompt_tokens, self.total_tokens
126        )
127    }
128}
129
130// ================================================================
131// Hyperbolic Completion API
132// ================================================================
133
134/// Meta Llama 3.1b Instruct model with 8B parameters.
135pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
136/// Meta Llama 3.3b Instruct model with 70B parameters.
137pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
138/// Meta Llama 3.1b Instruct model with 70B parameters.
139pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
140/// Meta Llama 3 Instruct model with 70B parameters.
141pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
142/// Hermes 3 Instruct model with 70B parameters.
143pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
144/// Deepseek v2.5 model.
145pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
146/// Qwen 2.5 model with 72B parameters.
147pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
148/// Meta Llama 3.2b Instruct model with 3B parameters.
149pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
150/// Qwen 2.5 Coder Instruct model with 32B parameters.
151pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
152/// Preview (latest) version of Qwen model with 32B parameters.
153pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
154/// Deepseek R1 Zero model.
155pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
156/// Deepseek R1 model.
157pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
158
159/// A Hyperbolic completion object.
160///
161/// For more information, see this link: <https://docs.hyperbolic.xyz/reference/create_chat_completion_v1_chat_completions_post>
162#[derive(Debug, Deserialize, Serialize)]
163pub struct CompletionResponse {
164    pub id: String,
165    pub object: String,
166    pub created: u64,
167    pub model: String,
168    pub choices: Vec<Choice>,
169    pub usage: Option<Usage>,
170}
171
172impl From<ApiErrorResponse> for CompletionError {
173    fn from(err: ApiErrorResponse) -> Self {
174        CompletionError::ProviderError(err.message)
175    }
176}
177
178impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
179    type Error = CompletionError;
180
181    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
182        let choice = response.choices.first().ok_or_else(|| {
183            CompletionError::ResponseError("Response contained no choices".to_owned())
184        })?;
185
186        let content = match &choice.message {
187            Message::Assistant {
188                content,
189                tool_calls,
190                ..
191            } => {
192                let mut content = content
193                    .iter()
194                    .map(|c| match c {
195                        AssistantContent::Text { text } => completion::AssistantContent::text(text),
196                        AssistantContent::Refusal { refusal } => {
197                            completion::AssistantContent::text(refusal)
198                        }
199                    })
200                    .collect::<Vec<_>>();
201
202                content.extend(
203                    tool_calls
204                        .iter()
205                        .map(|call| {
206                            completion::AssistantContent::tool_call(
207                                &call.id,
208                                &call.function.name,
209                                call.function.arguments.clone(),
210                            )
211                        })
212                        .collect::<Vec<_>>(),
213                );
214                Ok(content)
215            }
216            _ => Err(CompletionError::ResponseError(
217                "Response did not contain a valid message or tool call".into(),
218            )),
219        }?;
220
221        let choice = OneOrMany::many(content).map_err(|_| {
222            CompletionError::ResponseError(
223                "Response contained no message or tool call (empty)".to_owned(),
224            )
225        })?;
226
227        let usage = response
228            .usage
229            .as_ref()
230            .map(|usage| completion::Usage {
231                input_tokens: usage.prompt_tokens as u64,
232                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
233                total_tokens: usage.total_tokens as u64,
234                cached_input_tokens: 0,
235            })
236            .unwrap_or_default();
237
238        Ok(completion::CompletionResponse {
239            choice,
240            usage,
241            raw_response: response,
242            message_id: None,
243        })
244    }
245}
246
247#[derive(Debug, Deserialize, Serialize)]
248pub struct Choice {
249    pub index: usize,
250    pub message: Message,
251    pub finish_reason: String,
252}
253
254#[derive(Debug, Serialize, Deserialize)]
255pub(super) struct HyperbolicCompletionRequest {
256    model: String,
257    pub messages: Vec<Message>,
258    #[serde(skip_serializing_if = "Option::is_none")]
259    temperature: Option<f64>,
260    #[serde(flatten, skip_serializing_if = "Option::is_none")]
261    pub additional_params: Option<serde_json::Value>,
262}
263
264impl TryFrom<(&str, CompletionRequest)> for HyperbolicCompletionRequest {
265    type Error = CompletionError;
266
267    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
268        if req.output_schema.is_some() {
269            tracing::warn!("Structured outputs currently not supported for Hyperbolic");
270        }
271
272        let model = req.model.clone().unwrap_or_else(|| model.to_string());
273        if req.tool_choice.is_some() {
274            tracing::warn!("WARNING: `tool_choice` not supported on Hyperbolic");
275        }
276
277        if !req.tools.is_empty() {
278            tracing::warn!("WARNING: `tools` not supported on Hyperbolic");
279        }
280
281        let mut full_history: Vec<Message> = match &req.preamble {
282            Some(preamble) => vec![Message::system(preamble)],
283            None => vec![],
284        };
285
286        if let Some(docs) = req.normalized_documents() {
287            let docs: Vec<Message> = docs.try_into()?;
288            full_history.extend(docs);
289        }
290
291        let chat_history: Vec<Message> = req
292            .chat_history
293            .clone()
294            .into_iter()
295            .map(|message| message.try_into())
296            .collect::<Result<Vec<Vec<Message>>, _>>()?
297            .into_iter()
298            .flatten()
299            .collect();
300
301        full_history.extend(chat_history);
302
303        Ok(Self {
304            model: model.to_string(),
305            messages: full_history,
306            temperature: req.temperature,
307            additional_params: req.additional_params,
308        })
309    }
310}
311
312#[derive(Clone)]
313pub struct CompletionModel<T = reqwest::Client> {
314    client: Client<T>,
315    /// Name of the model (e.g.: deepseek-ai/DeepSeek-R1)
316    pub model: String,
317}
318
319impl<T> CompletionModel<T> {
320    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
321        Self {
322            client,
323            model: model.into(),
324        }
325    }
326
327    pub fn with_model(client: Client<T>, model: &str) -> Self {
328        Self {
329            client,
330            model: model.into(),
331        }
332    }
333}
334
335impl<T> completion::CompletionModel for CompletionModel<T>
336where
337    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
338{
339    type Response = CompletionResponse;
340    type StreamingResponse = openai::StreamingCompletionResponse;
341
342    type Client = Client<T>;
343
344    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
345        Self::new(client.clone(), model)
346    }
347
348    async fn completion(
349        &self,
350        completion_request: CompletionRequest,
351    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
352        let span = if tracing::Span::current().is_disabled() {
353            info_span!(
354                target: "rig::completions",
355                "chat",
356                gen_ai.operation.name = "chat",
357                gen_ai.provider.name = "hyperbolic",
358                gen_ai.request.model = self.model,
359                gen_ai.system_instructions = tracing::field::Empty,
360                gen_ai.response.id = tracing::field::Empty,
361                gen_ai.response.model = tracing::field::Empty,
362                gen_ai.usage.output_tokens = tracing::field::Empty,
363                gen_ai.usage.input_tokens = tracing::field::Empty,
364                gen_ai.usage.cached_tokens = tracing::field::Empty,
365            )
366        } else {
367            tracing::Span::current()
368        };
369
370        span.record("gen_ai.system_instructions", &completion_request.preamble);
371        let request =
372            HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
373
374        if tracing::enabled!(tracing::Level::TRACE) {
375            tracing::trace!(target: "rig::completions",
376                "Hyperbolic completion request: {}",
377                serde_json::to_string_pretty(&request)?
378            );
379        }
380
381        let body = serde_json::to_vec(&request)?;
382
383        let req = self
384            .client
385            .post("/v1/chat/completions")?
386            .body(body)
387            .map_err(http_client::Error::from)?;
388
389        let async_block = async move {
390            let response = self.client.send::<_, bytes::Bytes>(req).await?;
391
392            let status = response.status();
393            let response_body = response.into_body().into_future().await?.to_vec();
394
395            if status.is_success() {
396                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
397                    ApiResponse::Ok(response) => {
398                        if tracing::enabled!(tracing::Level::TRACE) {
399                            tracing::trace!(target: "rig::completions",
400                                "Hyperbolic completion response: {}",
401                                serde_json::to_string_pretty(&response)?
402                            );
403                        }
404
405                        response.try_into()
406                    }
407                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
408                }
409            } else {
410                Err(CompletionError::ProviderError(
411                    String::from_utf8_lossy(&response_body).to_string(),
412                ))
413            }
414        };
415
416        async_block.instrument(span).await
417    }
418
419    async fn stream(
420        &self,
421        completion_request: CompletionRequest,
422    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
423        let span = if tracing::Span::current().is_disabled() {
424            info_span!(
425                target: "rig::completions",
426                "chat_streaming",
427                gen_ai.operation.name = "chat_streaming",
428                gen_ai.provider.name = "hyperbolic",
429                gen_ai.request.model = self.model,
430                gen_ai.system_instructions = tracing::field::Empty,
431                gen_ai.response.id = tracing::field::Empty,
432                gen_ai.response.model = tracing::field::Empty,
433                gen_ai.usage.output_tokens = tracing::field::Empty,
434                gen_ai.usage.input_tokens = tracing::field::Empty,
435                gen_ai.usage.cached_tokens = tracing::field::Empty,
436            )
437        } else {
438            tracing::Span::current()
439        };
440
441        span.record("gen_ai.system_instructions", &completion_request.preamble);
442        let mut request =
443            HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
444
445        let params = json_utils::merge(
446            request.additional_params.unwrap_or(serde_json::json!({})),
447            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
448        );
449
450        request.additional_params = Some(params);
451
452        if tracing::enabled!(tracing::Level::TRACE) {
453            tracing::trace!(target: "rig::completions",
454                "Hyperbolic streaming completion request: {}",
455                serde_json::to_string_pretty(&request)?
456            );
457        }
458
459        let body = serde_json::to_vec(&request)?;
460
461        let req = self
462            .client
463            .post("/v1/chat/completions")?
464            .body(body)
465            .map_err(http_client::Error::from)?;
466
467        send_compatible_streaming_request(self.client.clone(), req)
468            .instrument(span)
469            .await
470    }
471}
472
473// =======================================
474// Hyperbolic Image Generation API
475// =======================================
476
477#[cfg(feature = "image")]
478pub use image_generation::*;
479
480#[cfg(feature = "image")]
481#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
482mod image_generation {
483    use super::{ApiResponse, Client};
484    use crate::http_client::HttpClientExt;
485    use crate::image_generation;
486    use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
487    use crate::json_utils::merge_inplace;
488    use base64::Engine;
489    use base64::prelude::BASE64_STANDARD;
490    use serde::Deserialize;
491    use serde_json::json;
492
493    pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
494    pub const SD2: &str = "SD2";
495    pub const SD1_5: &str = "SD1.5";
496    pub const SSD: &str = "SSD";
497    pub const SDXL_TURBO: &str = "SDXL-turbo";
498    pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
499    pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
500
501    #[derive(Clone)]
502    pub struct ImageGenerationModel<T> {
503        client: Client<T>,
504        pub model: String,
505    }
506
507    impl<T> ImageGenerationModel<T> {
508        pub(crate) fn new(client: Client<T>, model: impl Into<String>) -> Self {
509            Self {
510                client,
511                model: model.into(),
512            }
513        }
514
515        pub fn with_model(client: Client<T>, model: &str) -> Self {
516            Self {
517                client,
518                model: model.into(),
519            }
520        }
521    }
522
523    #[derive(Clone, Deserialize)]
524    pub struct Image {
525        image: String,
526    }
527
528    #[derive(Clone, Deserialize)]
529    pub struct ImageGenerationResponse {
530        images: Vec<Image>,
531    }
532
533    impl TryFrom<ImageGenerationResponse>
534        for image_generation::ImageGenerationResponse<ImageGenerationResponse>
535    {
536        type Error = ImageGenerationError;
537
538        fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
539            let data = BASE64_STANDARD
540                .decode(&value.images[0].image)
541                .expect("Could not decode image.");
542
543            Ok(Self {
544                image: data,
545                response: value,
546            })
547        }
548    }
549
550    impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
551    where
552        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
553    {
554        type Response = ImageGenerationResponse;
555
556        type Client = Client<T>;
557
558        fn make(client: &Self::Client, model: impl Into<String>) -> Self {
559            Self::new(client.clone(), model)
560        }
561
562        async fn image_generation(
563            &self,
564            generation_request: ImageGenerationRequest,
565        ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
566        {
567            let mut request = json!({
568                "model_name": self.model,
569                "prompt": generation_request.prompt,
570                "height": generation_request.height,
571                "width": generation_request.width,
572            });
573
574            if let Some(params) = generation_request.additional_params {
575                merge_inplace(&mut request, params);
576            }
577
578            let body = serde_json::to_vec(&request)?;
579
580            let request = self
581                .client
582                .post("/v1/image/generation")?
583                .header("Content-Type", "application/json")
584                .body(body)
585                .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
586
587            let response = self.client.send::<_, bytes::Bytes>(request).await?;
588
589            let status = response.status();
590            let response_body = response.into_body().into_future().await?.to_vec();
591
592            if !status.is_success() {
593                return Err(ImageGenerationError::ProviderError(format!(
594                    "{status}: {}",
595                    String::from_utf8_lossy(&response_body)
596                )));
597            }
598
599            match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
600                ApiResponse::Ok(response) => response.try_into(),
601                ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
602            }
603        }
604    }
605}
606
607// ======================================
608// Hyperbolic Audio Generation API
609// ======================================
610#[cfg(feature = "audio")]
611pub use audio_generation::*;
612use tracing::{Instrument, info_span};
613
614#[cfg(feature = "audio")]
615#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
616mod audio_generation {
617    use super::{ApiResponse, Client};
618    use crate::audio_generation;
619    use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
620    use crate::http_client::{self, HttpClientExt};
621    use base64::Engine;
622    use base64::prelude::BASE64_STANDARD;
623    use bytes::Bytes;
624    use serde::Deserialize;
625    use serde_json::json;
626
627    #[derive(Clone)]
628    pub struct AudioGenerationModel<T> {
629        client: Client<T>,
630        pub language: String,
631    }
632
633    #[derive(Clone, Deserialize)]
634    pub struct AudioGenerationResponse {
635        audio: String,
636    }
637
638    impl TryFrom<AudioGenerationResponse>
639        for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
640    {
641        type Error = AudioGenerationError;
642
643        fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
644            let data = BASE64_STANDARD
645                .decode(&value.audio)
646                .expect("Could not decode audio.");
647
648            Ok(Self {
649                audio: data,
650                response: value,
651            })
652        }
653    }
654
655    impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
656    where
657        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
658    {
659        type Response = AudioGenerationResponse;
660        type Client = Client<T>;
661
662        fn make(client: &Self::Client, language: impl Into<String>) -> Self {
663            Self {
664                client: client.clone(),
665                language: language.into(),
666            }
667        }
668
669        async fn audio_generation(
670            &self,
671            request: AudioGenerationRequest,
672        ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
673        {
674            let request = json!({
675                "language": self.language,
676                "speaker": request.voice,
677                "text": request.text,
678                "speed": request.speed
679            });
680
681            let body = serde_json::to_vec(&request)?;
682
683            let req = self
684                .client
685                .post("/v1/audio/generation")?
686                .body(body)
687                .map_err(http_client::Error::from)?;
688
689            let response = self.client.send::<_, Bytes>(req).await?;
690            let status = response.status();
691            let response_body = response.into_body().into_future().await?.to_vec();
692
693            if !status.is_success() {
694                return Err(AudioGenerationError::ProviderError(format!(
695                    "{status}: {}",
696                    String::from_utf8_lossy(&response_body)
697                )));
698            }
699
700            match serde_json::from_slice::<ApiResponse<AudioGenerationResponse>>(&response_body)? {
701                ApiResponse::Ok(response) => response.try_into(),
702                ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
703            }
704        }
705    }
706}
707
708#[cfg(test)]
709mod tests {
710    #[test]
711    fn test_client_initialization() {
712        let _client =
713            crate::providers::hyperbolic::Client::new("dummy-key").expect("Client::new() failed");
714        let _client_from_builder = crate::providers::hyperbolic::Client::builder()
715            .api_key("dummy-key")
716            .build()
717            .expect("Client::builder() failed");
718    }
719}