1use super::openai::{AssistantContent, send_compatible_streaming_request};
12
13use crate::client::{CompletionClient, ProviderClient, VerifyClient, VerifyError};
14use crate::http_client::{self, HttpClientExt};
15use crate::json_utils::merge_inplace;
16use crate::message;
17use crate::streaming::StreamingCompletionResponse;
18
19use crate::impl_conversion_traits;
20use crate::providers::openai;
21use crate::{
22 OneOrMany,
23 completion::{self, CompletionError, CompletionRequest},
24 json_utils,
25 providers::openai::Message,
26};
27use serde::{Deserialize, Serialize};
28use serde_json::{Value, json};
29
30const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz";
34
35pub struct ClientBuilder<'a, T = reqwest::Client> {
36 api_key: &'a str,
37 base_url: &'a str,
38 http_client: T,
39}
40
41impl<'a, T> ClientBuilder<'a, T>
42where
43 T: Default,
44{
45 pub fn new(api_key: &'a str) -> Self {
46 Self {
47 api_key,
48 base_url: HYPERBOLIC_API_BASE_URL,
49 http_client: Default::default(),
50 }
51 }
52}
53
54impl<'a, T> ClientBuilder<'a, T> {
55 pub fn base_url(mut self, base_url: &'a str) -> Self {
56 self.base_url = base_url;
57 self
58 }
59
60 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
61 ClientBuilder {
62 api_key: self.api_key,
63 base_url: self.base_url,
64 http_client,
65 }
66 }
67
68 pub fn build(self) -> Client<T> {
69 Client {
70 base_url: self.base_url.to_string(),
71 api_key: self.api_key.to_string(),
72 http_client: self.http_client,
73 }
74 }
75}
76
77#[derive(Clone)]
78pub struct Client<T = reqwest::Client> {
79 base_url: String,
80 api_key: String,
81 http_client: T,
82}
83
84impl<T> std::fmt::Debug for Client<T>
85where
86 T: std::fmt::Debug,
87{
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 f.debug_struct("Client")
90 .field("base_url", &self.base_url)
91 .field("http_client", &self.http_client)
92 .field("api_key", &"<REDACTED>")
93 .finish()
94 }
95}
96
97impl<T> Client<T>
98where
99 T: Default,
100{
101 pub fn builder(api_key: &str) -> ClientBuilder<'_, T> {
112 ClientBuilder::new(api_key)
113 }
114
115 pub fn new(api_key: &str) -> Self {
120 Self::builder(api_key).build()
121 }
122}
123
124impl<T> Client<T>
125where
126 T: HttpClientExt,
127{
128 fn req(
129 &self,
130 method: http_client::Method,
131 path: &str,
132 ) -> http_client::Result<http_client::Builder> {
133 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
134
135 http_client::with_bearer_auth(
136 http_client::Builder::new().method(method).uri(url),
137 &self.api_key,
138 )
139 }
140
141 fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
142 self.req(http_client::Method::GET, path)
143 }
144}
145
146impl Client<reqwest::Client> {
147 fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder {
148 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
149
150 self.http_client.post(url).bearer_auth(&self.api_key)
151 }
152}
153
154impl ProviderClient for Client<reqwest::Client> {
155 fn from_env() -> Self {
158 let api_key = std::env::var("HYPERBOLIC_API_KEY").expect("HYPERBOLIC_API_KEY not set");
159 Self::new(&api_key)
160 }
161
162 fn from_val(input: crate::client::ProviderValue) -> Self {
163 let crate::client::ProviderValue::Simple(api_key) = input else {
164 panic!("Incorrect provider value type")
165 };
166 Self::new(&api_key)
167 }
168}
169
170impl CompletionClient for Client<reqwest::Client> {
171 type CompletionModel = CompletionModel<reqwest::Client>;
172
173 fn completion_model(&self, model: &str) -> CompletionModel<reqwest::Client> {
185 CompletionModel::new(self.clone(), model)
186 }
187}
188
189impl VerifyClient for Client<reqwest::Client> {
190 #[cfg_attr(feature = "worker", worker::send)]
191 async fn verify(&self) -> Result<(), VerifyError> {
192 let req = self
193 .get("/models")?
194 .body(http_client::NoBody)
195 .map_err(http_client::Error::from)?;
196
197 let response = HttpClientExt::send(&self.http_client, req).await?;
198
199 match response.status() {
200 reqwest::StatusCode::OK => Ok(()),
201 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
202 reqwest::StatusCode::INTERNAL_SERVER_ERROR
203 | reqwest::StatusCode::SERVICE_UNAVAILABLE
204 | reqwest::StatusCode::BAD_GATEWAY => {
205 let text = http_client::text(response).await?;
206 Err(VerifyError::ProviderError(text))
207 }
208 _ => {
209 Ok(())
211 }
212 }
213 }
214}
215
216impl_conversion_traits!(
217 AsEmbeddings,
218 AsTranscription for Client<T>
219);
220
221#[derive(Debug, Deserialize)]
222struct ApiErrorResponse {
223 message: String,
224}
225
226#[derive(Debug, Deserialize)]
227#[serde(untagged)]
228enum ApiResponse<T> {
229 Ok(T),
230 Err(ApiErrorResponse),
231}
232
233#[derive(Debug, Deserialize)]
234pub struct EmbeddingData {
235 pub object: String,
236 pub embedding: Vec<f64>,
237 pub index: usize,
238}
239
240#[derive(Clone, Debug, Deserialize, Serialize)]
241pub struct Usage {
242 pub prompt_tokens: usize,
243 pub total_tokens: usize,
244}
245
246impl std::fmt::Display for Usage {
247 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
248 write!(
249 f,
250 "Prompt tokens: {} Total tokens: {}",
251 self.prompt_tokens, self.total_tokens
252 )
253 }
254}
255
256pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
261pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
263pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
265pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
267pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
269pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
271pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
273pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
275pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
277pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
279pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
281pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
283
284#[derive(Debug, Deserialize, Serialize)]
288pub struct CompletionResponse {
289 pub id: String,
290 pub object: String,
291 pub created: u64,
292 pub model: String,
293 pub choices: Vec<Choice>,
294 pub usage: Option<Usage>,
295}
296
297impl From<ApiErrorResponse> for CompletionError {
298 fn from(err: ApiErrorResponse) -> Self {
299 CompletionError::ProviderError(err.message)
300 }
301}
302
303impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
304 type Error = CompletionError;
305
306 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
307 let choice = response.choices.first().ok_or_else(|| {
308 CompletionError::ResponseError("Response contained no choices".to_owned())
309 })?;
310
311 let content = match &choice.message {
312 Message::Assistant {
313 content,
314 tool_calls,
315 ..
316 } => {
317 let mut content = content
318 .iter()
319 .map(|c| match c {
320 AssistantContent::Text { text } => completion::AssistantContent::text(text),
321 AssistantContent::Refusal { refusal } => {
322 completion::AssistantContent::text(refusal)
323 }
324 })
325 .collect::<Vec<_>>();
326
327 content.extend(
328 tool_calls
329 .iter()
330 .map(|call| {
331 completion::AssistantContent::tool_call(
332 &call.id,
333 &call.function.name,
334 call.function.arguments.clone(),
335 )
336 })
337 .collect::<Vec<_>>(),
338 );
339 Ok(content)
340 }
341 _ => Err(CompletionError::ResponseError(
342 "Response did not contain a valid message or tool call".into(),
343 )),
344 }?;
345
346 let choice = OneOrMany::many(content).map_err(|_| {
347 CompletionError::ResponseError(
348 "Response contained no message or tool call (empty)".to_owned(),
349 )
350 })?;
351
352 let usage = response
353 .usage
354 .as_ref()
355 .map(|usage| completion::Usage {
356 input_tokens: usage.prompt_tokens as u64,
357 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
358 total_tokens: usage.total_tokens as u64,
359 })
360 .unwrap_or_default();
361
362 Ok(completion::CompletionResponse {
363 choice,
364 usage,
365 raw_response: response,
366 })
367 }
368}
369
370#[derive(Debug, Deserialize, Serialize)]
371pub struct Choice {
372 pub index: usize,
373 pub message: Message,
374 pub finish_reason: String,
375}
376
377#[derive(Clone)]
378pub struct CompletionModel<T> {
379 client: Client<T>,
380 pub model: String,
382}
383
384impl<T> CompletionModel<T> {
385 pub fn new(client: Client<T>, model: &str) -> Self {
386 Self {
387 client,
388 model: model.to_string(),
389 }
390 }
391
392 pub(crate) fn create_completion_request(
393 &self,
394 completion_request: CompletionRequest,
395 ) -> Result<Value, CompletionError> {
396 if completion_request.tool_choice.is_some() {
397 tracing::warn!("WARNING: `tool_choice` not supported on Hyperbolic");
398 }
399 let mut partial_history = vec![];
401 if let Some(docs) = completion_request.normalized_documents() {
402 partial_history.push(docs);
403 }
404 partial_history.extend(completion_request.chat_history);
405
406 let mut full_history: Vec<Message> = completion_request
408 .preamble
409 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
410
411 full_history.extend(
413 partial_history
414 .into_iter()
415 .map(message::Message::try_into)
416 .collect::<Result<Vec<Vec<Message>>, _>>()?
417 .into_iter()
418 .flatten()
419 .collect::<Vec<_>>(),
420 );
421
422 let request = json!({
423 "model": self.model,
424 "messages": full_history,
425 "temperature": completion_request.temperature,
426 });
427
428 let request = if let Some(params) = completion_request.additional_params {
429 json_utils::merge(request, params)
430 } else {
431 request
432 };
433
434 Ok(request)
435 }
436}
437
438impl completion::CompletionModel for CompletionModel<reqwest::Client> {
439 type Response = CompletionResponse;
440 type StreamingResponse = openai::StreamingCompletionResponse;
441
442 #[cfg_attr(feature = "worker", worker::send)]
443 async fn completion(
444 &self,
445 completion_request: CompletionRequest,
446 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
447 let preamble = completion_request.preamble.clone();
448 let request = self.create_completion_request(completion_request)?;
449
450 let span = if tracing::Span::current().is_disabled() {
451 info_span!(
452 target: "rig::completions",
453 "chat",
454 gen_ai.operation.name = "chat",
455 gen_ai.provider.name = "hyperbolic",
456 gen_ai.request.model = self.model,
457 gen_ai.system_instructions = preamble,
458 gen_ai.response.id = tracing::field::Empty,
459 gen_ai.response.model = tracing::field::Empty,
460 gen_ai.usage.output_tokens = tracing::field::Empty,
461 gen_ai.usage.input_tokens = tracing::field::Empty,
462 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
463 gen_ai.output.messages = tracing::field::Empty,
464 )
465 } else {
466 tracing::Span::current()
467 };
468
469 let async_block = async move {
470 let response = self
471 .client
472 .reqwest_post("/v1/chat/completions")
473 .json(&request)
474 .send()
475 .await
476 .map_err(|e| http_client::Error::Instance(e.into()))?;
477
478 if response.status().is_success() {
479 match response
480 .json::<ApiResponse<CompletionResponse>>()
481 .await
482 .map_err(|e| http_client::Error::Instance(e.into()))?
483 {
484 ApiResponse::Ok(response) => {
485 tracing::info!(target: "rig",
486 "Hyperbolic completion token usage: {:?}",
487 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
488 );
489
490 response.try_into()
491 }
492 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
493 }
494 } else {
495 Err(CompletionError::ProviderError(
496 response
497 .text()
498 .await
499 .map_err(|e| http_client::Error::Instance(e.into()))?,
500 ))
501 }
502 };
503
504 async_block.instrument(span).await
505 }
506
507 #[cfg_attr(feature = "worker", worker::send)]
508 async fn stream(
509 &self,
510 completion_request: CompletionRequest,
511 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
512 let preamble = completion_request.preamble.clone();
513 let mut request = self.create_completion_request(completion_request)?;
514
515 let span = if tracing::Span::current().is_disabled() {
516 info_span!(
517 target: "rig::completions",
518 "chat_streaming",
519 gen_ai.operation.name = "chat_streaming",
520 gen_ai.provider.name = "hyperbolic",
521 gen_ai.request.model = self.model,
522 gen_ai.system_instructions = preamble,
523 gen_ai.response.id = tracing::field::Empty,
524 gen_ai.response.model = tracing::field::Empty,
525 gen_ai.usage.output_tokens = tracing::field::Empty,
526 gen_ai.usage.input_tokens = tracing::field::Empty,
527 gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
528 gen_ai.output.messages = tracing::field::Empty,
529 )
530 } else {
531 tracing::Span::current()
532 };
533
534 merge_inplace(
535 &mut request,
536 json!({"stream": true, "stream_options": {"include_usage": true}}),
537 );
538
539 let builder = self
540 .client
541 .reqwest_post("/v1/chat/completions")
542 .json(&request);
543
544 send_compatible_streaming_request(builder)
545 .instrument(span)
546 .await
547 }
548}
549
550#[cfg(feature = "image")]
555pub use image_generation::*;
556
557#[cfg(feature = "image")]
558#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
559mod image_generation {
560 use super::{ApiResponse, Client};
561 use crate::client::ImageGenerationClient;
562 use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
563 use crate::json_utils::merge_inplace;
564 use crate::{http_client, image_generation};
565 use base64::Engine;
566 use base64::prelude::BASE64_STANDARD;
567 use serde::Deserialize;
568 use serde_json::json;
569
570 pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
571 pub const SD2: &str = "SD2";
572 pub const SD1_5: &str = "SD1.5";
573 pub const SSD: &str = "SSD";
574 pub const SDXL_TURBO: &str = "SDXL-turbo";
575 pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
576 pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
577
578 #[derive(Clone)]
579 pub struct ImageGenerationModel<T> {
580 client: Client<T>,
581 pub model: String,
582 }
583
584 impl<T> ImageGenerationModel<T> {
585 pub(crate) fn new(client: Client<T>, model: &str) -> ImageGenerationModel<T> {
586 Self {
587 client,
588 model: model.to_string(),
589 }
590 }
591 }
592
593 #[derive(Clone, Deserialize)]
594 pub struct Image {
595 image: String,
596 }
597
598 #[derive(Clone, Deserialize)]
599 pub struct ImageGenerationResponse {
600 images: Vec<Image>,
601 }
602
603 impl TryFrom<ImageGenerationResponse>
604 for image_generation::ImageGenerationResponse<ImageGenerationResponse>
605 {
606 type Error = ImageGenerationError;
607
608 fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
609 let data = BASE64_STANDARD
610 .decode(&value.images[0].image)
611 .expect("Could not decode image.");
612
613 Ok(Self {
614 image: data,
615 response: value,
616 })
617 }
618 }
619
620 impl image_generation::ImageGenerationModel for ImageGenerationModel<reqwest::Client> {
621 type Response = ImageGenerationResponse;
622
623 #[cfg_attr(feature = "worker", worker::send)]
624 async fn image_generation(
625 &self,
626 generation_request: ImageGenerationRequest,
627 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
628 {
629 let mut request = json!({
630 "model_name": self.model,
631 "prompt": generation_request.prompt,
632 "height": generation_request.height,
633 "width": generation_request.width,
634 });
635
636 if let Some(params) = generation_request.additional_params {
637 merge_inplace(&mut request, params);
638 }
639
640 let response = self
641 .client
642 .reqwest_post("/v1/image/generation")
643 .json(&request)
644 .send()
645 .await
646 .map_err(|e| {
647 ImageGenerationError::HttpError(http_client::Error::Instance(e.into()))
648 })?;
649
650 if !response.status().is_success() {
651 return Err(ImageGenerationError::ProviderError(format!(
652 "{}: {}",
653 response.status().as_str(),
654 response.text().await.map_err(|e| {
655 ImageGenerationError::HttpError(http_client::Error::Instance(e.into()))
656 })?
657 )));
658 }
659
660 match response
661 .json::<ApiResponse<ImageGenerationResponse>>()
662 .await
663 .map_err(|e| {
664 ImageGenerationError::HttpError(http_client::Error::Instance(e.into()))
665 })? {
666 ApiResponse::Ok(response) => response.try_into(),
667 ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
668 }
669 }
670 }
671
672 impl ImageGenerationClient for Client<reqwest::Client> {
673 type ImageGenerationModel = ImageGenerationModel<reqwest::Client>;
674
675 fn image_generation_model(&self, model: &str) -> ImageGenerationModel<reqwest::Client> {
687 ImageGenerationModel::new(self.clone(), model)
688 }
689 }
690}
691
692#[cfg(feature = "audio")]
696pub use audio_generation::*;
697use tracing::{Instrument, info_span};
698
699#[cfg(feature = "audio")]
700#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
701mod audio_generation {
702 use super::{ApiResponse, Client};
703 use crate::audio_generation;
704 use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
705 use crate::client::AudioGenerationClient;
706 use crate::http_client;
707 use base64::Engine;
708 use base64::prelude::BASE64_STANDARD;
709 use serde::Deserialize;
710 use serde_json::json;
711
712 #[derive(Clone)]
713 pub struct AudioGenerationModel<T> {
714 client: Client<T>,
715 pub language: String,
716 }
717
718 impl<T> AudioGenerationModel<T> {
719 pub(crate) fn new(client: Client<T>, language: &str) -> AudioGenerationModel<T> {
720 Self {
721 client,
722 language: language.to_string(),
723 }
724 }
725 }
726
727 #[derive(Clone, Deserialize)]
728 pub struct AudioGenerationResponse {
729 audio: String,
730 }
731
732 impl TryFrom<AudioGenerationResponse>
733 for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
734 {
735 type Error = AudioGenerationError;
736
737 fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
738 let data = BASE64_STANDARD
739 .decode(&value.audio)
740 .expect("Could not decode audio.");
741
742 Ok(Self {
743 audio: data,
744 response: value,
745 })
746 }
747 }
748
749 impl audio_generation::AudioGenerationModel for AudioGenerationModel<reqwest::Client> {
750 type Response = AudioGenerationResponse;
751
752 #[cfg_attr(feature = "worker", worker::send)]
753 async fn audio_generation(
754 &self,
755 request: AudioGenerationRequest,
756 ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
757 {
758 let request = json!({
759 "language": self.language,
760 "speaker": request.voice,
761 "text": request.text,
762 "speed": request.speed
763 });
764
765 let response = self
766 .client
767 .reqwest_post("/v1/audio/generation")
768 .json(&request)
769 .send()
770 .await
771 .map_err(|e| {
772 AudioGenerationError::HttpError(http_client::Error::Instance(e.into()))
773 })?;
774
775 if !response.status().is_success() {
776 return Err(AudioGenerationError::ProviderError(format!(
777 "{}: {}",
778 response.status(),
779 response.text().await.map_err(|e| {
780 AudioGenerationError::HttpError(http_client::Error::Instance(e.into()))
781 })?
782 )));
783 }
784
785 match serde_json::from_str::<ApiResponse<AudioGenerationResponse>>(
786 &response.text().await.map_err(|e| {
787 AudioGenerationError::HttpError(http_client::Error::Instance(e.into()))
788 })?,
789 )? {
790 ApiResponse::Ok(response) => response.try_into(),
791 ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
792 }
793 }
794 }
795 impl AudioGenerationClient for Client<reqwest::Client> {
796 type AudioGenerationModel = AudioGenerationModel<reqwest::Client>;
797
798 fn audio_generation_model(&self, language: &str) -> AudioGenerationModel<reqwest::Client> {
810 AudioGenerationModel::new(self.clone(), language)
811 }
812 }
813}