1use super::openai::{AssistantContent, send_compatible_streaming_request};
12
13use crate::client::{ClientBuilderError, CompletionClient, ProviderClient};
14use crate::json_utils::merge_inplace;
15use crate::message;
16use crate::streaming::StreamingCompletionResponse;
17
18use crate::impl_conversion_traits;
19use crate::providers::openai;
20use crate::{
21 OneOrMany,
22 completion::{self, CompletionError, CompletionRequest},
23 json_utils,
24 providers::openai::Message,
25};
26use serde::{Deserialize, Serialize};
27use serde_json::{Value, json};
28
29const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz/v1";
33
34pub struct ClientBuilder<'a> {
35 api_key: &'a str,
36 base_url: &'a str,
37 http_client: Option<reqwest::Client>,
38}
39
40impl<'a> ClientBuilder<'a> {
41 pub fn new(api_key: &'a str) -> Self {
42 Self {
43 api_key,
44 base_url: HYPERBOLIC_API_BASE_URL,
45 http_client: None,
46 }
47 }
48
49 pub fn base_url(mut self, base_url: &'a str) -> Self {
50 self.base_url = base_url;
51 self
52 }
53
54 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
55 self.http_client = Some(client);
56 self
57 }
58
59 pub fn build(self) -> Result<Client, ClientBuilderError> {
60 let http_client = if let Some(http_client) = self.http_client {
61 http_client
62 } else {
63 reqwest::Client::builder().build()?
64 };
65
66 Ok(Client {
67 base_url: self.base_url.to_string(),
68 api_key: self.api_key.to_string(),
69 http_client,
70 })
71 }
72}
73
74#[derive(Clone)]
75pub struct Client {
76 base_url: String,
77 api_key: String,
78 http_client: reqwest::Client,
79}
80
81impl std::fmt::Debug for Client {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 f.debug_struct("Client")
84 .field("base_url", &self.base_url)
85 .field("http_client", &self.http_client)
86 .field("api_key", &"<REDACTED>")
87 .finish()
88 }
89}
90
91impl Client {
92 pub fn builder(api_key: &str) -> ClientBuilder<'_> {
103 ClientBuilder::new(api_key)
104 }
105
106 pub fn new(api_key: &str) -> Self {
111 Self::builder(api_key)
112 .build()
113 .expect("Hyperbolic client should build")
114 }
115
116 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
117 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
118 self.http_client.post(url).bearer_auth(&self.api_key)
119 }
120}
121
122impl ProviderClient for Client {
123 fn from_env() -> Self {
126 let api_key = std::env::var("HYPERBOLIC_API_KEY").expect("HYPERBOLIC_API_KEY not set");
127 Self::new(&api_key)
128 }
129
130 fn from_val(input: crate::client::ProviderValue) -> Self {
131 let crate::client::ProviderValue::Simple(api_key) = input else {
132 panic!("Incorrect provider value type")
133 };
134 Self::new(&api_key)
135 }
136}
137
138impl CompletionClient for Client {
139 type CompletionModel = CompletionModel;
140
141 fn completion_model(&self, model: &str) -> CompletionModel {
153 CompletionModel::new(self.clone(), model)
154 }
155}
156
157impl_conversion_traits!(
158 AsEmbeddings,
159 AsTranscription for Client
160);
161
162#[derive(Debug, Deserialize)]
163struct ApiErrorResponse {
164 message: String,
165}
166
167#[derive(Debug, Deserialize)]
168#[serde(untagged)]
169enum ApiResponse<T> {
170 Ok(T),
171 Err(ApiErrorResponse),
172}
173
174#[derive(Debug, Deserialize)]
175pub struct EmbeddingData {
176 pub object: String,
177 pub embedding: Vec<f64>,
178 pub index: usize,
179}
180
181#[derive(Clone, Debug, Deserialize, Serialize)]
182pub struct Usage {
183 pub prompt_tokens: usize,
184 pub total_tokens: usize,
185}
186
187impl std::fmt::Display for Usage {
188 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189 write!(
190 f,
191 "Prompt tokens: {} Total tokens: {}",
192 self.prompt_tokens, self.total_tokens
193 )
194 }
195}
196
197pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
202pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
204pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
206pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
208pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
210pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
212pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
214pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
216pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
218pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
220pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
222pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
224
225#[derive(Debug, Deserialize, Serialize)]
229pub struct CompletionResponse {
230 pub id: String,
231 pub object: String,
232 pub created: u64,
233 pub model: String,
234 pub choices: Vec<Choice>,
235 pub usage: Option<Usage>,
236}
237
238impl From<ApiErrorResponse> for CompletionError {
239 fn from(err: ApiErrorResponse) -> Self {
240 CompletionError::ProviderError(err.message)
241 }
242}
243
244impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
245 type Error = CompletionError;
246
247 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
248 let choice = response.choices.first().ok_or_else(|| {
249 CompletionError::ResponseError("Response contained no choices".to_owned())
250 })?;
251
252 let content = match &choice.message {
253 Message::Assistant {
254 content,
255 tool_calls,
256 ..
257 } => {
258 let mut content = content
259 .iter()
260 .map(|c| match c {
261 AssistantContent::Text { text } => completion::AssistantContent::text(text),
262 AssistantContent::Refusal { refusal } => {
263 completion::AssistantContent::text(refusal)
264 }
265 })
266 .collect::<Vec<_>>();
267
268 content.extend(
269 tool_calls
270 .iter()
271 .map(|call| {
272 completion::AssistantContent::tool_call(
273 &call.id,
274 &call.function.name,
275 call.function.arguments.clone(),
276 )
277 })
278 .collect::<Vec<_>>(),
279 );
280 Ok(content)
281 }
282 _ => Err(CompletionError::ResponseError(
283 "Response did not contain a valid message or tool call".into(),
284 )),
285 }?;
286
287 let choice = OneOrMany::many(content).map_err(|_| {
288 CompletionError::ResponseError(
289 "Response contained no message or tool call (empty)".to_owned(),
290 )
291 })?;
292
293 let usage = response
294 .usage
295 .as_ref()
296 .map(|usage| completion::Usage {
297 input_tokens: usage.prompt_tokens as u64,
298 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
299 total_tokens: usage.total_tokens as u64,
300 })
301 .unwrap_or_default();
302
303 Ok(completion::CompletionResponse {
304 choice,
305 usage,
306 raw_response: response,
307 })
308 }
309}
310
311#[derive(Debug, Deserialize, Serialize)]
312pub struct Choice {
313 pub index: usize,
314 pub message: Message,
315 pub finish_reason: String,
316}
317
318#[derive(Clone)]
319pub struct CompletionModel {
320 client: Client,
321 pub model: String,
323}
324
325impl CompletionModel {
326 pub(crate) fn create_completion_request(
327 &self,
328 completion_request: CompletionRequest,
329 ) -> Result<Value, CompletionError> {
330 let mut partial_history = vec![];
332 if let Some(docs) = completion_request.normalized_documents() {
333 partial_history.push(docs);
334 }
335 partial_history.extend(completion_request.chat_history);
336
337 let mut full_history: Vec<Message> = completion_request
339 .preamble
340 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
341
342 full_history.extend(
344 partial_history
345 .into_iter()
346 .map(message::Message::try_into)
347 .collect::<Result<Vec<Vec<Message>>, _>>()?
348 .into_iter()
349 .flatten()
350 .collect::<Vec<_>>(),
351 );
352
353 let request = json!({
354 "model": self.model,
355 "messages": full_history,
356 "temperature": completion_request.temperature,
357 });
358
359 let request = if let Some(params) = completion_request.additional_params {
360 json_utils::merge(request, params)
361 } else {
362 request
363 };
364
365 Ok(request)
366 }
367}
368
369impl CompletionModel {
370 pub fn new(client: Client, model: &str) -> Self {
371 Self {
372 client,
373 model: model.to_string(),
374 }
375 }
376}
377
378impl completion::CompletionModel for CompletionModel {
379 type Response = CompletionResponse;
380 type StreamingResponse = openai::StreamingCompletionResponse;
381
382 #[cfg_attr(feature = "worker", worker::send)]
383 async fn completion(
384 &self,
385 completion_request: CompletionRequest,
386 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
387 let request = self.create_completion_request(completion_request)?;
388
389 let response = self
390 .client
391 .post("/chat/completions")
392 .json(&request)
393 .send()
394 .await?;
395
396 if response.status().is_success() {
397 match response.json::<ApiResponse<CompletionResponse>>().await? {
398 ApiResponse::Ok(response) => {
399 tracing::info!(target: "rig",
400 "Hyperbolic completion token usage: {:?}",
401 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
402 );
403
404 response.try_into()
405 }
406 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
407 }
408 } else {
409 Err(CompletionError::ProviderError(response.text().await?))
410 }
411 }
412
413 #[cfg_attr(feature = "worker", worker::send)]
414 async fn stream(
415 &self,
416 completion_request: CompletionRequest,
417 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
418 let mut request = self.create_completion_request(completion_request)?;
419
420 merge_inplace(
421 &mut request,
422 json!({"stream": true, "stream_options": {"include_usage": true}}),
423 );
424
425 let builder = self.client.post("/chat/completions").json(&request);
426
427 send_compatible_streaming_request(builder).await
428 }
429}
430
431#[cfg(feature = "image")]
436pub use image_generation::*;
437
438#[cfg(feature = "image")]
439mod image_generation {
440 use super::{ApiResponse, Client};
441 use crate::client::ImageGenerationClient;
442 use crate::image_generation;
443 use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
444 use crate::json_utils::merge_inplace;
445 use base64::Engine;
446 use base64::prelude::BASE64_STANDARD;
447 use serde::Deserialize;
448 use serde_json::json;
449
450 pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
451 pub const SD2: &str = "SD2";
452 pub const SD1_5: &str = "SD1.5";
453 pub const SSD: &str = "SSD";
454 pub const SDXL_TURBO: &str = "SDXL-turbo";
455 pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
456 pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
457
458 #[cfg(feature = "image")]
459 #[derive(Clone)]
460 pub struct ImageGenerationModel {
461 client: Client,
462 pub model: String,
463 }
464
465 #[cfg(feature = "image")]
466 impl ImageGenerationModel {
467 pub(crate) fn new(client: Client, model: &str) -> ImageGenerationModel {
468 Self {
469 client,
470 model: model.to_string(),
471 }
472 }
473 }
474
475 #[cfg(feature = "image")]
476 #[derive(Clone, Deserialize)]
477 pub struct Image {
478 image: String,
479 }
480
481 #[cfg(feature = "image")]
482 #[derive(Clone, Deserialize)]
483 pub struct ImageGenerationResponse {
484 images: Vec<Image>,
485 }
486
487 #[cfg(feature = "image")]
488 impl TryFrom<ImageGenerationResponse>
489 for image_generation::ImageGenerationResponse<ImageGenerationResponse>
490 {
491 type Error = ImageGenerationError;
492
493 fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
494 let data = BASE64_STANDARD
495 .decode(&value.images[0].image)
496 .expect("Could not decode image.");
497
498 Ok(Self {
499 image: data,
500 response: value,
501 })
502 }
503 }
504
505 #[cfg(feature = "image")]
506 impl image_generation::ImageGenerationModel for ImageGenerationModel {
507 type Response = ImageGenerationResponse;
508
509 #[cfg_attr(feature = "worker", worker::send)]
510 async fn image_generation(
511 &self,
512 generation_request: ImageGenerationRequest,
513 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
514 {
515 let mut request = json!({
516 "model_name": self.model,
517 "prompt": generation_request.prompt,
518 "height": generation_request.height,
519 "width": generation_request.width,
520 });
521
522 if let Some(params) = generation_request.additional_params {
523 merge_inplace(&mut request, params);
524 }
525
526 let response = self
527 .client
528 .post("/image/generation")
529 .json(&request)
530 .send()
531 .await?;
532
533 if !response.status().is_success() {
534 return Err(ImageGenerationError::ProviderError(format!(
535 "{}: {}",
536 response.status().as_str(),
537 response.text().await?
538 )));
539 }
540
541 match response
542 .json::<ApiResponse<ImageGenerationResponse>>()
543 .await?
544 {
545 ApiResponse::Ok(response) => response.try_into(),
546 ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
547 }
548 }
549 }
550
551 impl ImageGenerationClient for Client {
552 type ImageGenerationModel = ImageGenerationModel;
553
554 fn image_generation_model(&self, model: &str) -> ImageGenerationModel {
566 ImageGenerationModel::new(self.clone(), model)
567 }
568 }
569}
570
571#[cfg(feature = "audio")]
575pub use audio_generation::*;
576
577#[cfg(feature = "audio")]
578mod audio_generation {
579 use super::{ApiResponse, Client};
580 use crate::audio_generation;
581 use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
582 use crate::client::AudioGenerationClient;
583 use base64::Engine;
584 use base64::prelude::BASE64_STANDARD;
585 use serde::Deserialize;
586 use serde_json::json;
587
588 #[derive(Clone)]
589 pub struct AudioGenerationModel {
590 client: Client,
591 pub language: String,
592 }
593
594 impl AudioGenerationModel {
595 pub(crate) fn new(client: Client, language: &str) -> AudioGenerationModel {
596 Self {
597 client,
598 language: language.to_string(),
599 }
600 }
601 }
602
603 #[derive(Clone, Deserialize)]
604 pub struct AudioGenerationResponse {
605 audio: String,
606 }
607
608 impl TryFrom<AudioGenerationResponse>
609 for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
610 {
611 type Error = AudioGenerationError;
612
613 fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
614 let data = BASE64_STANDARD
615 .decode(&value.audio)
616 .expect("Could not decode audio.");
617
618 Ok(Self {
619 audio: data,
620 response: value,
621 })
622 }
623 }
624
625 impl audio_generation::AudioGenerationModel for AudioGenerationModel {
626 type Response = AudioGenerationResponse;
627
628 #[cfg_attr(feature = "worker", worker::send)]
629 async fn audio_generation(
630 &self,
631 request: AudioGenerationRequest,
632 ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
633 {
634 let request = json!({
635 "language": self.language,
636 "speaker": request.voice,
637 "text": request.text,
638 "speed": request.speed
639 });
640
641 let response = self
642 .client
643 .post("/audio/generation")
644 .json(&request)
645 .send()
646 .await?;
647
648 if !response.status().is_success() {
649 return Err(AudioGenerationError::ProviderError(format!(
650 "{}: {}",
651 response.status(),
652 response.text().await?
653 )));
654 }
655
656 match serde_json::from_str::<ApiResponse<AudioGenerationResponse>>(
657 &response.text().await?,
658 )? {
659 ApiResponse::Ok(response) => response.try_into(),
660 ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
661 }
662 }
663 }
664 impl AudioGenerationClient for Client {
665 type AudioGenerationModel = AudioGenerationModel;
666
667 fn audio_generation_model(&self, language: &str) -> AudioGenerationModel {
679 AudioGenerationModel::new(self.clone(), language)
680 }
681 }
682}