rig/providers/
hyperbolic.rs

1//! Hyperbolic Inference API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::hyperbolic;
6//!
7//! let client = hyperbolic::Client::new("YOUR_API_KEY");
8//!
9//! let llama_3_1_8b = client.completion_model(hyperbolic::LLAMA_3_1_8B);
10//! ```
11use super::openai::{AssistantContent, send_compatible_streaming_request};
12
13use crate::client::{ClientBuilderError, CompletionClient, ProviderClient};
14use crate::json_utils::merge_inplace;
15use crate::message;
16use crate::streaming::StreamingCompletionResponse;
17
18use crate::impl_conversion_traits;
19use crate::providers::openai;
20use crate::{
21    OneOrMany,
22    completion::{self, CompletionError, CompletionRequest},
23    json_utils,
24    providers::openai::Message,
25};
26use serde::{Deserialize, Serialize};
27use serde_json::{Value, json};
28
29// ================================================================
30// Main Hyperbolic Client
31// ================================================================
32const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz/v1";
33
34pub struct ClientBuilder<'a> {
35    api_key: &'a str,
36    base_url: &'a str,
37    http_client: Option<reqwest::Client>,
38}
39
40impl<'a> ClientBuilder<'a> {
41    pub fn new(api_key: &'a str) -> Self {
42        Self {
43            api_key,
44            base_url: HYPERBOLIC_API_BASE_URL,
45            http_client: None,
46        }
47    }
48
49    pub fn base_url(mut self, base_url: &'a str) -> Self {
50        self.base_url = base_url;
51        self
52    }
53
54    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
55        self.http_client = Some(client);
56        self
57    }
58
59    pub fn build(self) -> Result<Client, ClientBuilderError> {
60        let http_client = if let Some(http_client) = self.http_client {
61            http_client
62        } else {
63            reqwest::Client::builder().build()?
64        };
65
66        Ok(Client {
67            base_url: self.base_url.to_string(),
68            api_key: self.api_key.to_string(),
69            http_client,
70        })
71    }
72}
73
74#[derive(Clone)]
75pub struct Client {
76    base_url: String,
77    api_key: String,
78    http_client: reqwest::Client,
79}
80
81impl std::fmt::Debug for Client {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        f.debug_struct("Client")
84            .field("base_url", &self.base_url)
85            .field("http_client", &self.http_client)
86            .field("api_key", &"<REDACTED>")
87            .finish()
88    }
89}
90
91impl Client {
92    /// Create a new Hyperbolic client builder.
93    ///
94    /// # Example
95    /// ```
96    /// use rig::providers::hyperbolic::{ClientBuilder, self};
97    ///
98    /// // Initialize the Hyperbolic client
99    /// let hyperbolic = Client::builder("your-hyperbolic-api-key")
100    ///    .build()
101    /// ```
102    pub fn builder(api_key: &str) -> ClientBuilder<'_> {
103        ClientBuilder::new(api_key)
104    }
105
106    /// Create a new Hyperbolic client. For more control, use the `builder` method.
107    ///
108    /// # Panics
109    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
110    pub fn new(api_key: &str) -> Self {
111        Self::builder(api_key)
112            .build()
113            .expect("Hyperbolic client should build")
114    }
115
116    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
117        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
118        self.http_client.post(url).bearer_auth(&self.api_key)
119    }
120}
121
122impl ProviderClient for Client {
123    /// Create a new Hyperbolic client from the `HYPERBOLIC_API_KEY` environment variable.
124    /// Panics if the environment variable is not set.
125    fn from_env() -> Self {
126        let api_key = std::env::var("HYPERBOLIC_API_KEY").expect("HYPERBOLIC_API_KEY not set");
127        Self::new(&api_key)
128    }
129
130    fn from_val(input: crate::client::ProviderValue) -> Self {
131        let crate::client::ProviderValue::Simple(api_key) = input else {
132            panic!("Incorrect provider value type")
133        };
134        Self::new(&api_key)
135    }
136}
137
138impl CompletionClient for Client {
139    type CompletionModel = CompletionModel;
140
141    /// Create a completion model with the given name.
142    ///
143    /// # Example
144    /// ```
145    /// use rig::providers::hyperbolic::{Client, self};
146    ///
147    /// // Initialize the Hyperbolic client
148    /// let hyperbolic = Client::new("your-hyperbolic-api-key");
149    ///
150    /// let llama_3_1_8b = hyperbolic.completion_model(hyperbolic::LLAMA_3_1_8B);
151    /// ```
152    fn completion_model(&self, model: &str) -> CompletionModel {
153        CompletionModel::new(self.clone(), model)
154    }
155}
156
157impl_conversion_traits!(
158    AsEmbeddings,
159    AsTranscription for Client
160);
161
162#[derive(Debug, Deserialize)]
163struct ApiErrorResponse {
164    message: String,
165}
166
167#[derive(Debug, Deserialize)]
168#[serde(untagged)]
169enum ApiResponse<T> {
170    Ok(T),
171    Err(ApiErrorResponse),
172}
173
174#[derive(Debug, Deserialize)]
175pub struct EmbeddingData {
176    pub object: String,
177    pub embedding: Vec<f64>,
178    pub index: usize,
179}
180
181#[derive(Clone, Debug, Deserialize, Serialize)]
182pub struct Usage {
183    pub prompt_tokens: usize,
184    pub total_tokens: usize,
185}
186
187impl std::fmt::Display for Usage {
188    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189        write!(
190            f,
191            "Prompt tokens: {} Total tokens: {}",
192            self.prompt_tokens, self.total_tokens
193        )
194    }
195}
196
197// ================================================================
198// Hyperbolic Completion API
199// ================================================================
200/// Meta Llama 3.1b Instruct model with 8B parameters.
201pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
202/// Meta Llama 3.3b Instruct model with 70B parameters.
203pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
204/// Meta Llama 3.1b Instruct model with 70B parameters.
205pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
206/// Meta Llama 3 Instruct model with 70B parameters.
207pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
208/// Hermes 3 Instruct model with 70B parameters.
209pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
210/// Deepseek v2.5 model.
211pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
212/// Qwen 2.5 model with 72B parameters.
213pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
214/// Meta Llama 3.2b Instruct model with 3B parameters.
215pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
216/// Qwen 2.5 Coder Instruct model with 32B parameters.
217pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
218/// Preview (latest) version of Qwen model with 32B parameters.
219pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
220/// Deepseek R1 Zero model.
221pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
222/// Deepseek R1 model.
223pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
224
225/// A Hyperbolic completion object.
226///
227/// For more information, see this link: <https://docs.hyperbolic.xyz/reference/create_chat_completion_v1_chat_completions_post>
228#[derive(Debug, Deserialize, Serialize)]
229pub struct CompletionResponse {
230    pub id: String,
231    pub object: String,
232    pub created: u64,
233    pub model: String,
234    pub choices: Vec<Choice>,
235    pub usage: Option<Usage>,
236}
237
238impl From<ApiErrorResponse> for CompletionError {
239    fn from(err: ApiErrorResponse) -> Self {
240        CompletionError::ProviderError(err.message)
241    }
242}
243
244impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
245    type Error = CompletionError;
246
247    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
248        let choice = response.choices.first().ok_or_else(|| {
249            CompletionError::ResponseError("Response contained no choices".to_owned())
250        })?;
251
252        let content = match &choice.message {
253            Message::Assistant {
254                content,
255                tool_calls,
256                ..
257            } => {
258                let mut content = content
259                    .iter()
260                    .map(|c| match c {
261                        AssistantContent::Text { text } => completion::AssistantContent::text(text),
262                        AssistantContent::Refusal { refusal } => {
263                            completion::AssistantContent::text(refusal)
264                        }
265                    })
266                    .collect::<Vec<_>>();
267
268                content.extend(
269                    tool_calls
270                        .iter()
271                        .map(|call| {
272                            completion::AssistantContent::tool_call(
273                                &call.id,
274                                &call.function.name,
275                                call.function.arguments.clone(),
276                            )
277                        })
278                        .collect::<Vec<_>>(),
279                );
280                Ok(content)
281            }
282            _ => Err(CompletionError::ResponseError(
283                "Response did not contain a valid message or tool call".into(),
284            )),
285        }?;
286
287        let choice = OneOrMany::many(content).map_err(|_| {
288            CompletionError::ResponseError(
289                "Response contained no message or tool call (empty)".to_owned(),
290            )
291        })?;
292
293        let usage = response
294            .usage
295            .as_ref()
296            .map(|usage| completion::Usage {
297                input_tokens: usage.prompt_tokens as u64,
298                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
299                total_tokens: usage.total_tokens as u64,
300            })
301            .unwrap_or_default();
302
303        Ok(completion::CompletionResponse {
304            choice,
305            usage,
306            raw_response: response,
307        })
308    }
309}
310
311#[derive(Debug, Deserialize, Serialize)]
312pub struct Choice {
313    pub index: usize,
314    pub message: Message,
315    pub finish_reason: String,
316}
317
318#[derive(Clone)]
319pub struct CompletionModel {
320    client: Client,
321    /// Name of the model (e.g.: deepseek-ai/DeepSeek-R1)
322    pub model: String,
323}
324
325impl CompletionModel {
326    pub(crate) fn create_completion_request(
327        &self,
328        completion_request: CompletionRequest,
329    ) -> Result<Value, CompletionError> {
330        // Build up the order of messages (context, chat_history, prompt)
331        let mut partial_history = vec![];
332        if let Some(docs) = completion_request.normalized_documents() {
333            partial_history.push(docs);
334        }
335        partial_history.extend(completion_request.chat_history);
336
337        // Initialize full history with preamble (or empty if non-existent)
338        let mut full_history: Vec<Message> = completion_request
339            .preamble
340            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
341
342        // Convert and extend the rest of the history
343        full_history.extend(
344            partial_history
345                .into_iter()
346                .map(message::Message::try_into)
347                .collect::<Result<Vec<Vec<Message>>, _>>()?
348                .into_iter()
349                .flatten()
350                .collect::<Vec<_>>(),
351        );
352
353        let request = json!({
354            "model": self.model,
355            "messages": full_history,
356            "temperature": completion_request.temperature,
357        });
358
359        let request = if let Some(params) = completion_request.additional_params {
360            json_utils::merge(request, params)
361        } else {
362            request
363        };
364
365        Ok(request)
366    }
367}
368
369impl CompletionModel {
370    pub fn new(client: Client, model: &str) -> Self {
371        Self {
372            client,
373            model: model.to_string(),
374        }
375    }
376}
377
378impl completion::CompletionModel for CompletionModel {
379    type Response = CompletionResponse;
380    type StreamingResponse = openai::StreamingCompletionResponse;
381
382    #[cfg_attr(feature = "worker", worker::send)]
383    async fn completion(
384        &self,
385        completion_request: CompletionRequest,
386    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
387        let request = self.create_completion_request(completion_request)?;
388
389        let response = self
390            .client
391            .post("/chat/completions")
392            .json(&request)
393            .send()
394            .await?;
395
396        if response.status().is_success() {
397            match response.json::<ApiResponse<CompletionResponse>>().await? {
398                ApiResponse::Ok(response) => {
399                    tracing::info!(target: "rig",
400                        "Hyperbolic completion token usage: {:?}",
401                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
402                    );
403
404                    response.try_into()
405                }
406                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
407            }
408        } else {
409            Err(CompletionError::ProviderError(response.text().await?))
410        }
411    }
412
413    #[cfg_attr(feature = "worker", worker::send)]
414    async fn stream(
415        &self,
416        completion_request: CompletionRequest,
417    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
418        let mut request = self.create_completion_request(completion_request)?;
419
420        merge_inplace(
421            &mut request,
422            json!({"stream": true, "stream_options": {"include_usage": true}}),
423        );
424
425        let builder = self.client.post("/chat/completions").json(&request);
426
427        send_compatible_streaming_request(builder).await
428    }
429}
430
431// =======================================
432// Hyperbolic Image Generation API
433// =======================================
434
435#[cfg(feature = "image")]
436pub use image_generation::*;
437
438#[cfg(feature = "image")]
439mod image_generation {
440    use super::{ApiResponse, Client};
441    use crate::client::ImageGenerationClient;
442    use crate::image_generation;
443    use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
444    use crate::json_utils::merge_inplace;
445    use base64::Engine;
446    use base64::prelude::BASE64_STANDARD;
447    use serde::Deserialize;
448    use serde_json::json;
449
450    pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
451    pub const SD2: &str = "SD2";
452    pub const SD1_5: &str = "SD1.5";
453    pub const SSD: &str = "SSD";
454    pub const SDXL_TURBO: &str = "SDXL-turbo";
455    pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
456    pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
457
458    #[cfg(feature = "image")]
459    #[derive(Clone)]
460    pub struct ImageGenerationModel {
461        client: Client,
462        pub model: String,
463    }
464
465    #[cfg(feature = "image")]
466    impl ImageGenerationModel {
467        pub(crate) fn new(client: Client, model: &str) -> ImageGenerationModel {
468            Self {
469                client,
470                model: model.to_string(),
471            }
472        }
473    }
474
475    #[cfg(feature = "image")]
476    #[derive(Clone, Deserialize)]
477    pub struct Image {
478        image: String,
479    }
480
481    #[cfg(feature = "image")]
482    #[derive(Clone, Deserialize)]
483    pub struct ImageGenerationResponse {
484        images: Vec<Image>,
485    }
486
487    #[cfg(feature = "image")]
488    impl TryFrom<ImageGenerationResponse>
489        for image_generation::ImageGenerationResponse<ImageGenerationResponse>
490    {
491        type Error = ImageGenerationError;
492
493        fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
494            let data = BASE64_STANDARD
495                .decode(&value.images[0].image)
496                .expect("Could not decode image.");
497
498            Ok(Self {
499                image: data,
500                response: value,
501            })
502        }
503    }
504
505    #[cfg(feature = "image")]
506    impl image_generation::ImageGenerationModel for ImageGenerationModel {
507        type Response = ImageGenerationResponse;
508
509        #[cfg_attr(feature = "worker", worker::send)]
510        async fn image_generation(
511            &self,
512            generation_request: ImageGenerationRequest,
513        ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
514        {
515            let mut request = json!({
516                "model_name": self.model,
517                "prompt": generation_request.prompt,
518                "height": generation_request.height,
519                "width": generation_request.width,
520            });
521
522            if let Some(params) = generation_request.additional_params {
523                merge_inplace(&mut request, params);
524            }
525
526            let response = self
527                .client
528                .post("/image/generation")
529                .json(&request)
530                .send()
531                .await?;
532
533            if !response.status().is_success() {
534                return Err(ImageGenerationError::ProviderError(format!(
535                    "{}: {}",
536                    response.status().as_str(),
537                    response.text().await?
538                )));
539            }
540
541            match response
542                .json::<ApiResponse<ImageGenerationResponse>>()
543                .await?
544            {
545                ApiResponse::Ok(response) => response.try_into(),
546                ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
547            }
548        }
549    }
550
551    impl ImageGenerationClient for Client {
552        type ImageGenerationModel = ImageGenerationModel;
553
554        /// Create an image generation model with the given name.
555        ///
556        /// # Example
557        /// ```
558        /// use rig::providers::hyperbolic::{Client, self};
559        ///
560        /// // Initialize the Hyperbolic client
561        /// let hyperbolic = Client::new("your-hyperbolic-api-key");
562        ///
563        /// let llama_3_1_8b = hyperbolic.image_generation_model(hyperbolic::SSD);
564        /// ```
565        fn image_generation_model(&self, model: &str) -> ImageGenerationModel {
566            ImageGenerationModel::new(self.clone(), model)
567        }
568    }
569}
570
571// ======================================
572// Hyperbolic Audio Generation API
573// ======================================
574#[cfg(feature = "audio")]
575pub use audio_generation::*;
576
577#[cfg(feature = "audio")]
578mod audio_generation {
579    use super::{ApiResponse, Client};
580    use crate::audio_generation;
581    use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
582    use crate::client::AudioGenerationClient;
583    use base64::Engine;
584    use base64::prelude::BASE64_STANDARD;
585    use serde::Deserialize;
586    use serde_json::json;
587
588    #[derive(Clone)]
589    pub struct AudioGenerationModel {
590        client: Client,
591        pub language: String,
592    }
593
594    impl AudioGenerationModel {
595        pub(crate) fn new(client: Client, language: &str) -> AudioGenerationModel {
596            Self {
597                client,
598                language: language.to_string(),
599            }
600        }
601    }
602
603    #[derive(Clone, Deserialize)]
604    pub struct AudioGenerationResponse {
605        audio: String,
606    }
607
608    impl TryFrom<AudioGenerationResponse>
609        for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
610    {
611        type Error = AudioGenerationError;
612
613        fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
614            let data = BASE64_STANDARD
615                .decode(&value.audio)
616                .expect("Could not decode audio.");
617
618            Ok(Self {
619                audio: data,
620                response: value,
621            })
622        }
623    }
624
625    impl audio_generation::AudioGenerationModel for AudioGenerationModel {
626        type Response = AudioGenerationResponse;
627
628        #[cfg_attr(feature = "worker", worker::send)]
629        async fn audio_generation(
630            &self,
631            request: AudioGenerationRequest,
632        ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
633        {
634            let request = json!({
635                "language": self.language,
636                "speaker": request.voice,
637                "text": request.text,
638                "speed": request.speed
639            });
640
641            let response = self
642                .client
643                .post("/audio/generation")
644                .json(&request)
645                .send()
646                .await?;
647
648            if !response.status().is_success() {
649                return Err(AudioGenerationError::ProviderError(format!(
650                    "{}: {}",
651                    response.status(),
652                    response.text().await?
653                )));
654            }
655
656            match serde_json::from_str::<ApiResponse<AudioGenerationResponse>>(
657                &response.text().await?,
658            )? {
659                ApiResponse::Ok(response) => response.try_into(),
660                ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
661            }
662        }
663    }
664    impl AudioGenerationClient for Client {
665        type AudioGenerationModel = AudioGenerationModel;
666
667        /// Create a completion model with the given name.
668        ///
669        /// # Example
670        /// ```
671        /// use rig::providers::hyperbolic::{Client, self};
672        ///
673        /// // Initialize the Hyperbolic client
674        /// let hyperbolic = Client::new("your-hyperbolic-api-key");
675        ///
676        /// let tts = hyperbolic.audio_generation_model("EN");
677        /// ```
678        fn audio_generation_model(&self, language: &str) -> AudioGenerationModel {
679            AudioGenerationModel::new(self.clone(), language)
680        }
681    }
682}