1use super::openai::{AssistantContent, send_compatible_streaming_request};
12
13use crate::client::{CompletionClient, ProviderClient, VerifyClient, VerifyError};
14use crate::http_client::{self, HttpClientExt};
15use crate::json_utils::merge_inplace;
16use crate::message;
17use crate::streaming::StreamingCompletionResponse;
18
19use crate::impl_conversion_traits;
20use crate::providers::openai;
21use crate::{
22 OneOrMany,
23 completion::{self, CompletionError, CompletionRequest},
24 json_utils,
25 providers::openai::Message,
26};
27use http::Method;
28use serde::{Deserialize, Serialize};
29use serde_json::{Value, json};
30
31const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz";
35
36pub struct ClientBuilder<'a, T = reqwest::Client> {
37 api_key: &'a str,
38 base_url: &'a str,
39 http_client: T,
40}
41
42impl<'a, T> ClientBuilder<'a, T>
43where
44 T: Default,
45{
46 pub fn new(api_key: &'a str) -> Self {
47 Self {
48 api_key,
49 base_url: HYPERBOLIC_API_BASE_URL,
50 http_client: Default::default(),
51 }
52 }
53}
54
55impl<'a, T> ClientBuilder<'a, T> {
56 pub fn new_with_client(api_key: &'a str, http_client: T) -> Self {
57 Self {
58 api_key,
59 base_url: HYPERBOLIC_API_BASE_URL,
60 http_client,
61 }
62 }
63
64 pub fn base_url(mut self, base_url: &'a str) -> Self {
65 self.base_url = base_url;
66 self
67 }
68
69 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
70 ClientBuilder {
71 api_key: self.api_key,
72 base_url: self.base_url,
73 http_client,
74 }
75 }
76
77 pub fn build(self) -> Client<T> {
78 Client {
79 base_url: self.base_url.to_string(),
80 api_key: self.api_key.to_string(),
81 http_client: self.http_client,
82 }
83 }
84}
85
86#[derive(Clone)]
87pub struct Client<T = reqwest::Client> {
88 base_url: String,
89 api_key: String,
90 http_client: T,
91}
92
93impl<T> std::fmt::Debug for Client<T>
94where
95 T: std::fmt::Debug,
96{
97 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98 f.debug_struct("Client")
99 .field("base_url", &self.base_url)
100 .field("http_client", &self.http_client)
101 .field("api_key", &"<REDACTED>")
102 .finish()
103 }
104}
105
106impl Client<reqwest::Client> {
107 pub fn builder(api_key: &str) -> ClientBuilder<'_, reqwest::Client> {
108 ClientBuilder::new(api_key)
109 }
110
111 pub fn new(api_key: &str) -> Self {
112 Self::builder(api_key).build()
113 }
114
115 pub fn from_env() -> Self {
116 <Self as ProviderClient>::from_env()
117 }
118}
119
120impl<T> Client<T>
121where
122 T: HttpClientExt,
123{
124 fn req(
125 &self,
126 method: http_client::Method,
127 path: &str,
128 ) -> http_client::Result<http_client::Builder> {
129 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
130
131 http_client::with_bearer_auth(
132 http_client::Builder::new().method(method).uri(url),
133 &self.api_key,
134 )
135 }
136
137 fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
138 self.req(http_client::Method::GET, path)
139 }
140}
141
142impl<T> ProviderClient for Client<T>
143where
144 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
145{
146 fn from_env() -> Self {
149 let api_key = std::env::var("HYPERBOLIC_API_KEY").expect("HYPERBOLIC_API_KEY not set");
150 ClientBuilder::<T>::new(&api_key).build()
151 }
152
153 fn from_val(input: crate::client::ProviderValue) -> Self {
154 let crate::client::ProviderValue::Simple(api_key) = input else {
155 panic!("Incorrect provider value type")
156 };
157 ClientBuilder::<T>::new(&api_key).build()
158 }
159}
160
161impl<T> CompletionClient for Client<T>
162where
163 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
164{
165 type CompletionModel = CompletionModel<T>;
166
167 fn completion_model(&self, model: &str) -> Self::CompletionModel {
179 CompletionModel::new(self.clone(), model)
180 }
181}
182
183impl VerifyClient for Client<reqwest::Client> {
184 #[cfg_attr(feature = "worker", worker::send)]
185 async fn verify(&self) -> Result<(), VerifyError> {
186 let req = self
187 .get("/models")?
188 .body(http_client::NoBody)
189 .map_err(http_client::Error::from)?;
190
191 let response = HttpClientExt::send(&self.http_client, req).await?;
192
193 match response.status() {
194 reqwest::StatusCode::OK => Ok(()),
195 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
196 reqwest::StatusCode::INTERNAL_SERVER_ERROR
197 | reqwest::StatusCode::SERVICE_UNAVAILABLE
198 | reqwest::StatusCode::BAD_GATEWAY => {
199 let text = http_client::text(response).await?;
200 Err(VerifyError::ProviderError(text))
201 }
202 _ => {
203 Ok(())
205 }
206 }
207 }
208}
209
210impl_conversion_traits!(
211 AsEmbeddings,
212 AsTranscription for Client<T>
213);
214
215#[derive(Debug, Deserialize)]
216struct ApiErrorResponse {
217 message: String,
218}
219
220#[derive(Debug, Deserialize)]
221#[serde(untagged)]
222enum ApiResponse<T> {
223 Ok(T),
224 Err(ApiErrorResponse),
225}
226
227#[derive(Debug, Deserialize)]
228pub struct EmbeddingData {
229 pub object: String,
230 pub embedding: Vec<f64>,
231 pub index: usize,
232}
233
234#[derive(Clone, Debug, Deserialize, Serialize)]
235pub struct Usage {
236 pub prompt_tokens: usize,
237 pub total_tokens: usize,
238}
239
240impl std::fmt::Display for Usage {
241 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242 write!(
243 f,
244 "Prompt tokens: {} Total tokens: {}",
245 self.prompt_tokens, self.total_tokens
246 )
247 }
248}
249
250pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
255pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
257pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
259pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
261pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
263pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
265pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
267pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
269pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
271pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
273pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
275pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
277
278#[derive(Debug, Deserialize, Serialize)]
282pub struct CompletionResponse {
283 pub id: String,
284 pub object: String,
285 pub created: u64,
286 pub model: String,
287 pub choices: Vec<Choice>,
288 pub usage: Option<Usage>,
289}
290
291impl From<ApiErrorResponse> for CompletionError {
292 fn from(err: ApiErrorResponse) -> Self {
293 CompletionError::ProviderError(err.message)
294 }
295}
296
297impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
298 type Error = CompletionError;
299
300 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
301 let choice = response.choices.first().ok_or_else(|| {
302 CompletionError::ResponseError("Response contained no choices".to_owned())
303 })?;
304
305 let content = match &choice.message {
306 Message::Assistant {
307 content,
308 tool_calls,
309 ..
310 } => {
311 let mut content = content
312 .iter()
313 .map(|c| match c {
314 AssistantContent::Text { text } => completion::AssistantContent::text(text),
315 AssistantContent::Refusal { refusal } => {
316 completion::AssistantContent::text(refusal)
317 }
318 })
319 .collect::<Vec<_>>();
320
321 content.extend(
322 tool_calls
323 .iter()
324 .map(|call| {
325 completion::AssistantContent::tool_call(
326 &call.id,
327 &call.function.name,
328 call.function.arguments.clone(),
329 )
330 })
331 .collect::<Vec<_>>(),
332 );
333 Ok(content)
334 }
335 _ => Err(CompletionError::ResponseError(
336 "Response did not contain a valid message or tool call".into(),
337 )),
338 }?;
339
340 let choice = OneOrMany::many(content).map_err(|_| {
341 CompletionError::ResponseError(
342 "Response contained no message or tool call (empty)".to_owned(),
343 )
344 })?;
345
346 let usage = response
347 .usage
348 .as_ref()
349 .map(|usage| completion::Usage {
350 input_tokens: usage.prompt_tokens as u64,
351 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
352 total_tokens: usage.total_tokens as u64,
353 })
354 .unwrap_or_default();
355
356 Ok(completion::CompletionResponse {
357 choice,
358 usage,
359 raw_response: response,
360 })
361 }
362}
363
364#[derive(Debug, Deserialize, Serialize)]
365pub struct Choice {
366 pub index: usize,
367 pub message: Message,
368 pub finish_reason: String,
369}
370
371#[derive(Clone)]
372pub struct CompletionModel<T> {
373 client: Client<T>,
374 pub model: String,
376}
377
378impl<T> CompletionModel<T> {
379 pub fn new(client: Client<T>, model: &str) -> Self {
380 Self {
381 client,
382 model: model.to_string(),
383 }
384 }
385
386 pub(crate) fn create_completion_request(
387 &self,
388 completion_request: CompletionRequest,
389 ) -> Result<Value, CompletionError> {
390 if completion_request.tool_choice.is_some() {
391 tracing::warn!("WARNING: `tool_choice` not supported on Hyperbolic");
392 }
393 let mut partial_history = vec![];
395 if let Some(docs) = completion_request.normalized_documents() {
396 partial_history.push(docs);
397 }
398 partial_history.extend(completion_request.chat_history);
399
400 let mut full_history: Vec<Message> = completion_request
402 .preamble
403 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
404
405 full_history.extend(
407 partial_history
408 .into_iter()
409 .map(message::Message::try_into)
410 .collect::<Result<Vec<Vec<Message>>, _>>()?
411 .into_iter()
412 .flatten()
413 .collect::<Vec<_>>(),
414 );
415
416 let request = json!({
417 "model": self.model,
418 "messages": full_history,
419 "temperature": completion_request.temperature,
420 });
421
422 let request = if let Some(params) = completion_request.additional_params {
423 json_utils::merge(request, params)
424 } else {
425 request
426 };
427
428 Ok(request)
429 }
430}
431
432impl<T> completion::CompletionModel for CompletionModel<T>
433where
434 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
435{
436 type Response = CompletionResponse;
437 type StreamingResponse = openai::StreamingCompletionResponse;
438
439 #[cfg_attr(feature = "worker", worker::send)]
440 async fn completion(
441 &self,
442 completion_request: CompletionRequest,
443 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
444 let preamble = completion_request.preamble.clone();
445 let request = self.create_completion_request(completion_request)?;
446 let body = serde_json::to_vec(&request)?;
447
448 let span = if tracing::Span::current().is_disabled() {
449 info_span!(
450 target: "rig::completions",
451 "chat",
452 gen_ai.operation.name = "chat",
453 gen_ai.provider.name = "hyperbolic",
454 gen_ai.request.model = self.model,
455 gen_ai.system_instructions = preamble,
456 gen_ai.response.id = tracing::field::Empty,
457 gen_ai.response.model = tracing::field::Empty,
458 gen_ai.usage.output_tokens = tracing::field::Empty,
459 gen_ai.usage.input_tokens = tracing::field::Empty,
460 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
461 gen_ai.output.messages = tracing::field::Empty,
462 )
463 } else {
464 tracing::Span::current()
465 };
466
467 let req = self
468 .client
469 .req(Method::POST, "/v1/chat/completions")?
470 .header("Content-Type", "application/json")
471 .body(body)
472 .map_err(http_client::Error::from)?;
473
474 let async_block = async move {
475 let response = self.client.http_client.send::<_, bytes::Bytes>(req).await?;
476
477 let status = response.status();
478 let response_body = response.into_body().into_future().await?.to_vec();
479
480 if status.is_success() {
481 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
482 ApiResponse::Ok(response) => {
483 tracing::info!(target: "rig",
484 "Hyperbolic completion token usage: {:?}",
485 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
486 );
487
488 response.try_into()
489 }
490 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
491 }
492 } else {
493 Err(CompletionError::ProviderError(
494 String::from_utf8_lossy(&response_body).to_string(),
495 ))
496 }
497 };
498
499 async_block.instrument(span).await
500 }
501
502 #[cfg_attr(feature = "worker", worker::send)]
503 async fn stream(
504 &self,
505 completion_request: CompletionRequest,
506 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
507 let preamble = completion_request.preamble.clone();
508 let mut request = self.create_completion_request(completion_request)?;
509
510 let span = if tracing::Span::current().is_disabled() {
511 info_span!(
512 target: "rig::completions",
513 "chat_streaming",
514 gen_ai.operation.name = "chat_streaming",
515 gen_ai.provider.name = "hyperbolic",
516 gen_ai.request.model = self.model,
517 gen_ai.system_instructions = preamble,
518 gen_ai.response.id = tracing::field::Empty,
519 gen_ai.response.model = tracing::field::Empty,
520 gen_ai.usage.output_tokens = tracing::field::Empty,
521 gen_ai.usage.input_tokens = tracing::field::Empty,
522 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
523 gen_ai.output.messages = tracing::field::Empty,
524 )
525 } else {
526 tracing::Span::current()
527 };
528
529 merge_inplace(
530 &mut request,
531 json!({"stream": true, "stream_options": {"include_usage": true}}),
532 );
533
534 let body = serde_json::to_vec(&request)?;
535
536 let req = self
537 .client
538 .req(Method::POST, "/v1/chat/completions")?
539 .header("Content-Type", "application/json")
540 .body(body)
541 .map_err(http_client::Error::from)?;
542
543 send_compatible_streaming_request(self.client.http_client.clone(), req)
544 .instrument(span)
545 .await
546 }
547}
548
549#[cfg(feature = "image")]
554pub use image_generation::*;
555
556#[cfg(feature = "image")]
557#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
558mod image_generation {
559 use super::{ApiResponse, Client};
560 use crate::client::ImageGenerationClient;
561 use crate::http_client::HttpClientExt;
562 use crate::image_generation;
563 use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
564 use crate::json_utils::merge_inplace;
565 use base64::Engine;
566 use base64::prelude::BASE64_STANDARD;
567 use http::Method;
568 use serde::Deserialize;
569 use serde_json::json;
570
571 pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
572 pub const SD2: &str = "SD2";
573 pub const SD1_5: &str = "SD1.5";
574 pub const SSD: &str = "SSD";
575 pub const SDXL_TURBO: &str = "SDXL-turbo";
576 pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
577 pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
578
579 #[derive(Clone)]
580 pub struct ImageGenerationModel<T> {
581 client: Client<T>,
582 pub model: String,
583 }
584
585 impl<T> ImageGenerationModel<T> {
586 pub(crate) fn new(client: Client<T>, model: &str) -> ImageGenerationModel<T> {
587 Self {
588 client,
589 model: model.to_string(),
590 }
591 }
592 }
593
594 #[derive(Clone, Deserialize)]
595 pub struct Image {
596 image: String,
597 }
598
599 #[derive(Clone, Deserialize)]
600 pub struct ImageGenerationResponse {
601 images: Vec<Image>,
602 }
603
604 impl TryFrom<ImageGenerationResponse>
605 for image_generation::ImageGenerationResponse<ImageGenerationResponse>
606 {
607 type Error = ImageGenerationError;
608
609 fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
610 let data = BASE64_STANDARD
611 .decode(&value.images[0].image)
612 .expect("Could not decode image.");
613
614 Ok(Self {
615 image: data,
616 response: value,
617 })
618 }
619 }
620
621 impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
622 where
623 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
624 {
625 type Response = ImageGenerationResponse;
626
627 #[cfg_attr(feature = "worker", worker::send)]
628 async fn image_generation(
629 &self,
630 generation_request: ImageGenerationRequest,
631 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
632 {
633 let mut request = json!({
634 "model_name": self.model,
635 "prompt": generation_request.prompt,
636 "height": generation_request.height,
637 "width": generation_request.width,
638 });
639
640 if let Some(params) = generation_request.additional_params {
641 merge_inplace(&mut request, params);
642 }
643
644 let body = serde_json::to_vec(&request)?;
645
646 let request = self
647 .client
648 .req(Method::POST, "/v1/image/generation")?
649 .header("Content-Type", "application/json")
650 .body(body)
651 .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
652
653 let response = self
654 .client
655 .http_client
656 .send::<_, bytes::Bytes>(request)
657 .await?;
658
659 let status = response.status();
660 let response_body = response.into_body().into_future().await?.to_vec();
661
662 if !status.is_success() {
663 return Err(ImageGenerationError::ProviderError(format!(
664 "{status}: {}",
665 String::from_utf8_lossy(&response_body)
666 )));
667 }
668
669 match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
670 ApiResponse::Ok(response) => response.try_into(),
671 ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
672 }
673 }
674 }
675
676 impl<T> ImageGenerationClient for Client<T>
677 where
678 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
679 {
680 type ImageGenerationModel = ImageGenerationModel<T>;
681
682 fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
694 ImageGenerationModel::new(self.clone(), model)
695 }
696 }
697}
698
699#[cfg(feature = "audio")]
703pub use audio_generation::*;
704use tracing::{Instrument, info_span};
705
706#[cfg(feature = "audio")]
707#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
708mod audio_generation {
709 use super::{ApiResponse, Client};
710 use crate::audio_generation;
711 use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
712 use crate::client::AudioGenerationClient;
713 use crate::http_client::{self, HttpClientExt};
714 use base64::Engine;
715 use base64::prelude::BASE64_STANDARD;
716 use bytes::Bytes;
717 use http::Method;
718 use serde::Deserialize;
719 use serde_json::json;
720
721 #[derive(Clone)]
722 pub struct AudioGenerationModel<T> {
723 client: Client<T>,
724 pub language: String,
725 }
726
727 impl<T> AudioGenerationModel<T> {
728 pub(crate) fn new(client: Client<T>, language: &str) -> AudioGenerationModel<T> {
729 Self {
730 client,
731 language: language.to_string(),
732 }
733 }
734 }
735
736 #[derive(Clone, Deserialize)]
737 pub struct AudioGenerationResponse {
738 audio: String,
739 }
740
741 impl TryFrom<AudioGenerationResponse>
742 for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
743 {
744 type Error = AudioGenerationError;
745
746 fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
747 let data = BASE64_STANDARD
748 .decode(&value.audio)
749 .expect("Could not decode audio.");
750
751 Ok(Self {
752 audio: data,
753 response: value,
754 })
755 }
756 }
757
758 impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
759 where
760 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
761 {
762 type Response = AudioGenerationResponse;
763
764 #[cfg_attr(feature = "worker", worker::send)]
765 async fn audio_generation(
766 &self,
767 request: AudioGenerationRequest,
768 ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
769 {
770 let request = json!({
771 "language": self.language,
772 "speaker": request.voice,
773 "text": request.text,
774 "speed": request.speed
775 });
776
777 let body = serde_json::to_vec(&request)?;
778
779 let req = self
780 .client
781 .req(Method::POST, "/v1/audio/generation")?
782 .header("Content-Type", "application/json")
783 .body(body)
784 .map_err(http_client::Error::from)?;
785
786 let response = self.client.http_client.send::<_, Bytes>(req).await?;
787 let status = response.status();
788 let response_body = response.into_body().into_future().await?.to_vec();
789
790 if !status.is_success() {
791 return Err(AudioGenerationError::ProviderError(format!(
792 "{status}: {}",
793 String::from_utf8_lossy(&response_body)
794 )));
795 }
796
797 match serde_json::from_slice::<ApiResponse<AudioGenerationResponse>>(&response_body)? {
798 ApiResponse::Ok(response) => response.try_into(),
799 ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
800 }
801 }
802 }
803 impl<T> AudioGenerationClient for Client<T>
804 where
805 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
806 {
807 type AudioGenerationModel = AudioGenerationModel<T>;
808
809 fn audio_generation_model(&self, language: &str) -> Self::AudioGenerationModel {
821 AudioGenerationModel::new(self.clone(), language)
822 }
823 }
824}