Skip to main content

rig/providers/
hyperbolic.rs

1//! Hyperbolic Inference API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::hyperbolic;
6//!
7//! let client = hyperbolic::Client::new("YOUR_API_KEY");
8//!
9//! let llama_3_1_8b = client.completion_model(hyperbolic::LLAMA_3_1_8B);
10//! ```
11use super::openai::{AssistantContent, send_compatible_streaming_request};
12
13use crate::client::{self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder};
14use crate::client::{BearerAuth, ProviderClient};
15use crate::http_client::{self, HttpClientExt};
16use crate::streaming::StreamingCompletionResponse;
17
18use crate::providers::openai;
19use crate::{
20    OneOrMany,
21    completion::{self, CompletionError, CompletionRequest},
22    json_utils,
23    providers::openai::Message,
24};
25use serde::{Deserialize, Serialize};
26
27// ================================================================
28// Main Hyperbolic Client
29// ================================================================
30const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz";
31
32#[derive(Debug, Default, Clone, Copy)]
33pub struct HyperbolicExt;
34#[derive(Debug, Default, Clone, Copy)]
35pub struct HyperbolicBuilder;
36
37type HyperbolicApiKey = BearerAuth;
38
39impl Provider for HyperbolicExt {
40    type Builder = HyperbolicBuilder;
41
42    const VERIFY_PATH: &'static str = "/models";
43
44    fn build<H>(
45        _: &crate::client::ClientBuilder<
46            Self::Builder,
47            <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
48            H,
49        >,
50    ) -> http_client::Result<Self> {
51        Ok(Self)
52    }
53}
54
55impl<H> Capabilities<H> for HyperbolicExt {
56    type Completion = Capable<CompletionModel<H>>;
57    type Embeddings = Nothing;
58    type Transcription = Nothing;
59    #[cfg(feature = "image")]
60    type ImageGeneration = Capable<ImageGenerationModel<H>>;
61    #[cfg(feature = "audio")]
62    type AudioGeneration = Capable<AudioGenerationModel<H>>;
63}
64
65impl DebugExt for HyperbolicExt {}
66
67impl ProviderBuilder for HyperbolicBuilder {
68    type Output = HyperbolicExt;
69    type ApiKey = HyperbolicApiKey;
70
71    const BASE_URL: &'static str = HYPERBOLIC_API_BASE_URL;
72}
73
74pub type Client<H = reqwest::Client> = client::Client<HyperbolicExt, H>;
75pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<HyperbolicBuilder, String, H>;
76
77impl ProviderClient for Client {
78    type Input = HyperbolicApiKey;
79
80    /// Create a new Hyperbolic client from the `HYPERBOLIC_API_KEY` environment variable.
81    /// Panics if the environment variable is not set.
82    fn from_env() -> Self {
83        let api_key = std::env::var("HYPERBOLIC_API_KEY").expect("HYPERBOLIC_API_KEY not set");
84        Self::new(&api_key).unwrap()
85    }
86
87    fn from_val(input: Self::Input) -> Self {
88        Self::new(input).unwrap()
89    }
90}
91
92#[derive(Debug, Deserialize)]
93struct ApiErrorResponse {
94    message: String,
95}
96
97#[derive(Debug, Deserialize)]
98#[serde(untagged)]
99enum ApiResponse<T> {
100    Ok(T),
101    Err(ApiErrorResponse),
102}
103
104#[derive(Debug, Deserialize)]
105pub struct EmbeddingData {
106    pub object: String,
107    pub embedding: Vec<f64>,
108    pub index: usize,
109}
110
111#[derive(Clone, Debug, Deserialize, Serialize)]
112pub struct Usage {
113    pub prompt_tokens: usize,
114    pub total_tokens: usize,
115}
116
117impl std::fmt::Display for Usage {
118    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119        write!(
120            f,
121            "Prompt tokens: {} Total tokens: {}",
122            self.prompt_tokens, self.total_tokens
123        )
124    }
125}
126
127// ================================================================
128// Hyperbolic Completion API
129// ================================================================
130
131/// Meta Llama 3.1b Instruct model with 8B parameters.
132pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
133/// Meta Llama 3.3b Instruct model with 70B parameters.
134pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
135/// Meta Llama 3.1b Instruct model with 70B parameters.
136pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
137/// Meta Llama 3 Instruct model with 70B parameters.
138pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
139/// Hermes 3 Instruct model with 70B parameters.
140pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
141/// Deepseek v2.5 model.
142pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
143/// Qwen 2.5 model with 72B parameters.
144pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
145/// Meta Llama 3.2b Instruct model with 3B parameters.
146pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
147/// Qwen 2.5 Coder Instruct model with 32B parameters.
148pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
149/// Preview (latest) version of Qwen model with 32B parameters.
150pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
151/// Deepseek R1 Zero model.
152pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
153/// Deepseek R1 model.
154pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
155
156/// A Hyperbolic completion object.
157///
158/// For more information, see this link: <https://docs.hyperbolic.xyz/reference/create_chat_completion_v1_chat_completions_post>
159#[derive(Debug, Deserialize, Serialize)]
160pub struct CompletionResponse {
161    pub id: String,
162    pub object: String,
163    pub created: u64,
164    pub model: String,
165    pub choices: Vec<Choice>,
166    pub usage: Option<Usage>,
167}
168
169impl From<ApiErrorResponse> for CompletionError {
170    fn from(err: ApiErrorResponse) -> Self {
171        CompletionError::ProviderError(err.message)
172    }
173}
174
175impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
176    type Error = CompletionError;
177
178    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
179        let choice = response.choices.first().ok_or_else(|| {
180            CompletionError::ResponseError("Response contained no choices".to_owned())
181        })?;
182
183        let content = match &choice.message {
184            Message::Assistant {
185                content,
186                tool_calls,
187                ..
188            } => {
189                let mut content = content
190                    .iter()
191                    .map(|c| match c {
192                        AssistantContent::Text { text } => completion::AssistantContent::text(text),
193                        AssistantContent::Refusal { refusal } => {
194                            completion::AssistantContent::text(refusal)
195                        }
196                    })
197                    .collect::<Vec<_>>();
198
199                content.extend(
200                    tool_calls
201                        .iter()
202                        .map(|call| {
203                            completion::AssistantContent::tool_call(
204                                &call.id,
205                                &call.function.name,
206                                call.function.arguments.clone(),
207                            )
208                        })
209                        .collect::<Vec<_>>(),
210                );
211                Ok(content)
212            }
213            _ => Err(CompletionError::ResponseError(
214                "Response did not contain a valid message or tool call".into(),
215            )),
216        }?;
217
218        let choice = OneOrMany::many(content).map_err(|_| {
219            CompletionError::ResponseError(
220                "Response contained no message or tool call (empty)".to_owned(),
221            )
222        })?;
223
224        let usage = response
225            .usage
226            .as_ref()
227            .map(|usage| completion::Usage {
228                input_tokens: usage.prompt_tokens as u64,
229                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
230                total_tokens: usage.total_tokens as u64,
231                cached_input_tokens: 0,
232            })
233            .unwrap_or_default();
234
235        Ok(completion::CompletionResponse {
236            choice,
237            usage,
238            raw_response: response,
239        })
240    }
241}
242
243#[derive(Debug, Deserialize, Serialize)]
244pub struct Choice {
245    pub index: usize,
246    pub message: Message,
247    pub finish_reason: String,
248}
249
250#[derive(Debug, Serialize, Deserialize)]
251pub(super) struct HyperbolicCompletionRequest {
252    model: String,
253    pub messages: Vec<Message>,
254    #[serde(skip_serializing_if = "Option::is_none")]
255    temperature: Option<f64>,
256    #[serde(flatten, skip_serializing_if = "Option::is_none")]
257    pub additional_params: Option<serde_json::Value>,
258}
259
260impl TryFrom<(&str, CompletionRequest)> for HyperbolicCompletionRequest {
261    type Error = CompletionError;
262
263    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
264        if req.tool_choice.is_some() {
265            tracing::warn!("WARNING: `tool_choice` not supported on Hyperbolic");
266        }
267
268        if !req.tools.is_empty() {
269            tracing::warn!("WARNING: `tools` not supported on Hyperbolic");
270        }
271
272        let mut full_history: Vec<Message> = match &req.preamble {
273            Some(preamble) => vec![Message::system(preamble)],
274            None => vec![],
275        };
276
277        if let Some(docs) = req.normalized_documents() {
278            let docs: Vec<Message> = docs.try_into()?;
279            full_history.extend(docs);
280        }
281
282        let chat_history: Vec<Message> = req
283            .chat_history
284            .clone()
285            .into_iter()
286            .map(|message| message.try_into())
287            .collect::<Result<Vec<Vec<Message>>, _>>()?
288            .into_iter()
289            .flatten()
290            .collect();
291
292        full_history.extend(chat_history);
293
294        Ok(Self {
295            model: model.to_string(),
296            messages: full_history,
297            temperature: req.temperature,
298            additional_params: req.additional_params,
299        })
300    }
301}
302
303#[derive(Clone)]
304pub struct CompletionModel<T = reqwest::Client> {
305    client: Client<T>,
306    /// Name of the model (e.g.: deepseek-ai/DeepSeek-R1)
307    pub model: String,
308}
309
310impl<T> CompletionModel<T> {
311    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
312        Self {
313            client,
314            model: model.into(),
315        }
316    }
317
318    pub fn with_model(client: Client<T>, model: &str) -> Self {
319        Self {
320            client,
321            model: model.into(),
322        }
323    }
324}
325
326impl<T> completion::CompletionModel for CompletionModel<T>
327where
328    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
329{
330    type Response = CompletionResponse;
331    type StreamingResponse = openai::StreamingCompletionResponse;
332
333    type Client = Client<T>;
334
335    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
336        Self::new(client.clone(), model)
337    }
338
339    async fn completion(
340        &self,
341        completion_request: CompletionRequest,
342    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
343        let span = if tracing::Span::current().is_disabled() {
344            info_span!(
345                target: "rig::completions",
346                "chat",
347                gen_ai.operation.name = "chat",
348                gen_ai.provider.name = "hyperbolic",
349                gen_ai.request.model = self.model,
350                gen_ai.system_instructions = tracing::field::Empty,
351                gen_ai.response.id = tracing::field::Empty,
352                gen_ai.response.model = tracing::field::Empty,
353                gen_ai.usage.output_tokens = tracing::field::Empty,
354                gen_ai.usage.input_tokens = tracing::field::Empty,
355            )
356        } else {
357            tracing::Span::current()
358        };
359
360        span.record("gen_ai.system_instructions", &completion_request.preamble);
361        let request =
362            HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
363
364        if tracing::enabled!(tracing::Level::TRACE) {
365            tracing::trace!(target: "rig::completions",
366                "Hyperbolic completion request: {}",
367                serde_json::to_string_pretty(&request)?
368            );
369        }
370
371        let body = serde_json::to_vec(&request)?;
372
373        let req = self
374            .client
375            .post("/v1/chat/completions")?
376            .body(body)
377            .map_err(http_client::Error::from)?;
378
379        let async_block = async move {
380            let response = self.client.send::<_, bytes::Bytes>(req).await?;
381
382            let status = response.status();
383            let response_body = response.into_body().into_future().await?.to_vec();
384
385            if status.is_success() {
386                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
387                    ApiResponse::Ok(response) => {
388                        if tracing::enabled!(tracing::Level::TRACE) {
389                            tracing::trace!(target: "rig::completions",
390                                "Hyperbolic completion response: {}",
391                                serde_json::to_string_pretty(&response)?
392                            );
393                        }
394
395                        response.try_into()
396                    }
397                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
398                }
399            } else {
400                Err(CompletionError::ProviderError(
401                    String::from_utf8_lossy(&response_body).to_string(),
402                ))
403            }
404        };
405
406        async_block.instrument(span).await
407    }
408
409    async fn stream(
410        &self,
411        completion_request: CompletionRequest,
412    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
413        let span = if tracing::Span::current().is_disabled() {
414            info_span!(
415                target: "rig::completions",
416                "chat_streaming",
417                gen_ai.operation.name = "chat_streaming",
418                gen_ai.provider.name = "hyperbolic",
419                gen_ai.request.model = self.model,
420                gen_ai.system_instructions = tracing::field::Empty,
421                gen_ai.response.id = tracing::field::Empty,
422                gen_ai.response.model = tracing::field::Empty,
423                gen_ai.usage.output_tokens = tracing::field::Empty,
424                gen_ai.usage.input_tokens = tracing::field::Empty,
425            )
426        } else {
427            tracing::Span::current()
428        };
429
430        span.record("gen_ai.system_instructions", &completion_request.preamble);
431        let mut request =
432            HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
433
434        let params = json_utils::merge(
435            request.additional_params.unwrap_or(serde_json::json!({})),
436            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
437        );
438
439        request.additional_params = Some(params);
440
441        if tracing::enabled!(tracing::Level::TRACE) {
442            tracing::trace!(target: "rig::completions",
443                "Hyperbolic streaming completion request: {}",
444                serde_json::to_string_pretty(&request)?
445            );
446        }
447
448        let body = serde_json::to_vec(&request)?;
449
450        let req = self
451            .client
452            .post("/v1/chat/completions")?
453            .body(body)
454            .map_err(http_client::Error::from)?;
455
456        send_compatible_streaming_request(self.client.clone(), req)
457            .instrument(span)
458            .await
459    }
460}
461
462// =======================================
463// Hyperbolic Image Generation API
464// =======================================
465
466#[cfg(feature = "image")]
467pub use image_generation::*;
468
469#[cfg(feature = "image")]
470#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
471mod image_generation {
472    use super::{ApiResponse, Client};
473    use crate::http_client::HttpClientExt;
474    use crate::image_generation;
475    use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
476    use crate::json_utils::merge_inplace;
477    use base64::Engine;
478    use base64::prelude::BASE64_STANDARD;
479    use serde::Deserialize;
480    use serde_json::json;
481
482    pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
483    pub const SD2: &str = "SD2";
484    pub const SD1_5: &str = "SD1.5";
485    pub const SSD: &str = "SSD";
486    pub const SDXL_TURBO: &str = "SDXL-turbo";
487    pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
488    pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
489
490    #[derive(Clone)]
491    pub struct ImageGenerationModel<T> {
492        client: Client<T>,
493        pub model: String,
494    }
495
496    impl<T> ImageGenerationModel<T> {
497        pub(crate) fn new(client: Client<T>, model: impl Into<String>) -> Self {
498            Self {
499                client,
500                model: model.into(),
501            }
502        }
503
504        pub fn with_model(client: Client<T>, model: &str) -> Self {
505            Self {
506                client,
507                model: model.into(),
508            }
509        }
510    }
511
512    #[derive(Clone, Deserialize)]
513    pub struct Image {
514        image: String,
515    }
516
517    #[derive(Clone, Deserialize)]
518    pub struct ImageGenerationResponse {
519        images: Vec<Image>,
520    }
521
522    impl TryFrom<ImageGenerationResponse>
523        for image_generation::ImageGenerationResponse<ImageGenerationResponse>
524    {
525        type Error = ImageGenerationError;
526
527        fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
528            let data = BASE64_STANDARD
529                .decode(&value.images[0].image)
530                .expect("Could not decode image.");
531
532            Ok(Self {
533                image: data,
534                response: value,
535            })
536        }
537    }
538
539    impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
540    where
541        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
542    {
543        type Response = ImageGenerationResponse;
544
545        type Client = Client<T>;
546
547        fn make(client: &Self::Client, model: impl Into<String>) -> Self {
548            Self::new(client.clone(), model)
549        }
550
551        async fn image_generation(
552            &self,
553            generation_request: ImageGenerationRequest,
554        ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
555        {
556            let mut request = json!({
557                "model_name": self.model,
558                "prompt": generation_request.prompt,
559                "height": generation_request.height,
560                "width": generation_request.width,
561            });
562
563            if let Some(params) = generation_request.additional_params {
564                merge_inplace(&mut request, params);
565            }
566
567            let body = serde_json::to_vec(&request)?;
568
569            let request = self
570                .client
571                .post("/v1/image/generation")?
572                .header("Content-Type", "application/json")
573                .body(body)
574                .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
575
576            let response = self.client.send::<_, bytes::Bytes>(request).await?;
577
578            let status = response.status();
579            let response_body = response.into_body().into_future().await?.to_vec();
580
581            if !status.is_success() {
582                return Err(ImageGenerationError::ProviderError(format!(
583                    "{status}: {}",
584                    String::from_utf8_lossy(&response_body)
585                )));
586            }
587
588            match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
589                ApiResponse::Ok(response) => response.try_into(),
590                ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
591            }
592        }
593    }
594}
595
596// ======================================
597// Hyperbolic Audio Generation API
598// ======================================
599#[cfg(feature = "audio")]
600pub use audio_generation::*;
601use tracing::{Instrument, info_span};
602
603#[cfg(feature = "audio")]
604#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
605mod audio_generation {
606    use super::{ApiResponse, Client};
607    use crate::audio_generation;
608    use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
609    use crate::http_client::{self, HttpClientExt};
610    use base64::Engine;
611    use base64::prelude::BASE64_STANDARD;
612    use bytes::Bytes;
613    use serde::Deserialize;
614    use serde_json::json;
615
616    #[derive(Clone)]
617    pub struct AudioGenerationModel<T> {
618        client: Client<T>,
619        pub language: String,
620    }
621
622    #[derive(Clone, Deserialize)]
623    pub struct AudioGenerationResponse {
624        audio: String,
625    }
626
627    impl TryFrom<AudioGenerationResponse>
628        for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
629    {
630        type Error = AudioGenerationError;
631
632        fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
633            let data = BASE64_STANDARD
634                .decode(&value.audio)
635                .expect("Could not decode audio.");
636
637            Ok(Self {
638                audio: data,
639                response: value,
640            })
641        }
642    }
643
644    impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
645    where
646        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
647    {
648        type Response = AudioGenerationResponse;
649        type Client = Client<T>;
650
651        fn make(client: &Self::Client, language: impl Into<String>) -> Self {
652            Self {
653                client: client.clone(),
654                language: language.into(),
655            }
656        }
657
658        async fn audio_generation(
659            &self,
660            request: AudioGenerationRequest,
661        ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
662        {
663            let request = json!({
664                "language": self.language,
665                "speaker": request.voice,
666                "text": request.text,
667                "speed": request.speed
668            });
669
670            let body = serde_json::to_vec(&request)?;
671
672            let req = self
673                .client
674                .post("/v1/audio/generation")?
675                .body(body)
676                .map_err(http_client::Error::from)?;
677
678            let response = self.client.send::<_, Bytes>(req).await?;
679            let status = response.status();
680            let response_body = response.into_body().into_future().await?.to_vec();
681
682            if !status.is_success() {
683                return Err(AudioGenerationError::ProviderError(format!(
684                    "{status}: {}",
685                    String::from_utf8_lossy(&response_body)
686                )));
687            }
688
689            match serde_json::from_slice::<ApiResponse<AudioGenerationResponse>>(&response_body)? {
690                ApiResponse::Ok(response) => response.try_into(),
691                ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
692            }
693        }
694    }
695}