1use super::openai::{AssistantContent, send_compatible_streaming_request};
12
13use crate::client::{
14 ClientBuilderError, CompletionClient, ProviderClient, VerifyClient, VerifyError,
15};
16use crate::json_utils::merge_inplace;
17use crate::message;
18use crate::streaming::StreamingCompletionResponse;
19
20use crate::impl_conversion_traits;
21use crate::providers::openai;
22use crate::{
23 OneOrMany,
24 completion::{self, CompletionError, CompletionRequest},
25 json_utils,
26 providers::openai::Message,
27};
28use serde::{Deserialize, Serialize};
29use serde_json::{Value, json};
30
31const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz";
35
36pub struct ClientBuilder<'a> {
37 api_key: &'a str,
38 base_url: &'a str,
39 http_client: Option<reqwest::Client>,
40}
41
42impl<'a> ClientBuilder<'a> {
43 pub fn new(api_key: &'a str) -> Self {
44 Self {
45 api_key,
46 base_url: HYPERBOLIC_API_BASE_URL,
47 http_client: None,
48 }
49 }
50
51 pub fn base_url(mut self, base_url: &'a str) -> Self {
52 self.base_url = base_url;
53 self
54 }
55
56 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
57 self.http_client = Some(client);
58 self
59 }
60
61 pub fn build(self) -> Result<Client, ClientBuilderError> {
62 let http_client = if let Some(http_client) = self.http_client {
63 http_client
64 } else {
65 reqwest::Client::builder().build()?
66 };
67
68 Ok(Client {
69 base_url: self.base_url.to_string(),
70 api_key: self.api_key.to_string(),
71 http_client,
72 })
73 }
74}
75
76#[derive(Clone)]
77pub struct Client {
78 base_url: String,
79 api_key: String,
80 http_client: reqwest::Client,
81}
82
83impl std::fmt::Debug for Client {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 f.debug_struct("Client")
86 .field("base_url", &self.base_url)
87 .field("http_client", &self.http_client)
88 .field("api_key", &"<REDACTED>")
89 .finish()
90 }
91}
92
93impl Client {
94 pub fn builder(api_key: &str) -> ClientBuilder<'_> {
105 ClientBuilder::new(api_key)
106 }
107
108 pub fn new(api_key: &str) -> Self {
113 Self::builder(api_key)
114 .build()
115 .expect("Hyperbolic client should build")
116 }
117
118 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
119 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
120 self.http_client.post(url).bearer_auth(&self.api_key)
121 }
122
123 pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder {
124 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
125 self.http_client.get(url).bearer_auth(&self.api_key)
126 }
127}
128
129impl ProviderClient for Client {
130 fn from_env() -> Self {
133 let api_key = std::env::var("HYPERBOLIC_API_KEY").expect("HYPERBOLIC_API_KEY not set");
134 Self::new(&api_key)
135 }
136
137 fn from_val(input: crate::client::ProviderValue) -> Self {
138 let crate::client::ProviderValue::Simple(api_key) = input else {
139 panic!("Incorrect provider value type")
140 };
141 Self::new(&api_key)
142 }
143}
144
145impl CompletionClient for Client {
146 type CompletionModel = CompletionModel;
147
148 fn completion_model(&self, model: &str) -> CompletionModel {
160 CompletionModel::new(self.clone(), model)
161 }
162}
163
164impl VerifyClient for Client {
165 #[cfg_attr(feature = "worker", worker::send)]
166 async fn verify(&self) -> Result<(), VerifyError> {
167 let response = self.get("/users/me").send().await?;
168 match response.status() {
169 reqwest::StatusCode::OK => Ok(()),
170 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
171 reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
172 Err(VerifyError::ProviderError(response.text().await?))
173 }
174 _ => {
175 response.error_for_status()?;
176 Ok(())
177 }
178 }
179 }
180}
181
182impl_conversion_traits!(
183 AsEmbeddings,
184 AsTranscription for Client
185);
186
187#[derive(Debug, Deserialize)]
188struct ApiErrorResponse {
189 message: String,
190}
191
192#[derive(Debug, Deserialize)]
193#[serde(untagged)]
194enum ApiResponse<T> {
195 Ok(T),
196 Err(ApiErrorResponse),
197}
198
199#[derive(Debug, Deserialize)]
200pub struct EmbeddingData {
201 pub object: String,
202 pub embedding: Vec<f64>,
203 pub index: usize,
204}
205
206#[derive(Clone, Debug, Deserialize, Serialize)]
207pub struct Usage {
208 pub prompt_tokens: usize,
209 pub total_tokens: usize,
210}
211
212impl std::fmt::Display for Usage {
213 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214 write!(
215 f,
216 "Prompt tokens: {} Total tokens: {}",
217 self.prompt_tokens, self.total_tokens
218 )
219 }
220}
221
222pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
227pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
229pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
231pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
233pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
235pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
237pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
239pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
241pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
243pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
245pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
247pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
249
250#[derive(Debug, Deserialize, Serialize)]
254pub struct CompletionResponse {
255 pub id: String,
256 pub object: String,
257 pub created: u64,
258 pub model: String,
259 pub choices: Vec<Choice>,
260 pub usage: Option<Usage>,
261}
262
263impl From<ApiErrorResponse> for CompletionError {
264 fn from(err: ApiErrorResponse) -> Self {
265 CompletionError::ProviderError(err.message)
266 }
267}
268
269impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
270 type Error = CompletionError;
271
272 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
273 let choice = response.choices.first().ok_or_else(|| {
274 CompletionError::ResponseError("Response contained no choices".to_owned())
275 })?;
276
277 let content = match &choice.message {
278 Message::Assistant {
279 content,
280 tool_calls,
281 ..
282 } => {
283 let mut content = content
284 .iter()
285 .map(|c| match c {
286 AssistantContent::Text { text } => completion::AssistantContent::text(text),
287 AssistantContent::Refusal { refusal } => {
288 completion::AssistantContent::text(refusal)
289 }
290 })
291 .collect::<Vec<_>>();
292
293 content.extend(
294 tool_calls
295 .iter()
296 .map(|call| {
297 completion::AssistantContent::tool_call(
298 &call.id,
299 &call.function.name,
300 call.function.arguments.clone(),
301 )
302 })
303 .collect::<Vec<_>>(),
304 );
305 Ok(content)
306 }
307 _ => Err(CompletionError::ResponseError(
308 "Response did not contain a valid message or tool call".into(),
309 )),
310 }?;
311
312 let choice = OneOrMany::many(content).map_err(|_| {
313 CompletionError::ResponseError(
314 "Response contained no message or tool call (empty)".to_owned(),
315 )
316 })?;
317
318 let usage = response
319 .usage
320 .as_ref()
321 .map(|usage| completion::Usage {
322 input_tokens: usage.prompt_tokens as u64,
323 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
324 total_tokens: usage.total_tokens as u64,
325 })
326 .unwrap_or_default();
327
328 Ok(completion::CompletionResponse {
329 choice,
330 usage,
331 raw_response: response,
332 })
333 }
334}
335
336#[derive(Debug, Deserialize, Serialize)]
337pub struct Choice {
338 pub index: usize,
339 pub message: Message,
340 pub finish_reason: String,
341}
342
343#[derive(Clone)]
344pub struct CompletionModel {
345 client: Client,
346 pub model: String,
348}
349
350impl CompletionModel {
351 pub(crate) fn create_completion_request(
352 &self,
353 completion_request: CompletionRequest,
354 ) -> Result<Value, CompletionError> {
355 let mut partial_history = vec![];
357 if let Some(docs) = completion_request.normalized_documents() {
358 partial_history.push(docs);
359 }
360 partial_history.extend(completion_request.chat_history);
361
362 let mut full_history: Vec<Message> = completion_request
364 .preamble
365 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
366
367 full_history.extend(
369 partial_history
370 .into_iter()
371 .map(message::Message::try_into)
372 .collect::<Result<Vec<Vec<Message>>, _>>()?
373 .into_iter()
374 .flatten()
375 .collect::<Vec<_>>(),
376 );
377
378 let request = json!({
379 "model": self.model,
380 "messages": full_history,
381 "temperature": completion_request.temperature,
382 });
383
384 let request = if let Some(params) = completion_request.additional_params {
385 json_utils::merge(request, params)
386 } else {
387 request
388 };
389
390 Ok(request)
391 }
392}
393
394impl CompletionModel {
395 pub fn new(client: Client, model: &str) -> Self {
396 Self {
397 client,
398 model: model.to_string(),
399 }
400 }
401}
402
403impl completion::CompletionModel for CompletionModel {
404 type Response = CompletionResponse;
405 type StreamingResponse = openai::StreamingCompletionResponse;
406
407 #[cfg_attr(feature = "worker", worker::send)]
408 async fn completion(
409 &self,
410 completion_request: CompletionRequest,
411 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
412 let request = self.create_completion_request(completion_request)?;
413
414 let response = self
415 .client
416 .post("/v1/chat/completions")
417 .json(&request)
418 .send()
419 .await?;
420
421 if response.status().is_success() {
422 match response.json::<ApiResponse<CompletionResponse>>().await? {
423 ApiResponse::Ok(response) => {
424 tracing::info!(target: "rig",
425 "Hyperbolic completion token usage: {:?}",
426 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
427 );
428
429 response.try_into()
430 }
431 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
432 }
433 } else {
434 Err(CompletionError::ProviderError(response.text().await?))
435 }
436 }
437
438 #[cfg_attr(feature = "worker", worker::send)]
439 async fn stream(
440 &self,
441 completion_request: CompletionRequest,
442 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
443 let mut request = self.create_completion_request(completion_request)?;
444
445 merge_inplace(
446 &mut request,
447 json!({"stream": true, "stream_options": {"include_usage": true}}),
448 );
449
450 let builder = self.client.post("/v1/chat/completions").json(&request);
451
452 send_compatible_streaming_request(builder).await
453 }
454}
455
456#[cfg(feature = "image")]
461pub use image_generation::*;
462
463#[cfg(feature = "image")]
464mod image_generation {
465 use super::{ApiResponse, Client};
466 use crate::client::ImageGenerationClient;
467 use crate::image_generation;
468 use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
469 use crate::json_utils::merge_inplace;
470 use base64::Engine;
471 use base64::prelude::BASE64_STANDARD;
472 use serde::Deserialize;
473 use serde_json::json;
474
475 pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
476 pub const SD2: &str = "SD2";
477 pub const SD1_5: &str = "SD1.5";
478 pub const SSD: &str = "SSD";
479 pub const SDXL_TURBO: &str = "SDXL-turbo";
480 pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
481 pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
482
483 #[cfg(feature = "image")]
484 #[derive(Clone)]
485 pub struct ImageGenerationModel {
486 client: Client,
487 pub model: String,
488 }
489
490 #[cfg(feature = "image")]
491 impl ImageGenerationModel {
492 pub(crate) fn new(client: Client, model: &str) -> ImageGenerationModel {
493 Self {
494 client,
495 model: model.to_string(),
496 }
497 }
498 }
499
500 #[cfg(feature = "image")]
501 #[derive(Clone, Deserialize)]
502 pub struct Image {
503 image: String,
504 }
505
506 #[cfg(feature = "image")]
507 #[derive(Clone, Deserialize)]
508 pub struct ImageGenerationResponse {
509 images: Vec<Image>,
510 }
511
512 #[cfg(feature = "image")]
513 impl TryFrom<ImageGenerationResponse>
514 for image_generation::ImageGenerationResponse<ImageGenerationResponse>
515 {
516 type Error = ImageGenerationError;
517
518 fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
519 let data = BASE64_STANDARD
520 .decode(&value.images[0].image)
521 .expect("Could not decode image.");
522
523 Ok(Self {
524 image: data,
525 response: value,
526 })
527 }
528 }
529
530 #[cfg(feature = "image")]
531 impl image_generation::ImageGenerationModel for ImageGenerationModel {
532 type Response = ImageGenerationResponse;
533
534 #[cfg_attr(feature = "worker", worker::send)]
535 async fn image_generation(
536 &self,
537 generation_request: ImageGenerationRequest,
538 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
539 {
540 let mut request = json!({
541 "model_name": self.model,
542 "prompt": generation_request.prompt,
543 "height": generation_request.height,
544 "width": generation_request.width,
545 });
546
547 if let Some(params) = generation_request.additional_params {
548 merge_inplace(&mut request, params);
549 }
550
551 let response = self
552 .client
553 .post("/v1/image/generation")
554 .json(&request)
555 .send()
556 .await?;
557
558 if !response.status().is_success() {
559 return Err(ImageGenerationError::ProviderError(format!(
560 "{}: {}",
561 response.status().as_str(),
562 response.text().await?
563 )));
564 }
565
566 match response
567 .json::<ApiResponse<ImageGenerationResponse>>()
568 .await?
569 {
570 ApiResponse::Ok(response) => response.try_into(),
571 ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
572 }
573 }
574 }
575
576 impl ImageGenerationClient for Client {
577 type ImageGenerationModel = ImageGenerationModel;
578
579 fn image_generation_model(&self, model: &str) -> ImageGenerationModel {
591 ImageGenerationModel::new(self.clone(), model)
592 }
593 }
594}
595
596#[cfg(feature = "audio")]
600pub use audio_generation::*;
601
602#[cfg(feature = "audio")]
603mod audio_generation {
604 use super::{ApiResponse, Client};
605 use crate::audio_generation;
606 use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
607 use crate::client::AudioGenerationClient;
608 use base64::Engine;
609 use base64::prelude::BASE64_STANDARD;
610 use serde::Deserialize;
611 use serde_json::json;
612
613 #[derive(Clone)]
614 pub struct AudioGenerationModel {
615 client: Client,
616 pub language: String,
617 }
618
619 impl AudioGenerationModel {
620 pub(crate) fn new(client: Client, language: &str) -> AudioGenerationModel {
621 Self {
622 client,
623 language: language.to_string(),
624 }
625 }
626 }
627
628 #[derive(Clone, Deserialize)]
629 pub struct AudioGenerationResponse {
630 audio: String,
631 }
632
633 impl TryFrom<AudioGenerationResponse>
634 for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
635 {
636 type Error = AudioGenerationError;
637
638 fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
639 let data = BASE64_STANDARD
640 .decode(&value.audio)
641 .expect("Could not decode audio.");
642
643 Ok(Self {
644 audio: data,
645 response: value,
646 })
647 }
648 }
649
650 impl audio_generation::AudioGenerationModel for AudioGenerationModel {
651 type Response = AudioGenerationResponse;
652
653 #[cfg_attr(feature = "worker", worker::send)]
654 async fn audio_generation(
655 &self,
656 request: AudioGenerationRequest,
657 ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
658 {
659 let request = json!({
660 "language": self.language,
661 "speaker": request.voice,
662 "text": request.text,
663 "speed": request.speed
664 });
665
666 let response = self
667 .client
668 .post("/v1/audio/generation")
669 .json(&request)
670 .send()
671 .await?;
672
673 if !response.status().is_success() {
674 return Err(AudioGenerationError::ProviderError(format!(
675 "{}: {}",
676 response.status(),
677 response.text().await?
678 )));
679 }
680
681 match serde_json::from_str::<ApiResponse<AudioGenerationResponse>>(
682 &response.text().await?,
683 )? {
684 ApiResponse::Ok(response) => response.try_into(),
685 ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
686 }
687 }
688 }
689 impl AudioGenerationClient for Client {
690 type AudioGenerationModel = AudioGenerationModel;
691
692 fn audio_generation_model(&self, language: &str) -> AudioGenerationModel {
704 AudioGenerationModel::new(self.clone(), language)
705 }
706 }
707}