rig/providers/
hyperbolic.rs

1//! Hyperbolic Inference API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::hyperbolic;
6//!
7//! let client = hyperbolic::Client::new("YOUR_API_KEY");
8//!
9//! let llama_3_1_8b = client.completion_model(hyperbolic::LLAMA_3_1_8B);
10//! ```
11use super::openai::{AssistantContent, send_compatible_streaming_request};
12
13use crate::client::{self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder};
14use crate::client::{BearerAuth, ProviderClient};
15use crate::http_client::{self, HttpClientExt};
16use crate::streaming::StreamingCompletionResponse;
17
18use crate::providers::openai;
19use crate::{
20    OneOrMany,
21    completion::{self, CompletionError, CompletionRequest},
22    json_utils,
23    providers::openai::Message,
24};
25use serde::{Deserialize, Serialize};
26
27// ================================================================
28// Main Hyperbolic Client
29// ================================================================
30const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz";
31
32#[derive(Debug, Default, Clone, Copy)]
33pub struct HyperbolicExt;
34#[derive(Debug, Default, Clone, Copy)]
35pub struct HyperbolicBuilder;
36
37type HyperbolicApiKey = BearerAuth;
38
39impl Provider for HyperbolicExt {
40    type Builder = HyperbolicBuilder;
41
42    const VERIFY_PATH: &'static str = "/models";
43
44    fn build<H>(
45        _: &crate::client::ClientBuilder<
46            Self::Builder,
47            <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
48            H,
49        >,
50    ) -> http_client::Result<Self> {
51        Ok(Self)
52    }
53}
54
55impl<H> Capabilities<H> for HyperbolicExt {
56    type Completion = Capable<CompletionModel<H>>;
57    type Embeddings = Nothing;
58    type Transcription = Nothing;
59    #[cfg(feature = "image")]
60    type ImageGeneration = Capable<ImageGenerationModel<H>>;
61    #[cfg(feature = "audio")]
62    type AudioGeneration = Capable<AudioGenerationModel<H>>;
63}
64
65impl DebugExt for HyperbolicExt {}
66
67impl ProviderBuilder for HyperbolicBuilder {
68    type Output = HyperbolicExt;
69    type ApiKey = HyperbolicApiKey;
70
71    const BASE_URL: &'static str = HYPERBOLIC_API_BASE_URL;
72}
73
74pub type Client<H = reqwest::Client> = client::Client<HyperbolicExt, H>;
75pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<HyperbolicBuilder, String, H>;
76
77impl ProviderClient for Client {
78    type Input = HyperbolicApiKey;
79
80    /// Create a new Hyperbolic client from the `HYPERBOLIC_API_KEY` environment variable.
81    /// Panics if the environment variable is not set.
82    fn from_env() -> Self {
83        let api_key = std::env::var("HYPERBOLIC_API_KEY").expect("HYPERBOLIC_API_KEY not set");
84        Self::new(&api_key).unwrap()
85    }
86
87    fn from_val(input: Self::Input) -> Self {
88        Self::new(input).unwrap()
89    }
90}
91
92#[derive(Debug, Deserialize)]
93struct ApiErrorResponse {
94    message: String,
95}
96
97#[derive(Debug, Deserialize)]
98#[serde(untagged)]
99enum ApiResponse<T> {
100    Ok(T),
101    Err(ApiErrorResponse),
102}
103
104#[derive(Debug, Deserialize)]
105pub struct EmbeddingData {
106    pub object: String,
107    pub embedding: Vec<f64>,
108    pub index: usize,
109}
110
111#[derive(Clone, Debug, Deserialize, Serialize)]
112pub struct Usage {
113    pub prompt_tokens: usize,
114    pub total_tokens: usize,
115}
116
117impl std::fmt::Display for Usage {
118    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119        write!(
120            f,
121            "Prompt tokens: {} Total tokens: {}",
122            self.prompt_tokens, self.total_tokens
123        )
124    }
125}
126
127// ================================================================
128// Hyperbolic Completion API
129// ================================================================
130
131/// Meta Llama 3.1b Instruct model with 8B parameters.
132pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
133/// Meta Llama 3.3b Instruct model with 70B parameters.
134pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
135/// Meta Llama 3.1b Instruct model with 70B parameters.
136pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
137/// Meta Llama 3 Instruct model with 70B parameters.
138pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
139/// Hermes 3 Instruct model with 70B parameters.
140pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
141/// Deepseek v2.5 model.
142pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
143/// Qwen 2.5 model with 72B parameters.
144pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
145/// Meta Llama 3.2b Instruct model with 3B parameters.
146pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
147/// Qwen 2.5 Coder Instruct model with 32B parameters.
148pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
149/// Preview (latest) version of Qwen model with 32B parameters.
150pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
151/// Deepseek R1 Zero model.
152pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
153/// Deepseek R1 model.
154pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
155
156/// A Hyperbolic completion object.
157///
158/// For more information, see this link: <https://docs.hyperbolic.xyz/reference/create_chat_completion_v1_chat_completions_post>
159#[derive(Debug, Deserialize, Serialize)]
160pub struct CompletionResponse {
161    pub id: String,
162    pub object: String,
163    pub created: u64,
164    pub model: String,
165    pub choices: Vec<Choice>,
166    pub usage: Option<Usage>,
167}
168
169impl From<ApiErrorResponse> for CompletionError {
170    fn from(err: ApiErrorResponse) -> Self {
171        CompletionError::ProviderError(err.message)
172    }
173}
174
175impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
176    type Error = CompletionError;
177
178    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
179        let choice = response.choices.first().ok_or_else(|| {
180            CompletionError::ResponseError("Response contained no choices".to_owned())
181        })?;
182
183        let content = match &choice.message {
184            Message::Assistant {
185                content,
186                tool_calls,
187                ..
188            } => {
189                let mut content = content
190                    .iter()
191                    .map(|c| match c {
192                        AssistantContent::Text { text } => completion::AssistantContent::text(text),
193                        AssistantContent::Refusal { refusal } => {
194                            completion::AssistantContent::text(refusal)
195                        }
196                    })
197                    .collect::<Vec<_>>();
198
199                content.extend(
200                    tool_calls
201                        .iter()
202                        .map(|call| {
203                            completion::AssistantContent::tool_call(
204                                &call.id,
205                                &call.function.name,
206                                call.function.arguments.clone(),
207                            )
208                        })
209                        .collect::<Vec<_>>(),
210                );
211                Ok(content)
212            }
213            _ => Err(CompletionError::ResponseError(
214                "Response did not contain a valid message or tool call".into(),
215            )),
216        }?;
217
218        let choice = OneOrMany::many(content).map_err(|_| {
219            CompletionError::ResponseError(
220                "Response contained no message or tool call (empty)".to_owned(),
221            )
222        })?;
223
224        let usage = response
225            .usage
226            .as_ref()
227            .map(|usage| completion::Usage {
228                input_tokens: usage.prompt_tokens as u64,
229                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
230                total_tokens: usage.total_tokens as u64,
231            })
232            .unwrap_or_default();
233
234        Ok(completion::CompletionResponse {
235            choice,
236            usage,
237            raw_response: response,
238        })
239    }
240}
241
242#[derive(Debug, Deserialize, Serialize)]
243pub struct Choice {
244    pub index: usize,
245    pub message: Message,
246    pub finish_reason: String,
247}
248
249#[derive(Debug, Serialize, Deserialize)]
250pub(super) struct HyperbolicCompletionRequest {
251    model: String,
252    pub messages: Vec<Message>,
253    #[serde(flatten, skip_serializing_if = "Option::is_none")]
254    temperature: Option<f64>,
255    #[serde(flatten, skip_serializing_if = "Option::is_none")]
256    pub additional_params: Option<serde_json::Value>,
257}
258
259impl TryFrom<(&str, CompletionRequest)> for HyperbolicCompletionRequest {
260    type Error = CompletionError;
261
262    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
263        if req.tool_choice.is_some() {
264            tracing::warn!("WARNING: `tool_choice` not supported on Hyperbolic");
265        }
266
267        if !req.tools.is_empty() {
268            tracing::warn!("WARNING: `tools` not supported on Hyperbolic");
269        }
270
271        let mut full_history: Vec<Message> = match &req.preamble {
272            Some(preamble) => vec![Message::system(preamble)],
273            None => vec![],
274        };
275
276        if let Some(docs) = req.normalized_documents() {
277            let docs: Vec<Message> = docs.try_into()?;
278            full_history.extend(docs);
279        }
280
281        let chat_history: Vec<Message> = req
282            .chat_history
283            .clone()
284            .into_iter()
285            .map(|message| message.try_into())
286            .collect::<Result<Vec<Vec<Message>>, _>>()?
287            .into_iter()
288            .flatten()
289            .collect();
290
291        full_history.extend(chat_history);
292
293        Ok(Self {
294            model: model.to_string(),
295            messages: full_history,
296            temperature: req.temperature,
297            additional_params: req.additional_params,
298        })
299    }
300}
301
302#[derive(Clone)]
303pub struct CompletionModel<T = reqwest::Client> {
304    client: Client<T>,
305    /// Name of the model (e.g.: deepseek-ai/DeepSeek-R1)
306    pub model: String,
307}
308
309impl<T> CompletionModel<T> {
310    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
311        Self {
312            client,
313            model: model.into(),
314        }
315    }
316
317    pub fn with_model(client: Client<T>, model: &str) -> Self {
318        Self {
319            client,
320            model: model.into(),
321        }
322    }
323}
324
325impl<T> completion::CompletionModel for CompletionModel<T>
326where
327    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
328{
329    type Response = CompletionResponse;
330    type StreamingResponse = openai::StreamingCompletionResponse;
331
332    type Client = Client<T>;
333
334    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
335        Self::new(client.clone(), model)
336    }
337
338    #[cfg_attr(feature = "worker", worker::send)]
339    async fn completion(
340        &self,
341        completion_request: CompletionRequest,
342    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
343        let preamble = completion_request.preamble.clone();
344        let request =
345            HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
346        let body = serde_json::to_vec(&request)?;
347
348        let span = if tracing::Span::current().is_disabled() {
349            info_span!(
350                target: "rig::completions",
351                "chat",
352                gen_ai.operation.name = "chat",
353                gen_ai.provider.name = "hyperbolic",
354                gen_ai.request.model = self.model,
355                gen_ai.system_instructions = preamble,
356                gen_ai.response.id = tracing::field::Empty,
357                gen_ai.response.model = tracing::field::Empty,
358                gen_ai.usage.output_tokens = tracing::field::Empty,
359                gen_ai.usage.input_tokens = tracing::field::Empty,
360                gen_ai.input.messages = serde_json::to_string(&request.messages)?,
361                gen_ai.output.messages = tracing::field::Empty,
362            )
363        } else {
364            tracing::Span::current()
365        };
366
367        let req = self
368            .client
369            .post("/v1/chat/completions")?
370            .body(body)
371            .map_err(http_client::Error::from)?;
372
373        let async_block = async move {
374            let response = self.client.send::<_, bytes::Bytes>(req).await?;
375
376            let status = response.status();
377            let response_body = response.into_body().into_future().await?.to_vec();
378
379            if status.is_success() {
380                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
381                    ApiResponse::Ok(response) => {
382                        tracing::info!(target: "rig",
383                            "Hyperbolic completion token usage: {:?}",
384                            response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
385                        );
386
387                        response.try_into()
388                    }
389                    ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
390                }
391            } else {
392                Err(CompletionError::ProviderError(
393                    String::from_utf8_lossy(&response_body).to_string(),
394                ))
395            }
396        };
397
398        async_block.instrument(span).await
399    }
400
401    #[cfg_attr(feature = "worker", worker::send)]
402    async fn stream(
403        &self,
404        completion_request: CompletionRequest,
405    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
406        let preamble = completion_request.preamble.clone();
407        let mut request =
408            HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
409
410        let span = if tracing::Span::current().is_disabled() {
411            info_span!(
412                target: "rig::completions",
413                "chat_streaming",
414                gen_ai.operation.name = "chat_streaming",
415                gen_ai.provider.name = "hyperbolic",
416                gen_ai.request.model = self.model,
417                gen_ai.system_instructions = preamble,
418                gen_ai.response.id = tracing::field::Empty,
419                gen_ai.response.model = tracing::field::Empty,
420                gen_ai.usage.output_tokens = tracing::field::Empty,
421                gen_ai.usage.input_tokens = tracing::field::Empty,
422                gen_ai.input.messages = serde_json::to_string(&request.messages)?,
423                gen_ai.output.messages = tracing::field::Empty,
424            )
425        } else {
426            tracing::Span::current()
427        };
428
429        let params = json_utils::merge(
430            request.additional_params.unwrap_or(serde_json::json!({})),
431            serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
432        );
433
434        request.additional_params = Some(params);
435
436        let body = serde_json::to_vec(&request)?;
437
438        let req = self
439            .client
440            .post("/v1/chat/completions")?
441            .body(body)
442            .map_err(http_client::Error::from)?;
443
444        send_compatible_streaming_request(self.client.http_client().clone(), req)
445            .instrument(span)
446            .await
447    }
448}
449
450// =======================================
451// Hyperbolic Image Generation API
452// =======================================
453
454#[cfg(feature = "image")]
455pub use image_generation::*;
456
457#[cfg(feature = "image")]
458#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
459mod image_generation {
460    use super::{ApiResponse, Client};
461    use crate::http_client::HttpClientExt;
462    use crate::image_generation;
463    use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
464    use crate::json_utils::merge_inplace;
465    use base64::Engine;
466    use base64::prelude::BASE64_STANDARD;
467    use serde::Deserialize;
468    use serde_json::json;
469
470    pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
471    pub const SD2: &str = "SD2";
472    pub const SD1_5: &str = "SD1.5";
473    pub const SSD: &str = "SSD";
474    pub const SDXL_TURBO: &str = "SDXL-turbo";
475    pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
476    pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
477
478    #[derive(Clone)]
479    pub struct ImageGenerationModel<T> {
480        client: Client<T>,
481        pub model: String,
482    }
483
484    impl<T> ImageGenerationModel<T> {
485        pub(crate) fn new(client: Client<T>, model: impl Into<String>) -> Self {
486            Self {
487                client,
488                model: model.into(),
489            }
490        }
491
492        pub fn with_model(client: Client<T>, model: &str) -> Self {
493            Self {
494                client,
495                model: model.into(),
496            }
497        }
498    }
499
500    #[derive(Clone, Deserialize)]
501    pub struct Image {
502        image: String,
503    }
504
505    #[derive(Clone, Deserialize)]
506    pub struct ImageGenerationResponse {
507        images: Vec<Image>,
508    }
509
510    impl TryFrom<ImageGenerationResponse>
511        for image_generation::ImageGenerationResponse<ImageGenerationResponse>
512    {
513        type Error = ImageGenerationError;
514
515        fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
516            let data = BASE64_STANDARD
517                .decode(&value.images[0].image)
518                .expect("Could not decode image.");
519
520            Ok(Self {
521                image: data,
522                response: value,
523            })
524        }
525    }
526
527    impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
528    where
529        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
530    {
531        type Response = ImageGenerationResponse;
532
533        type Client = Client<T>;
534
535        fn make(client: &Self::Client, model: impl Into<String>) -> Self {
536            Self::new(client.clone(), model)
537        }
538
539        #[cfg_attr(feature = "worker", worker::send)]
540        async fn image_generation(
541            &self,
542            generation_request: ImageGenerationRequest,
543        ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
544        {
545            let mut request = json!({
546                "model_name": self.model,
547                "prompt": generation_request.prompt,
548                "height": generation_request.height,
549                "width": generation_request.width,
550            });
551
552            if let Some(params) = generation_request.additional_params {
553                merge_inplace(&mut request, params);
554            }
555
556            let body = serde_json::to_vec(&request)?;
557
558            let request = self
559                .client
560                .post("/v1/image/generation")?
561                .header("Content-Type", "application/json")
562                .body(body)
563                .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
564
565            let response = self.client.send::<_, bytes::Bytes>(request).await?;
566
567            let status = response.status();
568            let response_body = response.into_body().into_future().await?.to_vec();
569
570            if !status.is_success() {
571                return Err(ImageGenerationError::ProviderError(format!(
572                    "{status}: {}",
573                    String::from_utf8_lossy(&response_body)
574                )));
575            }
576
577            match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
578                ApiResponse::Ok(response) => response.try_into(),
579                ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
580            }
581        }
582    }
583}
584
585// ======================================
586// Hyperbolic Audio Generation API
587// ======================================
588#[cfg(feature = "audio")]
589pub use audio_generation::*;
590use tracing::{Instrument, info_span};
591
592#[cfg(feature = "audio")]
593#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
594mod audio_generation {
595    use super::{ApiResponse, Client};
596    use crate::audio_generation;
597    use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
598    use crate::http_client::{self, HttpClientExt};
599    use base64::Engine;
600    use base64::prelude::BASE64_STANDARD;
601    use bytes::Bytes;
602    use serde::Deserialize;
603    use serde_json::json;
604
605    #[derive(Clone)]
606    pub struct AudioGenerationModel<T> {
607        client: Client<T>,
608        pub language: String,
609    }
610
611    #[derive(Clone, Deserialize)]
612    pub struct AudioGenerationResponse {
613        audio: String,
614    }
615
616    impl TryFrom<AudioGenerationResponse>
617        for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
618    {
619        type Error = AudioGenerationError;
620
621        fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
622            let data = BASE64_STANDARD
623                .decode(&value.audio)
624                .expect("Could not decode audio.");
625
626            Ok(Self {
627                audio: data,
628                response: value,
629            })
630        }
631    }
632
633    impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
634    where
635        T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
636    {
637        type Response = AudioGenerationResponse;
638        type Client = Client<T>;
639
640        fn make(client: &Self::Client, language: impl Into<String>) -> Self {
641            Self {
642                client: client.clone(),
643                language: language.into(),
644            }
645        }
646
647        #[cfg_attr(feature = "worker", worker::send)]
648        async fn audio_generation(
649            &self,
650            request: AudioGenerationRequest,
651        ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
652        {
653            let request = json!({
654                "language": self.language,
655                "speaker": request.voice,
656                "text": request.text,
657                "speed": request.speed
658            });
659
660            let body = serde_json::to_vec(&request)?;
661
662            let req = self
663                .client
664                .post("/v1/audio/generation")?
665                .body(body)
666                .map_err(http_client::Error::from)?;
667
668            let response = self.client.send::<_, Bytes>(req).await?;
669            let status = response.status();
670            let response_body = response.into_body().into_future().await?.to_vec();
671
672            if !status.is_success() {
673                return Err(AudioGenerationError::ProviderError(format!(
674                    "{status}: {}",
675                    String::from_utf8_lossy(&response_body)
676                )));
677            }
678
679            match serde_json::from_slice::<ApiResponse<AudioGenerationResponse>>(&response_body)? {
680                ApiResponse::Ok(response) => response.try_into(),
681                ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
682            }
683        }
684    }
685}