Skip to main content

rig/providers/
hyperbolic.rs

1//! Hyperbolic Inference API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::hyperbolic;
6//!
7//! let client = hyperbolic::Client::new("YOUR_API_KEY");
8//!
9//! let llama_3_1_8b = client.completion_model(hyperbolic::LLAMA_3_1_8B);
10//! ```
11use super::openai::{AssistantContent, send_compatible_streaming_request};
12
13use crate::client::{self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder};
14use crate::client::{BearerAuth, ProviderClient};
15use crate::http_client::{self, HttpClientExt};
16use crate::streaming::StreamingCompletionResponse;
17
18use crate::providers::openai;
19use crate::{
20    OneOrMany,
21    completion::{self, CompletionError, CompletionRequest},
22    json_utils,
23    providers::openai::Message,
24};
25use serde::{Deserialize, Serialize};
26
27// ================================================================
28// Main Hyperbolic Client
29// ================================================================
30const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz";
31
32#[derive(Debug, Default, Clone, Copy)]
33pub struct HyperbolicExt;
34#[derive(Debug, Default, Clone, Copy)]
35pub struct HyperbolicBuilder;
36
37type HyperbolicApiKey = BearerAuth;
38
39impl Provider for HyperbolicExt {
40    type Builder = HyperbolicBuilder;
41
42    const VERIFY_PATH: &'static str = "/models";
43
44    fn build<H>(
45        _: &crate::client::ClientBuilder<
46            Self::Builder,
47            <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
48            H,
49        >,
50    ) -> http_client::Result<Self> {
51        Ok(Self)
52    }
53}
54
55impl<H> Capabilities<H> for HyperbolicExt {
56    type Completion = Capable<CompletionModel<H>>;
57    type Embeddings = Nothing;
58    type Transcription = Nothing;
59    type ModelListing = Nothing;
60    #[cfg(feature = "image")]
61    type ImageGeneration = Capable<ImageGenerationModel<H>>;
62    #[cfg(feature = "audio")]
63    type AudioGeneration = Capable<AudioGenerationModel<H>>;
64}
65
66impl DebugExt for HyperbolicExt {}
67
68impl ProviderBuilder for HyperbolicBuilder {
69    type Output = HyperbolicExt;
70    type ApiKey = HyperbolicApiKey;
71
72    const BASE_URL: &'static str = HYPERBOLIC_API_BASE_URL;
73}
74
75pub type Client<H = reqwest::Client> = client::Client<HyperbolicExt, H>;
76pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<HyperbolicBuilder, String, H>;
77
78impl ProviderClient for Client {
79    type Input = HyperbolicApiKey;
80
81    /// Create a new Hyperbolic client from the `HYPERBOLIC_API_KEY` environment variable.
82    /// Panics if the environment variable is not set.
83    fn from_env() -> Self {
84        let api_key = std::env::var("HYPERBOLIC_API_KEY").expect("HYPERBOLIC_API_KEY not set");
85        Self::new(&api_key).unwrap()
86    }
87
88    fn from_val(input: Self::Input) -> Self {
89        Self::new(input).unwrap()
90    }
91}
92
93#[derive(Debug, Deserialize)]
94struct ApiErrorResponse {
95    message: String,
96}
97
98#[derive(Debug, Deserialize)]
99#[serde(untagged)]
100enum ApiResponse<T> {
101    Ok(T),
102    Err(ApiErrorResponse),
103}
104
105#[derive(Debug, Deserialize)]
106pub struct EmbeddingData {
107    pub object: String,
108    pub embedding: Vec<f64>,
109    pub index: usize,
110}
111
112#[derive(Clone, Debug, Deserialize, Serialize)]
113pub struct Usage {
114    pub prompt_tokens: usize,
115    pub total_tokens: usize,
116}
117
118impl std::fmt::Display for Usage {
119    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120        write!(
121            f,
122            "Prompt tokens: {} Total tokens: {}",
123            self.prompt_tokens, self.total_tokens
124        )
125    }
126}
127
128// ================================================================
129// Hyperbolic Completion API
130// ================================================================
131
132/// Meta Llama 3.1b Instruct model with 8B parameters.
133pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
134/// Meta Llama 3.3b Instruct model with 70B parameters.
135pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
136/// Meta Llama 3.1b Instruct model with 70B parameters.
137pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
138/// Meta Llama 3 Instruct model with 70B parameters.
139pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
140/// Hermes 3 Instruct model with 70B parameters.
141pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
142/// Deepseek v2.5 model.
143pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
144/// Qwen 2.5 model with 72B parameters.
145pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
146/// Meta Llama 3.2b Instruct model with 3B parameters.
147pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
148/// Qwen 2.5 Coder Instruct model with 32B parameters.
149pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
150/// Preview (latest) version of Qwen model with 32B parameters.
151pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
152/// Deepseek R1 Zero model.
153pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
154/// Deepseek R1 model.
155pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
156
157/// A Hyperbolic completion object.
158///
159/// For more information, see this link: <https://docs.hyperbolic.xyz/reference/create_chat_completion_v1_chat_completions_post>
160#[derive(Debug, Deserialize, Serialize)]
161pub struct CompletionResponse {
162    pub id: String,
163    pub object: String,
164    pub created: u64,
165    pub model: String,
166    pub choices: Vec<Choice>,
167    pub usage: Option<Usage>,
168}
169
170impl From<ApiErrorResponse> for CompletionError {
171    fn from(err: ApiErrorResponse) -> Self {
172        CompletionError::ProviderError(err.message)
173    }
174}
175
176impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
177    type Error = CompletionError;
178
179    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
180        let choice = response.choices.first().ok_or_else(|| {
181            CompletionError::ResponseError("Response contained no choices".to_owned())
182        })?;
183
184        let content = match &choice.message {
185            Message::Assistant {
186                content,
187                tool_calls,
188                ..
189            } => {
190                let mut content = content
191                    .iter()
192                    .map(|c| match c {
193                        AssistantContent::Text { text } => completion::AssistantContent::text(text),
194                        AssistantContent::Refusal { refusal } => {
195                            completion::AssistantContent::text(refusal)
196                        }
197                    })
198                    .collect::<Vec<_>>();
199
200                content.extend(
201                    tool_calls
202                        .iter()
203                        .map(|call| {
204                            completion::AssistantContent::tool_call(
205                                &call.id,
206                                &call.function.name,
207                                call.function.arguments.clone(),
208                            )
209                        })
210                        .collect::<Vec<_>>(),
211                );
212                Ok(content)
213            }
214            _ => Err(CompletionError::ResponseError(
215                "Response did not contain a valid message or tool call".into(),
216            )),
217        }?;
218
219        let choice = OneOrMany::many(content).map_err(|_| {
220            CompletionError::ResponseError(
221                "Response contained no message or tool call (empty)".to_owned(),
222            )
223        })?;
224
225        let usage = response
226            .usage
227            .as_ref()
228            .map(|usage| completion::Usage {
229                input_tokens: usage.prompt_tokens as u64,
230                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
231                total_tokens: usage.total_tokens as u64,
232                cached_input_tokens: 0,
233            })
234            .unwrap_or_default();
235
236        Ok(completion::CompletionResponse {
237            choice,
238            usage,
239            raw_response: response,
240            message_id: None,
241        })
242    }
243}
244
245#[derive(Debug, Deserialize, Serialize)]
246pub struct Choice {
247    pub index: usize,
248    pub message: Message,
249    pub finish_reason: String,
250}
251
252#[derive(Debug, Serialize, Deserialize)]
253pub(super) struct HyperbolicCompletionRequest {
254    model: String,
255    pub messages: Vec<Message>,
256    #[serde(skip_serializing_if = "Option::is_none")]
257    temperature: Option<f64>,
258    #[serde(flatten, skip_serializing_if = "Option::is_none")]
259    pub additional_params: Option<serde_json::Value>,
260}
261
262impl TryFrom<(&str, CompletionRequest)> for HyperbolicCompletionRequest {
263    type Error = CompletionError;
264
265    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
266        if req.output_schema.is_some() {
267            tracing::warn!("Structured outputs currently not supported for Hyperbolic");
268        }
269
270        let model = req.model.clone().unwrap_or_else(|| model.to_string());
271        if req.tool_choice.is_some() {
272            tracing::warn!("WARNING: `tool_choice` not supported on Hyperbolic");
273        }
274
275        if !req.tools.is_empty() {
276            tracing::warn!("WARNING: `tools` not supported on Hyperbolic");
277        }
278
279        let mut full_history: Vec<Message> = match &req.preamble {
280            Some(preamble) => vec![Message::system(preamble)],
281            None => vec![],
282        };
283
284        if let Some(docs) = req.normalized_documents() {
285            let docs: Vec<Message> = docs.try_into()?;
286            full_history.extend(docs);
287        }
288
289        let chat_history: Vec<Message> = req
290            .chat_history
291            .clone()
292            .into_iter()
293            .map(|message| message.try_into())
294            .collect::<Result<Vec<Vec<Message>>, _>>()?
295            .into_iter()
296            .flatten()
297            .collect();
298
299        full_history.extend(chat_history);
300
301        Ok(Self {
302            model: model.to_string(),
303            messages: full_history,
304            temperature: req.temperature,
305            additional_params: req.additional_params,
306        })
307    }
308}
309
310#[derive(Clone)]
311pub struct CompletionModel<T = reqwest::Client> {
312    client: Client<T>,
313    /// Name of the model (e.g.: deepseek-ai/DeepSeek-R1)
314    pub model: String,
315}
316
317impl<T> CompletionModel<T> {
318    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
319        Self {
320            client,
321            model: model.into(),
322        }
323    }
324
325    pub fn with_model(client: Client<T>, model: &str) -> Self {
326        Self {
327            client,
328            model: model.into(),
329        }
330    }
331}
332
333impl<T> completion::CompletionModel for CompletionModel<T>
334where
335    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
336{
337    type Response = CompletionResponse;
338    type StreamingResponse = openai::StreamingCompletionResponse;
339
340    type Client = Client<T>;
341
342    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
343        Self::new(client.clone(), model)
344    }
345
346    async fn completion(
347        &self,
348        completion_request: CompletionRequest,
349    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
350        let span = if tracing::Span::current().is_disabled() {
351            info_span!(
352                target: "rig::completions",
353                "chat",
354                gen_ai.operation.name = "chat",
355                gen_ai.provider.name = "hyperbolic",
356                gen_ai.request.model = self.model,
357                gen_ai.system_instructions = tracing::field::Empty,
358                gen_ai.response.id = tracing::field::Empty,
359                gen_ai.response.model = tracing::field::Empty,
360                gen_ai.usage.output_tokens = tracing::field::Empty,
361                gen_ai.usage.input_tokens = tracing::field::Empty,
362            )
363        } else {
364            tracing::Span::current()
365        };
366
367        span.record("gen_ai.system_instructions", &completion_request.preamble);
368        let request =
369            HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
370
371        if tracing::enabled!(tracing::Level::TRACE) {
372            tracing::trace!(target: "rig::completions",
373                "Hyperbolic completion request: {}",
374                serde_json::to_string_pretty(&request)?
375            );
376        }
377
378        let body = serde_json::to_vec(&request)?;
379
380        let req = self
381            .client
382            .post("/v1/chat/completions")?
383            .body(body)
384            .map_err(http_client::Error::from)?;
385
386        let async_block = async move {
387            let response = self.client.send::<_, bytes::Bytes>(req).await?;
388
389            let status = response.status();
390            let response_body = response.into_body().into_future().await?.to_vec();
391
392            if status.is_success() {
393                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
394                    ApiResponse::Ok(response) => {
395                        if tracing::enabled!(tracing::Level::TRACE) {
396                            tracing::trace!(target: "rig::completions",
397                                "Hyperbolic completion response: {}",
398                                serde_json::to_string_pretty(&response)?
399                            );
400                        }
401
402                        response.try_into()
403                    }
404                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
405                }
406            } else {
407                Err(CompletionError::ProviderError(
408                    String::from_utf8_lossy(&response_body).to_string(),
409                ))
410            }
411        };
412
413        async_block.instrument(span).await
414    }
415
416    async fn stream(
417        &self,
418        completion_request: CompletionRequest,
419    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
420        let span = if tracing::Span::current().is_disabled() {
421            info_span!(
422                target: "rig::completions",
423                "chat_streaming",
424                gen_ai.operation.name = "chat_streaming",
425                gen_ai.provider.name = "hyperbolic",
426                gen_ai.request.model = self.model,
427                gen_ai.system_instructions = tracing::field::Empty,
428                gen_ai.response.id = tracing::field::Empty,
429                gen_ai.response.model = tracing::field::Empty,
430                gen_ai.usage.output_tokens = tracing::field::Empty,
431                gen_ai.usage.input_tokens = tracing::field::Empty,
432            )
433        } else {
434            tracing::Span::current()
435        };
436
437        span.record("gen_ai.system_instructions", &completion_request.preamble);
438        let mut request =
439            HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
440
441        let params = json_utils::merge(
442            request.additional_params.unwrap_or(serde_json::json!({})),
443            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
444        );
445
446        request.additional_params = Some(params);
447
448        if tracing::enabled!(tracing::Level::TRACE) {
449            tracing::trace!(target: "rig::completions",
450                "Hyperbolic streaming completion request: {}",
451                serde_json::to_string_pretty(&request)?
452            );
453        }
454
455        let body = serde_json::to_vec(&request)?;
456
457        let req = self
458            .client
459            .post("/v1/chat/completions")?
460            .body(body)
461            .map_err(http_client::Error::from)?;
462
463        send_compatible_streaming_request(self.client.clone(), req)
464            .instrument(span)
465            .await
466    }
467}
468
469// =======================================
470// Hyperbolic Image Generation API
471// =======================================
472
473#[cfg(feature = "image")]
474pub use image_generation::*;
475
476#[cfg(feature = "image")]
477#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
478mod image_generation {
479    use super::{ApiResponse, Client};
480    use crate::http_client::HttpClientExt;
481    use crate::image_generation;
482    use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
483    use crate::json_utils::merge_inplace;
484    use base64::Engine;
485    use base64::prelude::BASE64_STANDARD;
486    use serde::Deserialize;
487    use serde_json::json;
488
489    pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
490    pub const SD2: &str = "SD2";
491    pub const SD1_5: &str = "SD1.5";
492    pub const SSD: &str = "SSD";
493    pub const SDXL_TURBO: &str = "SDXL-turbo";
494    pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
495    pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
496
497    #[derive(Clone)]
498    pub struct ImageGenerationModel<T> {
499        client: Client<T>,
500        pub model: String,
501    }
502
503    impl<T> ImageGenerationModel<T> {
504        pub(crate) fn new(client: Client<T>, model: impl Into<String>) -> Self {
505            Self {
506                client,
507                model: model.into(),
508            }
509        }
510
511        pub fn with_model(client: Client<T>, model: &str) -> Self {
512            Self {
513                client,
514                model: model.into(),
515            }
516        }
517    }
518
519    #[derive(Clone, Deserialize)]
520    pub struct Image {
521        image: String,
522    }
523
524    #[derive(Clone, Deserialize)]
525    pub struct ImageGenerationResponse {
526        images: Vec<Image>,
527    }
528
529    impl TryFrom<ImageGenerationResponse>
530        for image_generation::ImageGenerationResponse<ImageGenerationResponse>
531    {
532        type Error = ImageGenerationError;
533
534        fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
535            let data = BASE64_STANDARD
536                .decode(&value.images[0].image)
537                .expect("Could not decode image.");
538
539            Ok(Self {
540                image: data,
541                response: value,
542            })
543        }
544    }
545
546    impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
547    where
548        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
549    {
550        type Response = ImageGenerationResponse;
551
552        type Client = Client<T>;
553
554        fn make(client: &Self::Client, model: impl Into<String>) -> Self {
555            Self::new(client.clone(), model)
556        }
557
558        async fn image_generation(
559            &self,
560            generation_request: ImageGenerationRequest,
561        ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
562        {
563            let mut request = json!({
564                "model_name": self.model,
565                "prompt": generation_request.prompt,
566                "height": generation_request.height,
567                "width": generation_request.width,
568            });
569
570            if let Some(params) = generation_request.additional_params {
571                merge_inplace(&mut request, params);
572            }
573
574            let body = serde_json::to_vec(&request)?;
575
576            let request = self
577                .client
578                .post("/v1/image/generation")?
579                .header("Content-Type", "application/json")
580                .body(body)
581                .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
582
583            let response = self.client.send::<_, bytes::Bytes>(request).await?;
584
585            let status = response.status();
586            let response_body = response.into_body().into_future().await?.to_vec();
587
588            if !status.is_success() {
589                return Err(ImageGenerationError::ProviderError(format!(
590                    "{status}: {}",
591                    String::from_utf8_lossy(&response_body)
592                )));
593            }
594
595            match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
596                ApiResponse::Ok(response) => response.try_into(),
597                ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
598            }
599        }
600    }
601}
602
603// ======================================
604// Hyperbolic Audio Generation API
605// ======================================
606#[cfg(feature = "audio")]
607pub use audio_generation::*;
608use tracing::{Instrument, info_span};
609
610#[cfg(feature = "audio")]
611#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
612mod audio_generation {
613    use super::{ApiResponse, Client};
614    use crate::audio_generation;
615    use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
616    use crate::http_client::{self, HttpClientExt};
617    use base64::Engine;
618    use base64::prelude::BASE64_STANDARD;
619    use bytes::Bytes;
620    use serde::Deserialize;
621    use serde_json::json;
622
623    #[derive(Clone)]
624    pub struct AudioGenerationModel<T> {
625        client: Client<T>,
626        pub language: String,
627    }
628
629    #[derive(Clone, Deserialize)]
630    pub struct AudioGenerationResponse {
631        audio: String,
632    }
633
634    impl TryFrom<AudioGenerationResponse>
635        for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
636    {
637        type Error = AudioGenerationError;
638
639        fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
640            let data = BASE64_STANDARD
641                .decode(&value.audio)
642                .expect("Could not decode audio.");
643
644            Ok(Self {
645                audio: data,
646                response: value,
647            })
648        }
649    }
650
651    impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
652    where
653        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
654    {
655        type Response = AudioGenerationResponse;
656        type Client = Client<T>;
657
658        fn make(client: &Self::Client, language: impl Into<String>) -> Self {
659            Self {
660                client: client.clone(),
661                language: language.into(),
662            }
663        }
664
665        async fn audio_generation(
666            &self,
667            request: AudioGenerationRequest,
668        ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
669        {
670            let request = json!({
671                "language": self.language,
672                "speaker": request.voice,
673                "text": request.text,
674                "speed": request.speed
675            });
676
677            let body = serde_json::to_vec(&request)?;
678
679            let req = self
680                .client
681                .post("/v1/audio/generation")?
682                .body(body)
683                .map_err(http_client::Error::from)?;
684
685            let response = self.client.send::<_, Bytes>(req).await?;
686            let status = response.status();
687            let response_body = response.into_body().into_future().await?.to_vec();
688
689            if !status.is_success() {
690                return Err(AudioGenerationError::ProviderError(format!(
691                    "{status}: {}",
692                    String::from_utf8_lossy(&response_body)
693                )));
694            }
695
696            match serde_json::from_slice::<ApiResponse<AudioGenerationResponse>>(&response_body)? {
697                ApiResponse::Ok(response) => response.try_into(),
698                ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
699            }
700        }
701    }
702}
703
704#[cfg(test)]
705mod tests {
706    #[test]
707    fn test_client_initialization() {
708        let _client: crate::providers::hyperbolic::Client =
709            crate::providers::hyperbolic::Client::new("dummy-key").expect("Client::new() failed");
710        let _client_from_builder: crate::providers::hyperbolic::Client =
711            crate::providers::hyperbolic::Client::builder()
712                .api_key("dummy-key")
713                .build()
714                .expect("Client::builder() failed");
715    }
716}