rig/providers/
hyperbolic.rs

1//! Hyperbolic Inference API client and Rig integration
2//!
3//! # Example
4//! ```
5//! use rig::providers::hyperbolic;
6//!
7//! let client = hyperbolic::Client::new("YOUR_API_KEY");
8//!
9//! let llama_3_1_8b = client.completion_model(hyperbolic::LLAMA_3_1_8B);
10//! ```
11
12use super::openai::{send_compatible_streaming_request, AssistantContent};
13
14use crate::json_utils::merge_inplace;
15use crate::message;
16use crate::streaming::{StreamingCompletionModel, StreamingResult};
17use crate::{
18    agent::AgentBuilder,
19    completion::{self, CompletionError, CompletionRequest},
20    extractor::ExtractorBuilder,
21    json_utils,
22    providers::openai::Message,
23    OneOrMany,
24};
25
26use schemars::JsonSchema;
27use serde::{Deserialize, Serialize};
28use serde_json::{json, Value};
29
30// ================================================================
31// Main Hyperbolic Client
32// ================================================================
33const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz/v1";
34
35#[derive(Clone)]
36pub struct Client {
37    base_url: String,
38    http_client: reqwest::Client,
39}
40
41impl Client {
42    /// Create a new Hyperbolic client with the given API key.
43    pub fn new(api_key: &str) -> Self {
44        Self::from_url(api_key, HYPERBOLIC_API_BASE_URL)
45    }
46
47    /// Create a new OpenAI client with the given API key and base API URL.
48    pub fn from_url(api_key: &str, base_url: &str) -> Self {
49        Self {
50            base_url: base_url.to_string(),
51            http_client: reqwest::Client::builder()
52                .default_headers({
53                    let mut headers = reqwest::header::HeaderMap::new();
54                    headers.insert(
55                        "Authorization",
56                        format!("Bearer {}", api_key)
57                            .parse()
58                            .expect("Bearer token should parse"),
59                    );
60                    headers
61                })
62                .build()
63                .expect("OpenAI reqwest client should build"),
64        }
65    }
66
67    /// Create a new Hyperbolic client from the `HYPERBOLIC_API_KEY` environment variable.
68    /// Panics if the environment variable is not set.
69    pub fn from_env() -> Self {
70        let api_key = std::env::var("HYPERBOLIC_API_KEY").expect("HYPERBOLIC_API_KEY not set");
71        Self::new(&api_key)
72    }
73
74    fn post(&self, path: &str) -> reqwest::RequestBuilder {
75        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
76        self.http_client.post(url)
77    }
78
79    /// Create a completion model with the given name.
80    ///
81    /// # Example
82    /// ```
83    /// use rig::providers::hyperbolic::{Client, self};
84    ///
85    /// // Initialize the Hyperbolic client
86    /// let hyperbolic = Client::new("your-hyperbolic-api-key");
87    ///
88    /// let llama_3_1_8b = hyperbolic.completion_model(hyperbolic::LLAMA_3_1_8B);
89    /// ```
90    pub fn completion_model(&self, model: &str) -> CompletionModel {
91        CompletionModel::new(self.clone(), model)
92    }
93
94    /// Create an image generation model with the given name.
95    ///
96    /// # Example
97    /// ```
98    /// use rig::providers::hyperbolic::{Client, self};
99    ///
100    /// // Initialize the Hyperbolic client
101    /// let hyperbolic = Client::new("your-hyperbolic-api-key");
102    ///
103    /// let llama_3_1_8b = hyperbolic.image_generation_model(hyperbolic::SSD);
104    /// ```
105    #[cfg(feature = "image")]
106    pub fn image_generation_model(&self, model: &str) -> ImageGenerationModel {
107        ImageGenerationModel::new(self.clone(), model)
108    }
109
110    /// Create a completion model with the given name.
111    ///
112    /// # Example
113    /// ```
114    /// use rig::providers::hyperbolic::{Client, self};
115    ///
116    /// // Initialize the Hyperbolic client
117    /// let hyperbolic = Client::new("your-hyperbolic-api-key");
118    ///
119    /// let tts = hyperbolic.audio_generation_model("EN");
120    /// ```
121    #[cfg(feature = "audio")]
122    pub fn audio_generation_model(&self, language: &str) -> AudioGenerationModel {
123        AudioGenerationModel::new(self.clone(), language)
124    }
125
126    /// Create an agent builder with the given completion model.
127    ///
128    /// # Example
129    /// ```
130    /// use rig::providers::hyperbolic::{Client, self};
131    ///
132    /// // Initialize the Eternal client
133    /// let hyperbolic = Client::new("your-hyperbolic-api-key");
134    ///
135    /// let agent = hyperbolic.agent(hyperbolic::LLAMA_3_1_8B)
136    ///    .preamble("You are comedian AI with a mission to make people laugh.")
137    ///    .temperature(0.0)
138    ///    .build();
139    /// ```
140    pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
141        AgentBuilder::new(self.completion_model(model))
142    }
143
144    /// Create an extractor builder with the given completion model.
145    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
146        &self,
147        model: &str,
148    ) -> ExtractorBuilder<T, CompletionModel> {
149        ExtractorBuilder::new(self.completion_model(model))
150    }
151}
152
153#[derive(Debug, Deserialize)]
154struct ApiErrorResponse {
155    message: String,
156}
157
158#[derive(Debug, Deserialize)]
159#[serde(untagged)]
160enum ApiResponse<T> {
161    Ok(T),
162    Err(ApiErrorResponse),
163}
164
165#[derive(Debug, Deserialize)]
166pub struct EmbeddingData {
167    pub object: String,
168    pub embedding: Vec<f64>,
169    pub index: usize,
170}
171
172#[derive(Clone, Debug, Deserialize)]
173pub struct Usage {
174    pub prompt_tokens: usize,
175    pub total_tokens: usize,
176}
177
178impl std::fmt::Display for Usage {
179    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180        write!(
181            f,
182            "Prompt tokens: {} Total tokens: {}",
183            self.prompt_tokens, self.total_tokens
184        )
185    }
186}
187
188// ================================================================
189// Hyperbolic Completion API
190// ================================================================
191/// Meta Llama 3.1b Instruct model with 8B parameters.
192pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
193/// Meta Llama 3.3b Instruct model with 70B parameters.
194pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
195/// Meta Llama 3.1b Instruct model with 70B parameters.
196pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
197/// Meta Llama 3 Instruct model with 70B parameters.
198pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
199/// Hermes 3 Instruct model with 70B parameters.
200pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
201/// Deepseek v2.5 model.
202pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
203/// Qwen 2.5 model with 72B parameters.
204pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
205/// Meta Llama 3.2b Instruct model with 3B parameters.
206pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
207/// Qwen 2.5 Coder Instruct model with 32B parameters.
208pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
209/// Preview (latest) version of Qwen model with 32B parameters.
210pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
211/// Deepseek R1 Zero model.
212pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
213/// Deepseek R1 model.
214pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
215
216/// A Hyperbolic completion object.
217///
218/// For more information, see this link: <https://docs.hyperbolic.xyz/reference/create_chat_completion_v1_chat_completions_post>
219#[derive(Debug, Deserialize)]
220pub struct CompletionResponse {
221    pub id: String,
222    pub object: String,
223    pub created: u64,
224    pub model: String,
225    pub choices: Vec<Choice>,
226    pub usage: Option<Usage>,
227}
228
229impl From<ApiErrorResponse> for CompletionError {
230    fn from(err: ApiErrorResponse) -> Self {
231        CompletionError::ProviderError(err.message)
232    }
233}
234
235impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
236    type Error = CompletionError;
237
238    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
239        let choice = response.choices.first().ok_or_else(|| {
240            CompletionError::ResponseError("Response contained no choices".to_owned())
241        })?;
242
243        let content = match &choice.message {
244            Message::Assistant {
245                content,
246                tool_calls,
247                ..
248            } => {
249                let mut content = content
250                    .iter()
251                    .map(|c| match c {
252                        AssistantContent::Text { text } => completion::AssistantContent::text(text),
253                        AssistantContent::Refusal { refusal } => {
254                            completion::AssistantContent::text(refusal)
255                        }
256                    })
257                    .collect::<Vec<_>>();
258
259                content.extend(
260                    tool_calls
261                        .iter()
262                        .map(|call| {
263                            completion::AssistantContent::tool_call(
264                                &call.id,
265                                &call.function.name,
266                                call.function.arguments.clone(),
267                            )
268                        })
269                        .collect::<Vec<_>>(),
270                );
271                Ok(content)
272            }
273            _ => Err(CompletionError::ResponseError(
274                "Response did not contain a valid message or tool call".into(),
275            )),
276        }?;
277
278        let choice = OneOrMany::many(content).map_err(|_| {
279            CompletionError::ResponseError(
280                "Response contained no message or tool call (empty)".to_owned(),
281            )
282        })?;
283
284        Ok(completion::CompletionResponse {
285            choice,
286            raw_response: response,
287        })
288    }
289}
290
291#[derive(Debug, Deserialize)]
292pub struct Choice {
293    pub index: usize,
294    pub message: Message,
295    pub finish_reason: String,
296}
297
298#[derive(Clone)]
299pub struct CompletionModel {
300    client: Client,
301    /// Name of the model (e.g.: deepseek-ai/DeepSeek-R1)
302    pub model: String,
303}
304
305impl CompletionModel {
306    pub(crate) fn create_completion_request(
307        &self,
308        completion_request: CompletionRequest,
309    ) -> Result<Value, CompletionError> {
310        // Build up the order of messages (context, chat_history, prompt)
311        let mut partial_history = vec![];
312        if let Some(docs) = completion_request.normalized_documents() {
313            partial_history.push(docs);
314        }
315        partial_history.extend(completion_request.chat_history);
316
317        // Initialize full history with preamble (or empty if non-existent)
318        let mut full_history: Vec<Message> = completion_request
319            .preamble
320            .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
321
322        // Convert and extend the rest of the history
323        full_history.extend(
324            partial_history
325                .into_iter()
326                .map(message::Message::try_into)
327                .collect::<Result<Vec<Vec<Message>>, _>>()?
328                .into_iter()
329                .flatten()
330                .collect::<Vec<_>>(),
331        );
332
333        let request = json!({
334            "model": self.model,
335            "messages": full_history,
336            "temperature": completion_request.temperature,
337        });
338
339        let request = if let Some(params) = completion_request.additional_params {
340            json_utils::merge(request, params)
341        } else {
342            request
343        };
344
345        Ok(request)
346    }
347}
348
349impl CompletionModel {
350    pub fn new(client: Client, model: &str) -> Self {
351        Self {
352            client,
353            model: model.to_string(),
354        }
355    }
356}
357
358impl completion::CompletionModel for CompletionModel {
359    type Response = CompletionResponse;
360
361    #[cfg_attr(feature = "worker", worker::send)]
362    async fn completion(
363        &self,
364        completion_request: CompletionRequest,
365    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
366        let request = self.create_completion_request(completion_request)?;
367
368        let response = self
369            .client
370            .post("/chat/completions")
371            .json(&request)
372            .send()
373            .await?;
374
375        if response.status().is_success() {
376            match response.json::<ApiResponse<CompletionResponse>>().await? {
377                ApiResponse::Ok(response) => {
378                    tracing::info!(target: "rig",
379                        "Hyperbolic completion token usage: {:?}",
380                        response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
381                    );
382
383                    response.try_into()
384                }
385                ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
386            }
387        } else {
388            Err(CompletionError::ProviderError(response.text().await?))
389        }
390    }
391}
392
393impl StreamingCompletionModel for CompletionModel {
394    async fn stream(
395        &self,
396        completion_request: CompletionRequest,
397    ) -> Result<StreamingResult, CompletionError> {
398        let mut request = self.create_completion_request(completion_request)?;
399
400        merge_inplace(&mut request, json!({"stream": true}));
401
402        let builder = self.client.post("/chat/completions").json(&request);
403
404        send_compatible_streaming_request(builder).await
405    }
406}
407
408// =======================================
409// Hyperbolic Image Generation API
410// =======================================
411
412#[cfg(feature = "image")]
413pub use image_generation::*;
414
415#[cfg(feature = "image")]
416mod image_generation {
417    use super::{ApiResponse, Client};
418    use crate::image_generation;
419    use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
420    use crate::json_utils::merge_inplace;
421    use base64::prelude::BASE64_STANDARD;
422    use base64::Engine;
423    use serde::Deserialize;
424    use serde_json::json;
425
426    pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
427    pub const SD2: &str = "SD2";
428    pub const SD1_5: &str = "SD1.5";
429    pub const SSD: &str = "SSD";
430    pub const SDXL_TURBO: &str = "SDXL-turbo";
431    pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
432    pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
433
434    #[cfg(feature = "image")]
435    #[derive(Clone)]
436    pub struct ImageGenerationModel {
437        client: Client,
438        pub model: String,
439    }
440
441    #[cfg(feature = "image")]
442    impl ImageGenerationModel {
443        pub(crate) fn new(client: Client, model: &str) -> ImageGenerationModel {
444            Self {
445                client,
446                model: model.to_string(),
447            }
448        }
449    }
450
451    #[cfg(feature = "image")]
452    #[derive(Clone, Deserialize)]
453    pub struct Image {
454        image: String,
455    }
456
457    #[cfg(feature = "image")]
458    #[derive(Clone, Deserialize)]
459    pub struct ImageGenerationResponse {
460        images: Vec<Image>,
461    }
462
463    #[cfg(feature = "image")]
464    impl TryFrom<ImageGenerationResponse>
465        for image_generation::ImageGenerationResponse<ImageGenerationResponse>
466    {
467        type Error = ImageGenerationError;
468
469        fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
470            let data = BASE64_STANDARD
471                .decode(&value.images[0].image)
472                .expect("Could not decode image.");
473
474            Ok(Self {
475                image: data,
476                response: value,
477            })
478        }
479    }
480
481    #[cfg(feature = "image")]
482    impl image_generation::ImageGenerationModel for ImageGenerationModel {
483        type Response = ImageGenerationResponse;
484
485        async fn image_generation(
486            &self,
487            generation_request: ImageGenerationRequest,
488        ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
489        {
490            let mut request = json!({
491                "model_name": self.model,
492                "prompt": generation_request.prompt,
493                "height": generation_request.height,
494                "width": generation_request.width,
495            });
496
497            if let Some(params) = generation_request.additional_params {
498                merge_inplace(&mut request, params);
499            }
500
501            let response = self
502                .client
503                .post("/image/generation")
504                .json(&request)
505                .send()
506                .await?;
507
508            if !response.status().is_success() {
509                return Err(ImageGenerationError::ProviderError(format!(
510                    "{}: {}",
511                    response.status().as_str(),
512                    response.text().await?
513                )));
514            }
515
516            match response
517                .json::<ApiResponse<ImageGenerationResponse>>()
518                .await?
519            {
520                ApiResponse::Ok(response) => response.try_into(),
521                ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
522            }
523        }
524    }
525}
526
527// ======================================
528// Hyperbolic Audio Generation API
529// ======================================
530#[cfg(feature = "audio")]
531pub use audio_generation::*;
532#[cfg(feature = "audio")]
533mod audio_generation {
534    use super::{ApiResponse, Client};
535    use crate::audio_generation;
536    use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
537    use base64::prelude::BASE64_STANDARD;
538    use base64::Engine;
539    use serde::Deserialize;
540    use serde_json::json;
541
542    #[derive(Clone)]
543    pub struct AudioGenerationModel {
544        client: Client,
545        pub langauge: String,
546    }
547
548    impl AudioGenerationModel {
549        pub(crate) fn new(client: Client, language: &str) -> AudioGenerationModel {
550            Self {
551                client,
552                langauge: language.to_string(),
553            }
554        }
555    }
556
557    #[derive(Clone, Deserialize)]
558    pub struct AudioGenerationResponse {
559        audio: String,
560    }
561
562    impl TryFrom<AudioGenerationResponse>
563        for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
564    {
565        type Error = AudioGenerationError;
566
567        fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
568            let data = BASE64_STANDARD
569                .decode(&value.audio)
570                .expect("Could not decode audio.");
571
572            Ok(Self {
573                audio: data,
574                response: value,
575            })
576        }
577    }
578
579    impl audio_generation::AudioGenerationModel for AudioGenerationModel {
580        type Response = AudioGenerationResponse;
581
582        async fn audio_generation(
583            &self,
584            request: AudioGenerationRequest,
585        ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
586        {
587            let request = json!({
588                "language": self.langauge,
589                "speaker": request.voice,
590                "text": request.text,
591                "speed": request.speed
592            });
593
594            let response = self
595                .client
596                .post("/audio/generation")
597                .json(&request)
598                .send()
599                .await?;
600
601            if !response.status().is_success() {
602                return Err(AudioGenerationError::ProviderError(format!(
603                    "{}: {}",
604                    response.status(),
605                    response.text().await?
606                )));
607            }
608
609            match serde_json::from_str::<ApiResponse<AudioGenerationResponse>>(
610                &response.text().await?,
611            )? {
612                ApiResponse::Ok(response) => response.try_into(),
613                ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
614            }
615        }
616    }
617}