rig/providers/
hyperbolic.rs

1//! Hyperbolic Inference API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::hyperbolic;
6//!
7//! let client = hyperbolic::Client::new("YOUR_API_KEY");
8//!
9//! let llama_3_1_8b = client.completion_model(hyperbolic::LLAMA_3_1_8B);
10//! ```
11
12use super::openai::{AssistantContent, send_compatible_streaming_request};
13
14use crate::client::{CompletionClient, ProviderClient};
15use crate::json_utils::merge_inplace;
16use crate::message;
17use crate::streaming::StreamingCompletionResponse;
18
19use crate::impl_conversion_traits;
20use crate::providers::openai;
21use crate::{
22    OneOrMany,
23    completion::{self, CompletionError, CompletionRequest},
24    json_utils,
25    providers::openai::Message,
26};
27use serde::Deserialize;
28use serde_json::{Value, json};
29
30// ================================================================
31// Main Hyperbolic Client
32// ================================================================
33const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz/v1";
34
35#[derive(Clone)]
36pub struct Client {
37    base_url: String,
38    api_key: String,
39    http_client: reqwest::Client,
40}
41
42impl std::fmt::Debug for Client {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        f.debug_struct("Client")
45            .field("base_url", &self.base_url)
46            .field("http_client", &self.http_client)
47            .field("api_key", &"<REDACTED>")
48            .finish()
49    }
50}
51
52impl Client {
53    /// Create a new Hyperbolic client with the given API key.
54    pub fn new(api_key: &str) -> Self {
55        Self::from_url(api_key, HYPERBOLIC_API_BASE_URL)
56    }
57
58    /// Create a new OpenAI client with the given API key and base API URL.
59    pub fn from_url(api_key: &str, base_url: &str) -> Self {
60        Self {
61            base_url: base_url.to_string(),
62            api_key: api_key.to_string(),
63            http_client: reqwest::Client::builder()
64                .build()
65                .expect("OpenAI reqwest client should build"),
66        }
67    }
68
69    /// Use your own `reqwest::Client`.
70    /// The required headers will be automatically attached upon trying to make a request.
71    pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
72        self.http_client = client;
73
74        self
75    }
76
77    fn post(&self, path: &str) -> reqwest::RequestBuilder {
78        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
79        self.http_client.post(url).bearer_auth(&self.api_key)
80    }
81}
82
83impl ProviderClient for Client {
84    /// Create a new Hyperbolic client from the `HYPERBOLIC_API_KEY` environment variable.
85    /// Panics if the environment variable is not set.
86    fn from_env() -> Self {
87        let api_key = std::env::var("HYPERBOLIC_API_KEY").expect("HYPERBOLIC_API_KEY not set");
88        Self::new(&api_key)
89    }
90
91    fn from_val(input: crate::client::ProviderValue) -> Self {
92        let crate::client::ProviderValue::Simple(api_key) = input else {
93            panic!("Incorrect provider value type")
94        };
95        Self::new(&api_key)
96    }
97}
98
99impl CompletionClient for Client {
100    type CompletionModel = CompletionModel;
101
102    /// Create a completion model with the given name.
103    ///
104    /// # Example
105    /// ```
106    /// use rig::providers::hyperbolic::{Client, self};
107    ///
108    /// // Initialize the Hyperbolic client
109    /// let hyperbolic = Client::new("your-hyperbolic-api-key");
110    ///
111    /// let llama_3_1_8b = hyperbolic.completion_model(hyperbolic::LLAMA_3_1_8B);
112    /// ```
113    fn completion_model(&self, model: &str) -> CompletionModel {
114        CompletionModel::new(self.clone(), model)
115    }
116}
117
118impl_conversion_traits!(
119    AsEmbeddings,
120    AsTranscription for Client
121);
122
123#[derive(Debug, Deserialize)]
124struct ApiErrorResponse {
125    message: String,
126}
127
128#[derive(Debug, Deserialize)]
129#[serde(untagged)]
130enum ApiResponse<T> {
131    Ok(T),
132    Err(ApiErrorResponse),
133}
134
135#[derive(Debug, Deserialize)]
136pub struct EmbeddingData {
137    pub object: String,
138    pub embedding: Vec<f64>,
139    pub index: usize,
140}
141
142#[derive(Clone, Debug, Deserialize)]
143pub struct Usage {
144    pub prompt_tokens: usize,
145    pub total_tokens: usize,
146}
147
148impl std::fmt::Display for Usage {
149    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150        write!(
151            f,
152            "Prompt tokens: {} Total tokens: {}",
153            self.prompt_tokens, self.total_tokens
154        )
155    }
156}
157
158// ================================================================
159// Hyperbolic Completion API
160// ================================================================
161/// Meta Llama 3.1b Instruct model with 8B parameters.
162pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
163/// Meta Llama 3.3b Instruct model with 70B parameters.
164pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
165/// Meta Llama 3.1b Instruct model with 70B parameters.
166pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
167/// Meta Llama 3 Instruct model with 70B parameters.
168pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
169/// Hermes 3 Instruct model with 70B parameters.
170pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
171/// Deepseek v2.5 model.
172pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
173/// Qwen 2.5 model with 72B parameters.
174pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
175/// Meta Llama 3.2b Instruct model with 3B parameters.
176pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
177/// Qwen 2.5 Coder Instruct model with 32B parameters.
178pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
179/// Preview (latest) version of Qwen model with 32B parameters.
180pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
181/// Deepseek R1 Zero model.
182pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
183/// Deepseek R1 model.
184pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
185
186/// A Hyperbolic completion object.
187///
188/// For more information, see this link: <https://docs.hyperbolic.xyz/reference/create_chat_completion_v1_chat_completions_post>
189#[derive(Debug, Deserialize)]
190pub struct CompletionResponse {
191    pub id: String,
192    pub object: String,
193    pub created: u64,
194    pub model: String,
195    pub choices: Vec<Choice>,
196    pub usage: Option<Usage>,
197}
198
199impl From<ApiErrorResponse> for CompletionError {
200    fn from(err: ApiErrorResponse) -> Self {
201        CompletionError::ProviderError(err.message)
202    }
203}
204
205impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
206    type Error = CompletionError;
207
208    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
209        let choice = response.choices.first().ok_or_else(|| {
210            CompletionError::ResponseError("Response contained no choices".to_owned())
211        })?;
212
213        let content = match &choice.message {
214            Message::Assistant {
215                content,
216                tool_calls,
217                ..
218            } => {
219                let mut content = content
220                    .iter()
221                    .map(|c| match c {
222                        AssistantContent::Text { text } => completion::AssistantContent::text(text),
223                        AssistantContent::Refusal { refusal } => {
224                            completion::AssistantContent::text(refusal)
225                        }
226                    })
227                    .collect::<Vec<_>>();
228
229                content.extend(
230                    tool_calls
231                        .iter()
232                        .map(|call| {
233                            completion::AssistantContent::tool_call(
234                                &call.id,
235                                &call.function.name,
236                                call.function.arguments.clone(),
237                            )
238                        })
239                        .collect::<Vec<_>>(),
240                );
241                Ok(content)
242            }
243            _ => Err(CompletionError::ResponseError(
244                "Response did not contain a valid message or tool call".into(),
245            )),
246        }?;
247
248        let choice = OneOrMany::many(content).map_err(|_| {
249            CompletionError::ResponseError(
250                "Response contained no message or tool call (empty)".to_owned(),
251            )
252        })?;
253
254        let usage = response
255            .usage
256            .as_ref()
257            .map(|usage| completion::Usage {
258                input_tokens: usage.prompt_tokens as u64,
259                output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
260                total_tokens: usage.total_tokens as u64,
261            })
262            .unwrap_or_default();
263
264        Ok(completion::CompletionResponse {
265            choice,
266            usage,
267            raw_response: response,
268        })
269    }
270}
271
272#[derive(Debug, Deserialize)]
273pub struct Choice {
274    pub index: usize,
275    pub message: Message,
276    pub finish_reason: String,
277}
278
279#[derive(Clone)]
280pub struct CompletionModel {
281    client: Client,
282    /// Name of the model (e.g.: deepseek-ai/DeepSeek-R1)
283    pub model: String,
284}
285
286impl CompletionModel {
287    pub(crate) fn create_completion_request(
288        &self,
289        completion_request: CompletionRequest,
290    ) -> Result<Value, CompletionError> {
291        // Build up the order of messages (context, chat_history, prompt)
292        let mut partial_history = vec![];
293        if let Some(docs) = completion_request.normalized_documents() {
294            partial_history.push(docs);
295        }
296        partial_history.extend(completion_request.chat_history);
297
298        // Initialize full history with preamble (or empty if non-existent)
299        let mut full_history: Vec<Message> = completion_request
300            .preamble
301            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
302
303        // Convert and extend the rest of the history
304        full_history.extend(
305            partial_history
306                .into_iter()
307                .map(message::Message::try_into)
308                .collect::<Result<Vec<Vec<Message>>, _>>()?
309                .into_iter()
310                .flatten()
311                .collect::<Vec<_>>(),
312        );
313
314        let request = json!({
315            "model": self.model,
316            "messages": full_history,
317            "temperature": completion_request.temperature,
318        });
319
320        let request = if let Some(params) = completion_request.additional_params {
321            json_utils::merge(request, params)
322        } else {
323            request
324        };
325
326        Ok(request)
327    }
328}
329
330impl CompletionModel {
331    pub fn new(client: Client, model: &str) -> Self {
332        Self {
333            client,
334            model: model.to_string(),
335        }
336    }
337}
338
339impl completion::CompletionModel for CompletionModel {
340    type Response = CompletionResponse;
341    type StreamingResponse = openai::StreamingCompletionResponse;
342
343    #[cfg_attr(feature = "worker", worker::send)]
344    async fn completion(
345        &self,
346        completion_request: CompletionRequest,
347    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
348        let request = self.create_completion_request(completion_request)?;
349
350        let response = self
351            .client
352            .post("/chat/completions")
353            .json(&request)
354            .send()
355            .await?;
356
357        if response.status().is_success() {
358            match response.json::<ApiResponse<CompletionResponse>>().await? {
359                ApiResponse::Ok(response) => {
360                    tracing::info!(target: "rig",
361                        "Hyperbolic completion token usage: {:?}",
362                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
363                    );
364
365                    response.try_into()
366                }
367                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
368            }
369        } else {
370            Err(CompletionError::ProviderError(response.text().await?))
371        }
372    }
373
374    #[cfg_attr(feature = "worker", worker::send)]
375    async fn stream(
376        &self,
377        completion_request: CompletionRequest,
378    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
379        let mut request = self.create_completion_request(completion_request)?;
380
381        merge_inplace(
382            &mut request,
383            json!({"stream": true, "stream_options": {"include_usage": true}}),
384        );
385
386        let builder = self.client.post("/chat/completions").json(&request);
387
388        send_compatible_streaming_request(builder).await
389    }
390}
391
392// =======================================
393// Hyperbolic Image Generation API
394// =======================================
395
396#[cfg(feature = "image")]
397pub use image_generation::*;
398
399#[cfg(feature = "image")]
400mod image_generation {
401    use super::{ApiResponse, Client};
402    use crate::client::ImageGenerationClient;
403    use crate::image_generation;
404    use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
405    use crate::json_utils::merge_inplace;
406    use base64::Engine;
407    use base64::prelude::BASE64_STANDARD;
408    use serde::Deserialize;
409    use serde_json::json;
410
411    pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
412    pub const SD2: &str = "SD2";
413    pub const SD1_5: &str = "SD1.5";
414    pub const SSD: &str = "SSD";
415    pub const SDXL_TURBO: &str = "SDXL-turbo";
416    pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
417    pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
418
419    #[cfg(feature = "image")]
420    #[derive(Clone)]
421    pub struct ImageGenerationModel {
422        client: Client,
423        pub model: String,
424    }
425
426    #[cfg(feature = "image")]
427    impl ImageGenerationModel {
428        pub(crate) fn new(client: Client, model: &str) -> ImageGenerationModel {
429            Self {
430                client,
431                model: model.to_string(),
432            }
433        }
434    }
435
436    #[cfg(feature = "image")]
437    #[derive(Clone, Deserialize)]
438    pub struct Image {
439        image: String,
440    }
441
442    #[cfg(feature = "image")]
443    #[derive(Clone, Deserialize)]
444    pub struct ImageGenerationResponse {
445        images: Vec<Image>,
446    }
447
448    #[cfg(feature = "image")]
449    impl TryFrom<ImageGenerationResponse>
450        for image_generation::ImageGenerationResponse<ImageGenerationResponse>
451    {
452        type Error = ImageGenerationError;
453
454        fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
455            let data = BASE64_STANDARD
456                .decode(&value.images[0].image)
457                .expect("Could not decode image.");
458
459            Ok(Self {
460                image: data,
461                response: value,
462            })
463        }
464    }
465
466    #[cfg(feature = "image")]
467    impl image_generation::ImageGenerationModel for ImageGenerationModel {
468        type Response = ImageGenerationResponse;
469
470        async fn image_generation(
471            &self,
472            generation_request: ImageGenerationRequest,
473        ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
474        {
475            let mut request = json!({
476                "model_name": self.model,
477                "prompt": generation_request.prompt,
478                "height": generation_request.height,
479                "width": generation_request.width,
480            });
481
482            if let Some(params) = generation_request.additional_params {
483                merge_inplace(&mut request, params);
484            }
485
486            let response = self
487                .client
488                .post("/image/generation")
489                .json(&request)
490                .send()
491                .await?;
492
493            if !response.status().is_success() {
494                return Err(ImageGenerationError::ProviderError(format!(
495                    "{}: {}",
496                    response.status().as_str(),
497                    response.text().await?
498                )));
499            }
500
501            match response
502                .json::<ApiResponse<ImageGenerationResponse>>()
503                .await?
504            {
505                ApiResponse::Ok(response) => response.try_into(),
506                ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
507            }
508        }
509    }
510
511    impl ImageGenerationClient for Client {
512        type ImageGenerationModel = ImageGenerationModel;
513
514        /// Create an image generation model with the given name.
515        ///
516        /// # Example
517        /// ```
518        /// use rig::providers::hyperbolic::{Client, self};
519        ///
520        /// // Initialize the Hyperbolic client
521        /// let hyperbolic = Client::new("your-hyperbolic-api-key");
522        ///
523        /// let llama_3_1_8b = hyperbolic.image_generation_model(hyperbolic::SSD);
524        /// ```
525        fn image_generation_model(&self, model: &str) -> ImageGenerationModel {
526            ImageGenerationModel::new(self.clone(), model)
527        }
528    }
529}
530
531// ======================================
532// Hyperbolic Audio Generation API
533// ======================================
534#[cfg(feature = "audio")]
535pub use audio_generation::*;
536
537#[cfg(feature = "audio")]
538mod audio_generation {
539    use super::{ApiResponse, Client};
540    use crate::audio_generation;
541    use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
542    use crate::client::AudioGenerationClient;
543    use base64::Engine;
544    use base64::prelude::BASE64_STANDARD;
545    use serde::Deserialize;
546    use serde_json::json;
547
548    #[derive(Clone)]
549    pub struct AudioGenerationModel {
550        client: Client,
551        pub language: String,
552    }
553
554    impl AudioGenerationModel {
555        pub(crate) fn new(client: Client, language: &str) -> AudioGenerationModel {
556            Self {
557                client,
558                language: language.to_string(),
559            }
560        }
561    }
562
563    #[derive(Clone, Deserialize)]
564    pub struct AudioGenerationResponse {
565        audio: String,
566    }
567
568    impl TryFrom<AudioGenerationResponse>
569        for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
570    {
571        type Error = AudioGenerationError;
572
573        fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
574            let data = BASE64_STANDARD
575                .decode(&value.audio)
576                .expect("Could not decode audio.");
577
578            Ok(Self {
579                audio: data,
580                response: value,
581            })
582        }
583    }
584
585    impl audio_generation::AudioGenerationModel for AudioGenerationModel {
586        type Response = AudioGenerationResponse;
587
588        async fn audio_generation(
589            &self,
590            request: AudioGenerationRequest,
591        ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
592        {
593            let request = json!({
594                "language": self.language,
595                "speaker": request.voice,
596                "text": request.text,
597                "speed": request.speed
598            });
599
600            let response = self
601                .client
602                .post("/audio/generation")
603                .json(&request)
604                .send()
605                .await?;
606
607            if !response.status().is_success() {
608                return Err(AudioGenerationError::ProviderError(format!(
609                    "{}: {}",
610                    response.status(),
611                    response.text().await?
612                )));
613            }
614
615            match serde_json::from_str::<ApiResponse<AudioGenerationResponse>>(
616                &response.text().await?,
617            )? {
618                ApiResponse::Ok(response) => response.try_into(),
619                ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
620            }
621        }
622    }
623    impl AudioGenerationClient for Client {
624        type AudioGenerationModel = AudioGenerationModel;
625
626        /// Create a completion model with the given name.
627        ///
628        /// # Example
629        /// ```
630        /// use rig::providers::hyperbolic::{Client, self};
631        ///
632        /// // Initialize the Hyperbolic client
633        /// let hyperbolic = Client::new("your-hyperbolic-api-key");
634        ///
635        /// let tts = hyperbolic.audio_generation_model("EN");
636        /// ```
637        fn audio_generation_model(&self, language: &str) -> AudioGenerationModel {
638            AudioGenerationModel::new(self.clone(), language)
639        }
640    }
641}