1use super::openai::{AssistantContent, send_compatible_streaming_request};
13
14use crate::client::{CompletionClient, ProviderClient};
15use crate::json_utils::merge_inplace;
16use crate::message;
17use crate::streaming::StreamingCompletionResponse;
18
19use crate::impl_conversion_traits;
20use crate::providers::openai;
21use crate::{
22 OneOrMany,
23 completion::{self, CompletionError, CompletionRequest},
24 json_utils,
25 providers::openai::Message,
26};
27use serde::Deserialize;
28use serde_json::{Value, json};
29
30const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz/v1";
34
35#[derive(Clone)]
36pub struct Client {
37 base_url: String,
38 api_key: String,
39 http_client: reqwest::Client,
40}
41
42impl std::fmt::Debug for Client {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 f.debug_struct("Client")
45 .field("base_url", &self.base_url)
46 .field("http_client", &self.http_client)
47 .field("api_key", &"<REDACTED>")
48 .finish()
49 }
50}
51
52impl Client {
53 pub fn new(api_key: &str) -> Self {
55 Self::from_url(api_key, HYPERBOLIC_API_BASE_URL)
56 }
57
58 pub fn from_url(api_key: &str, base_url: &str) -> Self {
60 Self {
61 base_url: base_url.to_string(),
62 api_key: api_key.to_string(),
63 http_client: reqwest::Client::builder()
64 .build()
65 .expect("OpenAI reqwest client should build"),
66 }
67 }
68
69 pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
72 self.http_client = client;
73
74 self
75 }
76
77 fn post(&self, path: &str) -> reqwest::RequestBuilder {
78 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
79 self.http_client.post(url).bearer_auth(&self.api_key)
80 }
81}
82
83impl ProviderClient for Client {
84 fn from_env() -> Self {
87 let api_key = std::env::var("HYPERBOLIC_API_KEY").expect("HYPERBOLIC_API_KEY not set");
88 Self::new(&api_key)
89 }
90
91 fn from_val(input: crate::client::ProviderValue) -> Self {
92 let crate::client::ProviderValue::Simple(api_key) = input else {
93 panic!("Incorrect provider value type")
94 };
95 Self::new(&api_key)
96 }
97}
98
99impl CompletionClient for Client {
100 type CompletionModel = CompletionModel;
101
102 fn completion_model(&self, model: &str) -> CompletionModel {
114 CompletionModel::new(self.clone(), model)
115 }
116}
117
118impl_conversion_traits!(
119 AsEmbeddings,
120 AsTranscription for Client
121);
122
123#[derive(Debug, Deserialize)]
124struct ApiErrorResponse {
125 message: String,
126}
127
128#[derive(Debug, Deserialize)]
129#[serde(untagged)]
130enum ApiResponse<T> {
131 Ok(T),
132 Err(ApiErrorResponse),
133}
134
135#[derive(Debug, Deserialize)]
136pub struct EmbeddingData {
137 pub object: String,
138 pub embedding: Vec<f64>,
139 pub index: usize,
140}
141
142#[derive(Clone, Debug, Deserialize)]
143pub struct Usage {
144 pub prompt_tokens: usize,
145 pub total_tokens: usize,
146}
147
148impl std::fmt::Display for Usage {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 write!(
151 f,
152 "Prompt tokens: {} Total tokens: {}",
153 self.prompt_tokens, self.total_tokens
154 )
155 }
156}
157
158pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
163pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
165pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
167pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
169pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
171pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
173pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
175pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
177pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
179pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
181pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
183pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
185
186#[derive(Debug, Deserialize)]
190pub struct CompletionResponse {
191 pub id: String,
192 pub object: String,
193 pub created: u64,
194 pub model: String,
195 pub choices: Vec<Choice>,
196 pub usage: Option<Usage>,
197}
198
199impl From<ApiErrorResponse> for CompletionError {
200 fn from(err: ApiErrorResponse) -> Self {
201 CompletionError::ProviderError(err.message)
202 }
203}
204
205impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
206 type Error = CompletionError;
207
208 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
209 let choice = response.choices.first().ok_or_else(|| {
210 CompletionError::ResponseError("Response contained no choices".to_owned())
211 })?;
212
213 let content = match &choice.message {
214 Message::Assistant {
215 content,
216 tool_calls,
217 ..
218 } => {
219 let mut content = content
220 .iter()
221 .map(|c| match c {
222 AssistantContent::Text { text } => completion::AssistantContent::text(text),
223 AssistantContent::Refusal { refusal } => {
224 completion::AssistantContent::text(refusal)
225 }
226 })
227 .collect::<Vec<_>>();
228
229 content.extend(
230 tool_calls
231 .iter()
232 .map(|call| {
233 completion::AssistantContent::tool_call(
234 &call.id,
235 &call.function.name,
236 call.function.arguments.clone(),
237 )
238 })
239 .collect::<Vec<_>>(),
240 );
241 Ok(content)
242 }
243 _ => Err(CompletionError::ResponseError(
244 "Response did not contain a valid message or tool call".into(),
245 )),
246 }?;
247
248 let choice = OneOrMany::many(content).map_err(|_| {
249 CompletionError::ResponseError(
250 "Response contained no message or tool call (empty)".to_owned(),
251 )
252 })?;
253
254 let usage = response
255 .usage
256 .as_ref()
257 .map(|usage| completion::Usage {
258 input_tokens: usage.prompt_tokens as u64,
259 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
260 total_tokens: usage.total_tokens as u64,
261 })
262 .unwrap_or_default();
263
264 Ok(completion::CompletionResponse {
265 choice,
266 usage,
267 raw_response: response,
268 })
269 }
270}
271
272#[derive(Debug, Deserialize)]
273pub struct Choice {
274 pub index: usize,
275 pub message: Message,
276 pub finish_reason: String,
277}
278
279#[derive(Clone)]
280pub struct CompletionModel {
281 client: Client,
282 pub model: String,
284}
285
286impl CompletionModel {
287 pub(crate) fn create_completion_request(
288 &self,
289 completion_request: CompletionRequest,
290 ) -> Result<Value, CompletionError> {
291 let mut partial_history = vec![];
293 if let Some(docs) = completion_request.normalized_documents() {
294 partial_history.push(docs);
295 }
296 partial_history.extend(completion_request.chat_history);
297
298 let mut full_history: Vec<Message> = completion_request
300 .preamble
301 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
302
303 full_history.extend(
305 partial_history
306 .into_iter()
307 .map(message::Message::try_into)
308 .collect::<Result<Vec<Vec<Message>>, _>>()?
309 .into_iter()
310 .flatten()
311 .collect::<Vec<_>>(),
312 );
313
314 let request = json!({
315 "model": self.model,
316 "messages": full_history,
317 "temperature": completion_request.temperature,
318 });
319
320 let request = if let Some(params) = completion_request.additional_params {
321 json_utils::merge(request, params)
322 } else {
323 request
324 };
325
326 Ok(request)
327 }
328}
329
330impl CompletionModel {
331 pub fn new(client: Client, model: &str) -> Self {
332 Self {
333 client,
334 model: model.to_string(),
335 }
336 }
337}
338
339impl completion::CompletionModel for CompletionModel {
340 type Response = CompletionResponse;
341 type StreamingResponse = openai::StreamingCompletionResponse;
342
343 #[cfg_attr(feature = "worker", worker::send)]
344 async fn completion(
345 &self,
346 completion_request: CompletionRequest,
347 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
348 let request = self.create_completion_request(completion_request)?;
349
350 let response = self
351 .client
352 .post("/chat/completions")
353 .json(&request)
354 .send()
355 .await?;
356
357 if response.status().is_success() {
358 match response.json::<ApiResponse<CompletionResponse>>().await? {
359 ApiResponse::Ok(response) => {
360 tracing::info!(target: "rig",
361 "Hyperbolic completion token usage: {:?}",
362 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
363 );
364
365 response.try_into()
366 }
367 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
368 }
369 } else {
370 Err(CompletionError::ProviderError(response.text().await?))
371 }
372 }
373
374 #[cfg_attr(feature = "worker", worker::send)]
375 async fn stream(
376 &self,
377 completion_request: CompletionRequest,
378 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
379 let mut request = self.create_completion_request(completion_request)?;
380
381 merge_inplace(
382 &mut request,
383 json!({"stream": true, "stream_options": {"include_usage": true}}),
384 );
385
386 let builder = self.client.post("/chat/completions").json(&request);
387
388 send_compatible_streaming_request(builder).await
389 }
390}
391
392#[cfg(feature = "image")]
397pub use image_generation::*;
398
399#[cfg(feature = "image")]
400mod image_generation {
401 use super::{ApiResponse, Client};
402 use crate::client::ImageGenerationClient;
403 use crate::image_generation;
404 use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
405 use crate::json_utils::merge_inplace;
406 use base64::Engine;
407 use base64::prelude::BASE64_STANDARD;
408 use serde::Deserialize;
409 use serde_json::json;
410
411 pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
412 pub const SD2: &str = "SD2";
413 pub const SD1_5: &str = "SD1.5";
414 pub const SSD: &str = "SSD";
415 pub const SDXL_TURBO: &str = "SDXL-turbo";
416 pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
417 pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
418
419 #[cfg(feature = "image")]
420 #[derive(Clone)]
421 pub struct ImageGenerationModel {
422 client: Client,
423 pub model: String,
424 }
425
426 #[cfg(feature = "image")]
427 impl ImageGenerationModel {
428 pub(crate) fn new(client: Client, model: &str) -> ImageGenerationModel {
429 Self {
430 client,
431 model: model.to_string(),
432 }
433 }
434 }
435
436 #[cfg(feature = "image")]
437 #[derive(Clone, Deserialize)]
438 pub struct Image {
439 image: String,
440 }
441
442 #[cfg(feature = "image")]
443 #[derive(Clone, Deserialize)]
444 pub struct ImageGenerationResponse {
445 images: Vec<Image>,
446 }
447
448 #[cfg(feature = "image")]
449 impl TryFrom<ImageGenerationResponse>
450 for image_generation::ImageGenerationResponse<ImageGenerationResponse>
451 {
452 type Error = ImageGenerationError;
453
454 fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
455 let data = BASE64_STANDARD
456 .decode(&value.images[0].image)
457 .expect("Could not decode image.");
458
459 Ok(Self {
460 image: data,
461 response: value,
462 })
463 }
464 }
465
466 #[cfg(feature = "image")]
467 impl image_generation::ImageGenerationModel for ImageGenerationModel {
468 type Response = ImageGenerationResponse;
469
470 async fn image_generation(
471 &self,
472 generation_request: ImageGenerationRequest,
473 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
474 {
475 let mut request = json!({
476 "model_name": self.model,
477 "prompt": generation_request.prompt,
478 "height": generation_request.height,
479 "width": generation_request.width,
480 });
481
482 if let Some(params) = generation_request.additional_params {
483 merge_inplace(&mut request, params);
484 }
485
486 let response = self
487 .client
488 .post("/image/generation")
489 .json(&request)
490 .send()
491 .await?;
492
493 if !response.status().is_success() {
494 return Err(ImageGenerationError::ProviderError(format!(
495 "{}: {}",
496 response.status().as_str(),
497 response.text().await?
498 )));
499 }
500
501 match response
502 .json::<ApiResponse<ImageGenerationResponse>>()
503 .await?
504 {
505 ApiResponse::Ok(response) => response.try_into(),
506 ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
507 }
508 }
509 }
510
511 impl ImageGenerationClient for Client {
512 type ImageGenerationModel = ImageGenerationModel;
513
514 fn image_generation_model(&self, model: &str) -> ImageGenerationModel {
526 ImageGenerationModel::new(self.clone(), model)
527 }
528 }
529}
530
531#[cfg(feature = "audio")]
535pub use audio_generation::*;
536
537#[cfg(feature = "audio")]
538mod audio_generation {
539 use super::{ApiResponse, Client};
540 use crate::audio_generation;
541 use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
542 use crate::client::AudioGenerationClient;
543 use base64::Engine;
544 use base64::prelude::BASE64_STANDARD;
545 use serde::Deserialize;
546 use serde_json::json;
547
548 #[derive(Clone)]
549 pub struct AudioGenerationModel {
550 client: Client,
551 pub language: String,
552 }
553
554 impl AudioGenerationModel {
555 pub(crate) fn new(client: Client, language: &str) -> AudioGenerationModel {
556 Self {
557 client,
558 language: language.to_string(),
559 }
560 }
561 }
562
563 #[derive(Clone, Deserialize)]
564 pub struct AudioGenerationResponse {
565 audio: String,
566 }
567
568 impl TryFrom<AudioGenerationResponse>
569 for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
570 {
571 type Error = AudioGenerationError;
572
573 fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
574 let data = BASE64_STANDARD
575 .decode(&value.audio)
576 .expect("Could not decode audio.");
577
578 Ok(Self {
579 audio: data,
580 response: value,
581 })
582 }
583 }
584
585 impl audio_generation::AudioGenerationModel for AudioGenerationModel {
586 type Response = AudioGenerationResponse;
587
588 async fn audio_generation(
589 &self,
590 request: AudioGenerationRequest,
591 ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
592 {
593 let request = json!({
594 "language": self.language,
595 "speaker": request.voice,
596 "text": request.text,
597 "speed": request.speed
598 });
599
600 let response = self
601 .client
602 .post("/audio/generation")
603 .json(&request)
604 .send()
605 .await?;
606
607 if !response.status().is_success() {
608 return Err(AudioGenerationError::ProviderError(format!(
609 "{}: {}",
610 response.status(),
611 response.text().await?
612 )));
613 }
614
615 match serde_json::from_str::<ApiResponse<AudioGenerationResponse>>(
616 &response.text().await?,
617 )? {
618 ApiResponse::Ok(response) => response.try_into(),
619 ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
620 }
621 }
622 }
623 impl AudioGenerationClient for Client {
624 type AudioGenerationModel = AudioGenerationModel;
625
626 fn audio_generation_model(&self, language: &str) -> AudioGenerationModel {
638 AudioGenerationModel::new(self.clone(), language)
639 }
640 }
641}