rig/providers/
hyperbolic.rs

1//! Hyperbolic Inference API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::hyperbolic;
6//!
7//! let client = hyperbolic::Client::new("YOUR_API_KEY");
8//!
9//! let llama_3_1_8b = client.completion_model(hyperbolic::LLAMA_3_1_8B);
10//! ```
11use super::openai::{AssistantContent, send_compatible_streaming_request};
12
13use crate::client::{self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder};
14use crate::client::{BearerAuth, ProviderClient};
15use crate::http_client::{self, HttpClientExt};
16use crate::streaming::StreamingCompletionResponse;
17
18use crate::providers::openai;
19use crate::{
20    OneOrMany,
21    completion::{self, CompletionError, CompletionRequest},
22    json_utils,
23    providers::openai::Message,
24};
25use serde::{Deserialize, Serialize};
26
27// ================================================================
28// Main Hyperbolic Client
29// ================================================================
30const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz";
31
32#[derive(Debug, Default, Clone, Copy)]
33pub struct HyperbolicExt;
34#[derive(Debug, Default, Clone, Copy)]
35pub struct HyperbolicBuilder;
36
37type HyperbolicApiKey = BearerAuth;
38
39impl Provider for HyperbolicExt {
40    type Builder = HyperbolicBuilder;
41
42    const VERIFY_PATH: &'static str = "/models";
43
44    fn build<H>(
45        _: &crate::client::ClientBuilder<
46            Self::Builder,
47            <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
48            H,
49        >,
50    ) -> http_client::Result<Self> {
51        Ok(Self)
52    }
53}
54
55impl<H> Capabilities<H> for HyperbolicExt {
56    type Completion = Capable<CompletionModel<H>>;
57    type Embeddings = Nothing;
58    type Transcription = Nothing;
59    #[cfg(feature = "image")]
60    type ImageGeneration = Capable<ImageGenerationModel<H>>;
61    #[cfg(feature = "audio")]
62    type AudioGeneration = Capable<AudioGenerationModel<H>>;
63}
64
65impl DebugExt for HyperbolicExt {}
66
67impl ProviderBuilder for HyperbolicBuilder {
68    type Output = HyperbolicExt;
69    type ApiKey = HyperbolicApiKey;
70
71    const BASE_URL: &'static str = HYPERBOLIC_API_BASE_URL;
72}
73
74pub type Client<H = reqwest::Client> = client::Client<HyperbolicExt, H>;
75pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<HyperbolicBuilder, String, H>;
76
77impl ProviderClient for Client {
78    type Input = HyperbolicApiKey;
79
80    /// Create a new Hyperbolic client from the `HYPERBOLIC_API_KEY` environment variable.
81    /// Panics if the environment variable is not set.
82    fn from_env() -> Self {
83        let api_key = std::env::var("HYPERBOLIC_API_KEY").expect("HYPERBOLIC_API_KEY not set");
84        Self::new(&api_key).unwrap()
85    }
86
87    fn from_val(input: Self::Input) -> Self {
88        Self::new(input).unwrap()
89    }
90}
91
92#[derive(Debug, Deserialize)]
93struct ApiErrorResponse {
94    message: String,
95}
96
97#[derive(Debug, Deserialize)]
98#[serde(untagged)]
99enum ApiResponse<T> {
100    Ok(T),
101    Err(ApiErrorResponse),
102}
103
104#[derive(Debug, Deserialize)]
105pub struct EmbeddingData {
106    pub object: String,
107    pub embedding: Vec<f64>,
108    pub index: usize,
109}
110
111#[derive(Clone, Debug, Deserialize, Serialize)]
112pub struct Usage {
113    pub prompt_tokens: usize,
114    pub total_tokens: usize,
115}
116
117impl std::fmt::Display for Usage {
118    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119        write!(
120            f,
121            "Prompt tokens: {} Total tokens: {}",
122            self.prompt_tokens, self.total_tokens
123        )
124    }
125}
126
127// ================================================================
128// Hyperbolic Completion API
129// ================================================================
130
131/// Meta Llama 3.1b Instruct model with 8B parameters.
132pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
133/// Meta Llama 3.3b Instruct model with 70B parameters.
134pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
135/// Meta Llama 3.1b Instruct model with 70B parameters.
136pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
137/// Meta Llama 3 Instruct model with 70B parameters.
138pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
139/// Hermes 3 Instruct model with 70B parameters.
140pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
141/// Deepseek v2.5 model.
142pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
143/// Qwen 2.5 model with 72B parameters.
144pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
145/// Meta Llama 3.2b Instruct model with 3B parameters.
146pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
147/// Qwen 2.5 Coder Instruct model with 32B parameters.
148pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
149/// Preview (latest) version of Qwen model with 32B parameters.
150pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
151/// Deepseek R1 Zero model.
152pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
153/// Deepseek R1 model.
154pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
155
156/// A Hyperbolic completion object.
157///
158/// For more information, see this link: <https://docs.hyperbolic.xyz/reference/create_chat_completion_v1_chat_completions_post>
159#[derive(Debug, Deserialize, Serialize)]
160pub struct CompletionResponse {
161    pub id: String,
162    pub object: String,
163    pub created: u64,
164    pub model: String,
165    pub choices: Vec<Choice>,
166    pub usage: Option<Usage>,
167}
168
169impl From<ApiErrorResponse> for CompletionError {
170    fn from(err: ApiErrorResponse) -> Self {
171        CompletionError::ProviderError(err.message)
172    }
173}
174
175impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
176    type Error = CompletionError;
177
178    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
179        let choice = response.choices.first().ok_or_else(|| {
180            CompletionError::ResponseError("Response contained no choices".to_owned())
181        })?;
182
183        let content = match &choice.message {
184            Message::Assistant {
185                content,
186                tool_calls,
187                ..
188            } => {
189                let mut content = content
190                    .iter()
191                    .map(|c| match c {
192                        AssistantContent::Text { text } => completion::AssistantContent::text(text),
193                        AssistantContent::Refusal { refusal } => {
194                            completion::AssistantContent::text(refusal)
195                        }
196                    })
197                    .collect::<Vec<_>>();
198
199                content.extend(
200                    tool_calls
201                        .iter()
202                        .map(|call| {
203                            completion::AssistantContent::tool_call(
204                                &call.id,
205                                &call.function.name,
206                                call.function.arguments.clone(),
207                            )
208                        })
209                        .collect::<Vec<_>>(),
210                );
211                Ok(content)
212            }
213            _ => Err(CompletionError::ResponseError(
214                "Response did not contain a valid message or tool call".into(),
215            )),
216        }?;
217
218        let choice = OneOrMany::many(content).map_err(|_| {
219            CompletionError::ResponseError(
220                "Response contained no message or tool call (empty)".to_owned(),
221            )
222        })?;
223
224        let usage = response
225            .usage
226            .as_ref()
227            .map(|usage| completion::Usage {
228                input_tokens: usage.prompt_tokens as u64,
229                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
230                total_tokens: usage.total_tokens as u64,
231            })
232            .unwrap_or_default();
233
234        Ok(completion::CompletionResponse {
235            choice,
236            usage,
237            raw_response: response,
238        })
239    }
240}
241
242#[derive(Debug, Deserialize, Serialize)]
243pub struct Choice {
244    pub index: usize,
245    pub message: Message,
246    pub finish_reason: String,
247}
248
249#[derive(Debug, Serialize, Deserialize)]
250pub(super) struct HyperbolicCompletionRequest {
251    model: String,
252    pub messages: Vec<Message>,
253    #[serde(skip_serializing_if = "Option::is_none")]
254    temperature: Option<f64>,
255    #[serde(flatten, skip_serializing_if = "Option::is_none")]
256    pub additional_params: Option<serde_json::Value>,
257}
258
259impl TryFrom<(&str, CompletionRequest)> for HyperbolicCompletionRequest {
260    type Error = CompletionError;
261
262    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
263        if req.tool_choice.is_some() {
264            tracing::warn!("WARNING: `tool_choice` not supported on Hyperbolic");
265        }
266
267        if !req.tools.is_empty() {
268            tracing::warn!("WARNING: `tools` not supported on Hyperbolic");
269        }
270
271        let mut full_history: Vec<Message> = match &req.preamble {
272            Some(preamble) => vec![Message::system(preamble)],
273            None => vec![],
274        };
275
276        if let Some(docs) = req.normalized_documents() {
277            let docs: Vec<Message> = docs.try_into()?;
278            full_history.extend(docs);
279        }
280
281        let chat_history: Vec<Message> = req
282            .chat_history
283            .clone()
284            .into_iter()
285            .map(|message| message.try_into())
286            .collect::<Result<Vec<Vec<Message>>, _>>()?
287            .into_iter()
288            .flatten()
289            .collect();
290
291        full_history.extend(chat_history);
292
293        Ok(Self {
294            model: model.to_string(),
295            messages: full_history,
296            temperature: req.temperature,
297            additional_params: req.additional_params,
298        })
299    }
300}
301
302#[derive(Clone)]
303pub struct CompletionModel<T = reqwest::Client> {
304    client: Client<T>,
305    /// Name of the model (e.g.: deepseek-ai/DeepSeek-R1)
306    pub model: String,
307}
308
309impl<T> CompletionModel<T> {
310    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
311        Self {
312            client,
313            model: model.into(),
314        }
315    }
316
317    pub fn with_model(client: Client<T>, model: &str) -> Self {
318        Self {
319            client,
320            model: model.into(),
321        }
322    }
323}
324
325impl<T> completion::CompletionModel for CompletionModel<T>
326where
327    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
328{
329    type Response = CompletionResponse;
330    type StreamingResponse = openai::StreamingCompletionResponse;
331
332    type Client = Client<T>;
333
334    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
335        Self::new(client.clone(), model)
336    }
337
338    async fn completion(
339        &self,
340        completion_request: CompletionRequest,
341    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
342        let span = if tracing::Span::current().is_disabled() {
343            info_span!(
344                target: "rig::completions",
345                "chat",
346                gen_ai.operation.name = "chat",
347                gen_ai.provider.name = "hyperbolic",
348                gen_ai.request.model = self.model,
349                gen_ai.system_instructions = tracing::field::Empty,
350                gen_ai.response.id = tracing::field::Empty,
351                gen_ai.response.model = tracing::field::Empty,
352                gen_ai.usage.output_tokens = tracing::field::Empty,
353                gen_ai.usage.input_tokens = tracing::field::Empty,
354            )
355        } else {
356            tracing::Span::current()
357        };
358
359        span.record("gen_ai.system_instructions", &completion_request.preamble);
360        let request =
361            HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
362
363        if tracing::enabled!(tracing::Level::TRACE) {
364            tracing::trace!(target: "rig::completions",
365                "Hyperbolic completion request: {}",
366                serde_json::to_string_pretty(&request)?
367            );
368        }
369
370        let body = serde_json::to_vec(&request)?;
371
372        let req = self
373            .client
374            .post("/v1/chat/completions")?
375            .body(body)
376            .map_err(http_client::Error::from)?;
377
378        let async_block = async move {
379            let response = self.client.send::<_, bytes::Bytes>(req).await?;
380
381            let status = response.status();
382            let response_body = response.into_body().into_future().await?.to_vec();
383
384            if status.is_success() {
385                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
386                    ApiResponse::Ok(response) => {
387                        if tracing::enabled!(tracing::Level::TRACE) {
388                            tracing::trace!(target: "rig::completions",
389                                "Hyperbolic completion response: {}",
390                                serde_json::to_string_pretty(&response)?
391                            );
392                        }
393
394                        response.try_into()
395                    }
396                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
397                }
398            } else {
399                Err(CompletionError::ProviderError(
400                    String::from_utf8_lossy(&response_body).to_string(),
401                ))
402            }
403        };
404
405        async_block.instrument(span).await
406    }
407
408    async fn stream(
409        &self,
410        completion_request: CompletionRequest,
411    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
412        let span = if tracing::Span::current().is_disabled() {
413            info_span!(
414                target: "rig::completions",
415                "chat_streaming",
416                gen_ai.operation.name = "chat_streaming",
417                gen_ai.provider.name = "hyperbolic",
418                gen_ai.request.model = self.model,
419                gen_ai.system_instructions = tracing::field::Empty,
420                gen_ai.response.id = tracing::field::Empty,
421                gen_ai.response.model = tracing::field::Empty,
422                gen_ai.usage.output_tokens = tracing::field::Empty,
423                gen_ai.usage.input_tokens = tracing::field::Empty,
424            )
425        } else {
426            tracing::Span::current()
427        };
428
429        span.record("gen_ai.system_instructions", &completion_request.preamble);
430        let mut request =
431            HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
432
433        let params = json_utils::merge(
434            request.additional_params.unwrap_or(serde_json::json!({})),
435            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
436        );
437
438        request.additional_params = Some(params);
439
440        if tracing::enabled!(tracing::Level::TRACE) {
441            tracing::trace!(target: "rig::completions",
442                "Hyperbolic streaming completion request: {}",
443                serde_json::to_string_pretty(&request)?
444            );
445        }
446
447        let body = serde_json::to_vec(&request)?;
448
449        let req = self
450            .client
451            .post("/v1/chat/completions")?
452            .body(body)
453            .map_err(http_client::Error::from)?;
454
455        send_compatible_streaming_request(self.client.clone(), req)
456            .instrument(span)
457            .await
458    }
459}
460
461// =======================================
462// Hyperbolic Image Generation API
463// =======================================
464
465#[cfg(feature = "image")]
466pub use image_generation::*;
467
468#[cfg(feature = "image")]
469#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
470mod image_generation {
471    use super::{ApiResponse, Client};
472    use crate::http_client::HttpClientExt;
473    use crate::image_generation;
474    use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
475    use crate::json_utils::merge_inplace;
476    use base64::Engine;
477    use base64::prelude::BASE64_STANDARD;
478    use serde::Deserialize;
479    use serde_json::json;
480
481    pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
482    pub const SD2: &str = "SD2";
483    pub const SD1_5: &str = "SD1.5";
484    pub const SSD: &str = "SSD";
485    pub const SDXL_TURBO: &str = "SDXL-turbo";
486    pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
487    pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
488
489    #[derive(Clone)]
490    pub struct ImageGenerationModel<T> {
491        client: Client<T>,
492        pub model: String,
493    }
494
495    impl<T> ImageGenerationModel<T> {
496        pub(crate) fn new(client: Client<T>, model: impl Into<String>) -> Self {
497            Self {
498                client,
499                model: model.into(),
500            }
501        }
502
503        pub fn with_model(client: Client<T>, model: &str) -> Self {
504            Self {
505                client,
506                model: model.into(),
507            }
508        }
509    }
510
511    #[derive(Clone, Deserialize)]
512    pub struct Image {
513        image: String,
514    }
515
516    #[derive(Clone, Deserialize)]
517    pub struct ImageGenerationResponse {
518        images: Vec<Image>,
519    }
520
521    impl TryFrom<ImageGenerationResponse>
522        for image_generation::ImageGenerationResponse<ImageGenerationResponse>
523    {
524        type Error = ImageGenerationError;
525
526        fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
527            let data = BASE64_STANDARD
528                .decode(&value.images[0].image)
529                .expect("Could not decode image.");
530
531            Ok(Self {
532                image: data,
533                response: value,
534            })
535        }
536    }
537
538    impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
539    where
540        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
541    {
542        type Response = ImageGenerationResponse;
543
544        type Client = Client<T>;
545
546        fn make(client: &Self::Client, model: impl Into<String>) -> Self {
547            Self::new(client.clone(), model)
548        }
549
550        async fn image_generation(
551            &self,
552            generation_request: ImageGenerationRequest,
553        ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
554        {
555            let mut request = json!({
556                "model_name": self.model,
557                "prompt": generation_request.prompt,
558                "height": generation_request.height,
559                "width": generation_request.width,
560            });
561
562            if let Some(params) = generation_request.additional_params {
563                merge_inplace(&mut request, params);
564            }
565
566            let body = serde_json::to_vec(&request)?;
567
568            let request = self
569                .client
570                .post("/v1/image/generation")?
571                .header("Content-Type", "application/json")
572                .body(body)
573                .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
574
575            let response = self.client.send::<_, bytes::Bytes>(request).await?;
576
577            let status = response.status();
578            let response_body = response.into_body().into_future().await?.to_vec();
579
580            if !status.is_success() {
581                return Err(ImageGenerationError::ProviderError(format!(
582                    "{status}: {}",
583                    String::from_utf8_lossy(&response_body)
584                )));
585            }
586
587            match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
588                ApiResponse::Ok(response) => response.try_into(),
589                ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
590            }
591        }
592    }
593}
594
595// ======================================
596// Hyperbolic Audio Generation API
597// ======================================
598#[cfg(feature = "audio")]
599pub use audio_generation::*;
600use tracing::{Instrument, info_span};
601
602#[cfg(feature = "audio")]
603#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
604mod audio_generation {
605    use super::{ApiResponse, Client};
606    use crate::audio_generation;
607    use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
608    use crate::http_client::{self, HttpClientExt};
609    use base64::Engine;
610    use base64::prelude::BASE64_STANDARD;
611    use bytes::Bytes;
612    use serde::Deserialize;
613    use serde_json::json;
614
615    #[derive(Clone)]
616    pub struct AudioGenerationModel<T> {
617        client: Client<T>,
618        pub language: String,
619    }
620
621    #[derive(Clone, Deserialize)]
622    pub struct AudioGenerationResponse {
623        audio: String,
624    }
625
626    impl TryFrom<AudioGenerationResponse>
627        for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
628    {
629        type Error = AudioGenerationError;
630
631        fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
632            let data = BASE64_STANDARD
633                .decode(&value.audio)
634                .expect("Could not decode audio.");
635
636            Ok(Self {
637                audio: data,
638                response: value,
639            })
640        }
641    }
642
643    impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
644    where
645        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
646    {
647        type Response = AudioGenerationResponse;
648        type Client = Client<T>;
649
650        fn make(client: &Self::Client, language: impl Into<String>) -> Self {
651            Self {
652                client: client.clone(),
653                language: language.into(),
654            }
655        }
656
657        async fn audio_generation(
658            &self,
659            request: AudioGenerationRequest,
660        ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
661        {
662            let request = json!({
663                "language": self.language,
664                "speaker": request.voice,
665                "text": request.text,
666                "speed": request.speed
667            });
668
669            let body = serde_json::to_vec(&request)?;
670
671            let req = self
672                .client
673                .post("/v1/audio/generation")?
674                .body(body)
675                .map_err(http_client::Error::from)?;
676
677            let response = self.client.send::<_, Bytes>(req).await?;
678            let status = response.status();
679            let response_body = response.into_body().into_future().await?.to_vec();
680
681            if !status.is_success() {
682                return Err(AudioGenerationError::ProviderError(format!(
683                    "{status}: {}",
684                    String::from_utf8_lossy(&response_body)
685                )));
686            }
687
688            match serde_json::from_slice::<ApiResponse<AudioGenerationResponse>>(&response_body)? {
689                ApiResponse::Ok(response) => response.try_into(),
690                ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
691            }
692        }
693    }
694}