rig/providers/
hyperbolic.rs

1//! Hyperbolic Inference API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::hyperbolic;
6//!
7//! let client = hyperbolic::Client::new("YOUR_API_KEY");
8//!
9//! let llama_3_1_8b = client.completion_model(hyperbolic::LLAMA_3_1_8B);
10//! ```
11
12use super::openai::{AssistantContent, send_compatible_streaming_request};
13
14use crate::client::{CompletionClient, ProviderClient};
15use crate::json_utils::merge_inplace;
16use crate::message;
17use crate::streaming::StreamingCompletionResponse;
18
19use crate::impl_conversion_traits;
20use crate::providers::openai;
21use crate::{
22    OneOrMany,
23    completion::{self, CompletionError, CompletionRequest},
24    json_utils,
25    providers::openai::Message,
26};
27use serde::Deserialize;
28use serde_json::{Value, json};
29
30// ================================================================
31// Main Hyperbolic Client
32// ================================================================
33const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz/v1";
34
35#[derive(Clone)]
36pub struct Client {
37    base_url: String,
38    api_key: String,
39    http_client: reqwest::Client,
40}
41
42impl std::fmt::Debug for Client {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        f.debug_struct("Client")
45            .field("base_url", &self.base_url)
46            .field("http_client", &self.http_client)
47            .field("api_key", &"<REDACTED>")
48            .finish()
49    }
50}
51
52impl Client {
53    /// Create a new Hyperbolic client with the given API key.
54    pub fn new(api_key: &str) -> Self {
55        Self::from_url(api_key, HYPERBOLIC_API_BASE_URL)
56    }
57
58    /// Create a new OpenAI client with the given API key and base API URL.
59    pub fn from_url(api_key: &str, base_url: &str) -> Self {
60        Self {
61            base_url: base_url.to_string(),
62            api_key: api_key.to_string(),
63            http_client: reqwest::Client::builder()
64                .build()
65                .expect("OpenAI reqwest client should build"),
66        }
67    }
68
69    /// Use your own `reqwest::Client`.
70    /// The required headers will be automatically attached upon trying to make a request.
71    pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
72        self.http_client = client;
73
74        self
75    }
76
77    fn post(&self, path: &str) -> reqwest::RequestBuilder {
78        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
79        self.http_client.post(url).bearer_auth(&self.api_key)
80    }
81}
82
83impl ProviderClient for Client {
84    /// Create a new Hyperbolic client from the `HYPERBOLIC_API_KEY` environment variable.
85    /// Panics if the environment variable is not set.
86    fn from_env() -> Self {
87        let api_key = std::env::var("HYPERBOLIC_API_KEY").expect("HYPERBOLIC_API_KEY not set");
88        Self::new(&api_key)
89    }
90}
91
92impl CompletionClient for Client {
93    type CompletionModel = CompletionModel;
94
95    /// Create a completion model with the given name.
96    ///
97    /// # Example
98    /// ```
99    /// use rig::providers::hyperbolic::{Client, self};
100    ///
101    /// // Initialize the Hyperbolic client
102    /// let hyperbolic = Client::new("your-hyperbolic-api-key");
103    ///
104    /// let llama_3_1_8b = hyperbolic.completion_model(hyperbolic::LLAMA_3_1_8B);
105    /// ```
106    fn completion_model(&self, model: &str) -> CompletionModel {
107        CompletionModel::new(self.clone(), model)
108    }
109}
110
111impl_conversion_traits!(
112    AsEmbeddings,
113    AsTranscription for Client
114);
115
116#[derive(Debug, Deserialize)]
117struct ApiErrorResponse {
118    message: String,
119}
120
121#[derive(Debug, Deserialize)]
122#[serde(untagged)]
123enum ApiResponse<T> {
124    Ok(T),
125    Err(ApiErrorResponse),
126}
127
128#[derive(Debug, Deserialize)]
129pub struct EmbeddingData {
130    pub object: String,
131    pub embedding: Vec<f64>,
132    pub index: usize,
133}
134
135#[derive(Clone, Debug, Deserialize)]
136pub struct Usage {
137    pub prompt_tokens: usize,
138    pub total_tokens: usize,
139}
140
141impl std::fmt::Display for Usage {
142    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143        write!(
144            f,
145            "Prompt tokens: {} Total tokens: {}",
146            self.prompt_tokens, self.total_tokens
147        )
148    }
149}
150
151// ================================================================
152// Hyperbolic Completion API
153// ================================================================
154/// Meta Llama 3.1b Instruct model with 8B parameters.
155pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
156/// Meta Llama 3.3b Instruct model with 70B parameters.
157pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
158/// Meta Llama 3.1b Instruct model with 70B parameters.
159pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
160/// Meta Llama 3 Instruct model with 70B parameters.
161pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
162/// Hermes 3 Instruct model with 70B parameters.
163pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
164/// Deepseek v2.5 model.
165pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
166/// Qwen 2.5 model with 72B parameters.
167pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
168/// Meta Llama 3.2b Instruct model with 3B parameters.
169pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
170/// Qwen 2.5 Coder Instruct model with 32B parameters.
171pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
172/// Preview (latest) version of Qwen model with 32B parameters.
173pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
174/// Deepseek R1 Zero model.
175pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
176/// Deepseek R1 model.
177pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
178
179/// A Hyperbolic completion object.
180///
181/// For more information, see this link: <https://docs.hyperbolic.xyz/reference/create_chat_completion_v1_chat_completions_post>
182#[derive(Debug, Deserialize)]
183pub struct CompletionResponse {
184    pub id: String,
185    pub object: String,
186    pub created: u64,
187    pub model: String,
188    pub choices: Vec<Choice>,
189    pub usage: Option<Usage>,
190}
191
192impl From<ApiErrorResponse> for CompletionError {
193    fn from(err: ApiErrorResponse) -> Self {
194        CompletionError::ProviderError(err.message)
195    }
196}
197
198impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
199    type Error = CompletionError;
200
201    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
202        let choice = response.choices.first().ok_or_else(|| {
203            CompletionError::ResponseError("Response contained no choices".to_owned())
204        })?;
205
206        let content = match &choice.message {
207            Message::Assistant {
208                content,
209                tool_calls,
210                ..
211            } => {
212                let mut content = content
213                    .iter()
214                    .map(|c| match c {
215                        AssistantContent::Text { text } => completion::AssistantContent::text(text),
216                        AssistantContent::Refusal { refusal } => {
217                            completion::AssistantContent::text(refusal)
218                        }
219                    })
220                    .collect::<Vec<_>>();
221
222                content.extend(
223                    tool_calls
224                        .iter()
225                        .map(|call| {
226                            completion::AssistantContent::tool_call(
227                                &call.id,
228                                &call.function.name,
229                                call.function.arguments.clone(),
230                            )
231                        })
232                        .collect::<Vec<_>>(),
233                );
234                Ok(content)
235            }
236            _ => Err(CompletionError::ResponseError(
237                "Response did not contain a valid message or tool call".into(),
238            )),
239        }?;
240
241        let choice = OneOrMany::many(content).map_err(|_| {
242            CompletionError::ResponseError(
243                "Response contained no message or tool call (empty)".to_owned(),
244            )
245        })?;
246
247        Ok(completion::CompletionResponse {
248            choice,
249            raw_response: response,
250        })
251    }
252}
253
254#[derive(Debug, Deserialize)]
255pub struct Choice {
256    pub index: usize,
257    pub message: Message,
258    pub finish_reason: String,
259}
260
261#[derive(Clone)]
262pub struct CompletionModel {
263    client: Client,
264    /// Name of the model (e.g.: deepseek-ai/DeepSeek-R1)
265    pub model: String,
266}
267
268impl CompletionModel {
269    pub(crate) fn create_completion_request(
270        &self,
271        completion_request: CompletionRequest,
272    ) -> Result<Value, CompletionError> {
273        // Build up the order of messages (context, chat_history, prompt)
274        let mut partial_history = vec![];
275        if let Some(docs) = completion_request.normalized_documents() {
276            partial_history.push(docs);
277        }
278        partial_history.extend(completion_request.chat_history);
279
280        // Initialize full history with preamble (or empty if non-existent)
281        let mut full_history: Vec<Message> = completion_request
282            .preamble
283            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
284
285        // Convert and extend the rest of the history
286        full_history.extend(
287            partial_history
288                .into_iter()
289                .map(message::Message::try_into)
290                .collect::<Result<Vec<Vec<Message>>, _>>()?
291                .into_iter()
292                .flatten()
293                .collect::<Vec<_>>(),
294        );
295
296        let request = json!({
297            "model": self.model,
298            "messages": full_history,
299            "temperature": completion_request.temperature,
300        });
301
302        let request = if let Some(params) = completion_request.additional_params {
303            json_utils::merge(request, params)
304        } else {
305            request
306        };
307
308        Ok(request)
309    }
310}
311
312impl CompletionModel {
313    pub fn new(client: Client, model: &str) -> Self {
314        Self {
315            client,
316            model: model.to_string(),
317        }
318    }
319}
320
321impl completion::CompletionModel for CompletionModel {
322    type Response = CompletionResponse;
323    type StreamingResponse = openai::StreamingCompletionResponse;
324
325    #[cfg_attr(feature = "worker", worker::send)]
326    async fn completion(
327        &self,
328        completion_request: CompletionRequest,
329    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
330        let request = self.create_completion_request(completion_request)?;
331
332        let response = self
333            .client
334            .post("/chat/completions")
335            .json(&request)
336            .send()
337            .await?;
338
339        if response.status().is_success() {
340            match response.json::<ApiResponse<CompletionResponse>>().await? {
341                ApiResponse::Ok(response) => {
342                    tracing::info!(target: "rig",
343                        "Hyperbolic completion token usage: {:?}",
344                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
345                    );
346
347                    response.try_into()
348                }
349                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
350            }
351        } else {
352            Err(CompletionError::ProviderError(response.text().await?))
353        }
354    }
355
356    #[cfg_attr(feature = "worker", worker::send)]
357    async fn stream(
358        &self,
359        completion_request: CompletionRequest,
360    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
361        let mut request = self.create_completion_request(completion_request)?;
362
363        merge_inplace(
364            &mut request,
365            json!({"stream": true, "stream_options": {"include_usage": true}}),
366        );
367
368        let builder = self.client.post("/chat/completions").json(&request);
369
370        send_compatible_streaming_request(builder).await
371    }
372}
373
374// =======================================
375// Hyperbolic Image Generation API
376// =======================================
377
378#[cfg(feature = "image")]
379pub use image_generation::*;
380
381#[cfg(feature = "image")]
382mod image_generation {
383    use super::{ApiResponse, Client};
384    use crate::client::ImageGenerationClient;
385    use crate::image_generation;
386    use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
387    use crate::json_utils::merge_inplace;
388    use base64::Engine;
389    use base64::prelude::BASE64_STANDARD;
390    use serde::Deserialize;
391    use serde_json::json;
392
393    pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
394    pub const SD2: &str = "SD2";
395    pub const SD1_5: &str = "SD1.5";
396    pub const SSD: &str = "SSD";
397    pub const SDXL_TURBO: &str = "SDXL-turbo";
398    pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
399    pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
400
401    #[cfg(feature = "image")]
402    #[derive(Clone)]
403    pub struct ImageGenerationModel {
404        client: Client,
405        pub model: String,
406    }
407
408    #[cfg(feature = "image")]
409    impl ImageGenerationModel {
410        pub(crate) fn new(client: Client, model: &str) -> ImageGenerationModel {
411            Self {
412                client,
413                model: model.to_string(),
414            }
415        }
416    }
417
418    #[cfg(feature = "image")]
419    #[derive(Clone, Deserialize)]
420    pub struct Image {
421        image: String,
422    }
423
424    #[cfg(feature = "image")]
425    #[derive(Clone, Deserialize)]
426    pub struct ImageGenerationResponse {
427        images: Vec<Image>,
428    }
429
430    #[cfg(feature = "image")]
431    impl TryFrom<ImageGenerationResponse>
432        for image_generation::ImageGenerationResponse<ImageGenerationResponse>
433    {
434        type Error = ImageGenerationError;
435
436        fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
437            let data = BASE64_STANDARD
438                .decode(&value.images[0].image)
439                .expect("Could not decode image.");
440
441            Ok(Self {
442                image: data,
443                response: value,
444            })
445        }
446    }
447
448    #[cfg(feature = "image")]
449    impl image_generation::ImageGenerationModel for ImageGenerationModel {
450        type Response = ImageGenerationResponse;
451
452        async fn image_generation(
453            &self,
454            generation_request: ImageGenerationRequest,
455        ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
456        {
457            let mut request = json!({
458                "model_name": self.model,
459                "prompt": generation_request.prompt,
460                "height": generation_request.height,
461                "width": generation_request.width,
462            });
463
464            if let Some(params) = generation_request.additional_params {
465                merge_inplace(&mut request, params);
466            }
467
468            let response = self
469                .client
470                .post("/image/generation")
471                .json(&request)
472                .send()
473                .await?;
474
475            if !response.status().is_success() {
476                return Err(ImageGenerationError::ProviderError(format!(
477                    "{}: {}",
478                    response.status().as_str(),
479                    response.text().await?
480                )));
481            }
482
483            match response
484                .json::<ApiResponse<ImageGenerationResponse>>()
485                .await?
486            {
487                ApiResponse::Ok(response) => response.try_into(),
488                ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
489            }
490        }
491    }
492
493    impl ImageGenerationClient for Client {
494        type ImageGenerationModel = ImageGenerationModel;
495
496        /// Create an image generation model with the given name.
497        ///
498        /// # Example
499        /// ```
500        /// use rig::providers::hyperbolic::{Client, self};
501        ///
502        /// // Initialize the Hyperbolic client
503        /// let hyperbolic = Client::new("your-hyperbolic-api-key");
504        ///
505        /// let llama_3_1_8b = hyperbolic.image_generation_model(hyperbolic::SSD);
506        /// ```
507        fn image_generation_model(&self, model: &str) -> ImageGenerationModel {
508            ImageGenerationModel::new(self.clone(), model)
509        }
510    }
511}
512
513// ======================================
514// Hyperbolic Audio Generation API
515// ======================================
516#[cfg(feature = "audio")]
517pub use audio_generation::*;
518
519#[cfg(feature = "audio")]
520mod audio_generation {
521    use super::{ApiResponse, Client};
522    use crate::audio_generation;
523    use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
524    use crate::client::AudioGenerationClient;
525    use base64::Engine;
526    use base64::prelude::BASE64_STANDARD;
527    use serde::Deserialize;
528    use serde_json::json;
529
530    #[derive(Clone)]
531    pub struct AudioGenerationModel {
532        client: Client,
533        pub language: String,
534    }
535
536    impl AudioGenerationModel {
537        pub(crate) fn new(client: Client, language: &str) -> AudioGenerationModel {
538            Self {
539                client,
540                language: language.to_string(),
541            }
542        }
543    }
544
545    #[derive(Clone, Deserialize)]
546    pub struct AudioGenerationResponse {
547        audio: String,
548    }
549
550    impl TryFrom<AudioGenerationResponse>
551        for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
552    {
553        type Error = AudioGenerationError;
554
555        fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
556            let data = BASE64_STANDARD
557                .decode(&value.audio)
558                .expect("Could not decode audio.");
559
560            Ok(Self {
561                audio: data,
562                response: value,
563            })
564        }
565    }
566
567    impl audio_generation::AudioGenerationModel for AudioGenerationModel {
568        type Response = AudioGenerationResponse;
569
570        async fn audio_generation(
571            &self,
572            request: AudioGenerationRequest,
573        ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
574        {
575            let request = json!({
576                "language": self.language,
577                "speaker": request.voice,
578                "text": request.text,
579                "speed": request.speed
580            });
581
582            let response = self
583                .client
584                .post("/audio/generation")
585                .json(&request)
586                .send()
587                .await?;
588
589            if !response.status().is_success() {
590                return Err(AudioGenerationError::ProviderError(format!(
591                    "{}: {}",
592                    response.status(),
593                    response.text().await?
594                )));
595            }
596
597            match serde_json::from_str::<ApiResponse<AudioGenerationResponse>>(
598                &response.text().await?,
599            )? {
600                ApiResponse::Ok(response) => response.try_into(),
601                ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
602            }
603        }
604    }
605    impl AudioGenerationClient for Client {
606        type AudioGenerationModel = AudioGenerationModel;
607
608        /// Create a completion model with the given name.
609        ///
610        /// # Example
611        /// ```
612        /// use rig::providers::hyperbolic::{Client, self};
613        ///
614        /// // Initialize the Hyperbolic client
615        /// let hyperbolic = Client::new("your-hyperbolic-api-key");
616        ///
617        /// let tts = hyperbolic.audio_generation_model("EN");
618        /// ```
619        fn audio_generation_model(&self, language: &str) -> AudioGenerationModel {
620            AudioGenerationModel::new(self.clone(), language)
621        }
622    }
623}