1use super::openai::{send_compatible_streaming_request, AssistantContent};
13
14use crate::client::{CompletionClient, ProviderClient};
15use crate::json_utils::merge_inplace;
16use crate::message;
17use crate::streaming::StreamingCompletionResponse;
18
19use crate::impl_conversion_traits;
20use crate::providers::openai;
21use crate::{
22 completion::{self, CompletionError, CompletionRequest},
23 json_utils,
24 providers::openai::Message,
25 OneOrMany,
26};
27use serde::Deserialize;
28use serde_json::{json, Value};
29
30const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz/v1";
34
35#[derive(Clone, Debug)]
36pub struct Client {
37 base_url: String,
38 http_client: reqwest::Client,
39}
40
41impl Client {
42 pub fn new(api_key: &str) -> Self {
44 Self::from_url(api_key, HYPERBOLIC_API_BASE_URL)
45 }
46
47 pub fn from_url(api_key: &str, base_url: &str) -> Self {
49 Self {
50 base_url: base_url.to_string(),
51 http_client: reqwest::Client::builder()
52 .default_headers({
53 let mut headers = reqwest::header::HeaderMap::new();
54 headers.insert(
55 "Authorization",
56 format!("Bearer {api_key}")
57 .parse()
58 .expect("Bearer token should parse"),
59 );
60 headers
61 })
62 .build()
63 .expect("OpenAI reqwest client should build"),
64 }
65 }
66
67 fn post(&self, path: &str) -> reqwest::RequestBuilder {
68 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
69 self.http_client.post(url)
70 }
71}
72
73impl ProviderClient for Client {
74 fn from_env() -> Self {
77 let api_key = std::env::var("HYPERBOLIC_API_KEY").expect("HYPERBOLIC_API_KEY not set");
78 Self::new(&api_key)
79 }
80}
81
82impl CompletionClient for Client {
83 type CompletionModel = CompletionModel;
84
85 fn completion_model(&self, model: &str) -> CompletionModel {
97 CompletionModel::new(self.clone(), model)
98 }
99}
100
101impl_conversion_traits!(
102 AsEmbeddings,
103 AsTranscription for Client
104);
105
106#[derive(Debug, Deserialize)]
107struct ApiErrorResponse {
108 message: String,
109}
110
111#[derive(Debug, Deserialize)]
112#[serde(untagged)]
113enum ApiResponse<T> {
114 Ok(T),
115 Err(ApiErrorResponse),
116}
117
118#[derive(Debug, Deserialize)]
119pub struct EmbeddingData {
120 pub object: String,
121 pub embedding: Vec<f64>,
122 pub index: usize,
123}
124
125#[derive(Clone, Debug, Deserialize)]
126pub struct Usage {
127 pub prompt_tokens: usize,
128 pub total_tokens: usize,
129}
130
131impl std::fmt::Display for Usage {
132 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133 write!(
134 f,
135 "Prompt tokens: {} Total tokens: {}",
136 self.prompt_tokens, self.total_tokens
137 )
138 }
139}
140
141pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
146pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
148pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
150pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
152pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
154pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
156pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
158pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
160pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
162pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
164pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
166pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
168
169#[derive(Debug, Deserialize)]
173pub struct CompletionResponse {
174 pub id: String,
175 pub object: String,
176 pub created: u64,
177 pub model: String,
178 pub choices: Vec<Choice>,
179 pub usage: Option<Usage>,
180}
181
182impl From<ApiErrorResponse> for CompletionError {
183 fn from(err: ApiErrorResponse) -> Self {
184 CompletionError::ProviderError(err.message)
185 }
186}
187
188impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
189 type Error = CompletionError;
190
191 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
192 let choice = response.choices.first().ok_or_else(|| {
193 CompletionError::ResponseError("Response contained no choices".to_owned())
194 })?;
195
196 let content = match &choice.message {
197 Message::Assistant {
198 content,
199 tool_calls,
200 ..
201 } => {
202 let mut content = content
203 .iter()
204 .map(|c| match c {
205 AssistantContent::Text { text } => completion::AssistantContent::text(text),
206 AssistantContent::Refusal { refusal } => {
207 completion::AssistantContent::text(refusal)
208 }
209 })
210 .collect::<Vec<_>>();
211
212 content.extend(
213 tool_calls
214 .iter()
215 .map(|call| {
216 completion::AssistantContent::tool_call(
217 &call.id,
218 &call.function.name,
219 call.function.arguments.clone(),
220 )
221 })
222 .collect::<Vec<_>>(),
223 );
224 Ok(content)
225 }
226 _ => Err(CompletionError::ResponseError(
227 "Response did not contain a valid message or tool call".into(),
228 )),
229 }?;
230
231 let choice = OneOrMany::many(content).map_err(|_| {
232 CompletionError::ResponseError(
233 "Response contained no message or tool call (empty)".to_owned(),
234 )
235 })?;
236
237 Ok(completion::CompletionResponse {
238 choice,
239 raw_response: response,
240 })
241 }
242}
243
244#[derive(Debug, Deserialize)]
245pub struct Choice {
246 pub index: usize,
247 pub message: Message,
248 pub finish_reason: String,
249}
250
251#[derive(Clone)]
252pub struct CompletionModel {
253 client: Client,
254 pub model: String,
256}
257
258impl CompletionModel {
259 pub(crate) fn create_completion_request(
260 &self,
261 completion_request: CompletionRequest,
262 ) -> Result<Value, CompletionError> {
263 let mut partial_history = vec![];
265 if let Some(docs) = completion_request.normalized_documents() {
266 partial_history.push(docs);
267 }
268 partial_history.extend(completion_request.chat_history);
269
270 let mut full_history: Vec<Message> = completion_request
272 .preamble
273 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
274
275 full_history.extend(
277 partial_history
278 .into_iter()
279 .map(message::Message::try_into)
280 .collect::<Result<Vec<Vec<Message>>, _>>()?
281 .into_iter()
282 .flatten()
283 .collect::<Vec<_>>(),
284 );
285
286 let request = json!({
287 "model": self.model,
288 "messages": full_history,
289 "temperature": completion_request.temperature,
290 });
291
292 let request = if let Some(params) = completion_request.additional_params {
293 json_utils::merge(request, params)
294 } else {
295 request
296 };
297
298 Ok(request)
299 }
300}
301
302impl CompletionModel {
303 pub fn new(client: Client, model: &str) -> Self {
304 Self {
305 client,
306 model: model.to_string(),
307 }
308 }
309}
310
311impl completion::CompletionModel for CompletionModel {
312 type Response = CompletionResponse;
313 type StreamingResponse = openai::StreamingCompletionResponse;
314
315 #[cfg_attr(feature = "worker", worker::send)]
316 async fn completion(
317 &self,
318 completion_request: CompletionRequest,
319 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
320 let request = self.create_completion_request(completion_request)?;
321
322 let response = self
323 .client
324 .post("/chat/completions")
325 .json(&request)
326 .send()
327 .await?;
328
329 if response.status().is_success() {
330 match response.json::<ApiResponse<CompletionResponse>>().await? {
331 ApiResponse::Ok(response) => {
332 tracing::info!(target: "rig",
333 "Hyperbolic completion token usage: {:?}",
334 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
335 );
336
337 response.try_into()
338 }
339 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
340 }
341 } else {
342 Err(CompletionError::ProviderError(response.text().await?))
343 }
344 }
345
346 #[cfg_attr(feature = "worker", worker::send)]
347 async fn stream(
348 &self,
349 completion_request: CompletionRequest,
350 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
351 let mut request = self.create_completion_request(completion_request)?;
352
353 merge_inplace(
354 &mut request,
355 json!({"stream": true, "stream_options": {"include_usage": true}}),
356 );
357
358 let builder = self.client.post("/chat/completions").json(&request);
359
360 send_compatible_streaming_request(builder).await
361 }
362}
363
364#[cfg(feature = "image")]
369pub use image_generation::*;
370
371#[cfg(feature = "image")]
372mod image_generation {
373 use super::{ApiResponse, Client};
374 use crate::client::ImageGenerationClient;
375 use crate::image_generation;
376 use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
377 use crate::json_utils::merge_inplace;
378 use base64::prelude::BASE64_STANDARD;
379 use base64::Engine;
380 use serde::Deserialize;
381 use serde_json::json;
382
383 pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
384 pub const SD2: &str = "SD2";
385 pub const SD1_5: &str = "SD1.5";
386 pub const SSD: &str = "SSD";
387 pub const SDXL_TURBO: &str = "SDXL-turbo";
388 pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
389 pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
390
391 #[cfg(feature = "image")]
392 #[derive(Clone)]
393 pub struct ImageGenerationModel {
394 client: Client,
395 pub model: String,
396 }
397
398 #[cfg(feature = "image")]
399 impl ImageGenerationModel {
400 pub(crate) fn new(client: Client, model: &str) -> ImageGenerationModel {
401 Self {
402 client,
403 model: model.to_string(),
404 }
405 }
406 }
407
408 #[cfg(feature = "image")]
409 #[derive(Clone, Deserialize)]
410 pub struct Image {
411 image: String,
412 }
413
414 #[cfg(feature = "image")]
415 #[derive(Clone, Deserialize)]
416 pub struct ImageGenerationResponse {
417 images: Vec<Image>,
418 }
419
420 #[cfg(feature = "image")]
421 impl TryFrom<ImageGenerationResponse>
422 for image_generation::ImageGenerationResponse<ImageGenerationResponse>
423 {
424 type Error = ImageGenerationError;
425
426 fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
427 let data = BASE64_STANDARD
428 .decode(&value.images[0].image)
429 .expect("Could not decode image.");
430
431 Ok(Self {
432 image: data,
433 response: value,
434 })
435 }
436 }
437
438 #[cfg(feature = "image")]
439 impl image_generation::ImageGenerationModel for ImageGenerationModel {
440 type Response = ImageGenerationResponse;
441
442 async fn image_generation(
443 &self,
444 generation_request: ImageGenerationRequest,
445 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
446 {
447 let mut request = json!({
448 "model_name": self.model,
449 "prompt": generation_request.prompt,
450 "height": generation_request.height,
451 "width": generation_request.width,
452 });
453
454 if let Some(params) = generation_request.additional_params {
455 merge_inplace(&mut request, params);
456 }
457
458 let response = self
459 .client
460 .post("/image/generation")
461 .json(&request)
462 .send()
463 .await?;
464
465 if !response.status().is_success() {
466 return Err(ImageGenerationError::ProviderError(format!(
467 "{}: {}",
468 response.status().as_str(),
469 response.text().await?
470 )));
471 }
472
473 match response
474 .json::<ApiResponse<ImageGenerationResponse>>()
475 .await?
476 {
477 ApiResponse::Ok(response) => response.try_into(),
478 ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
479 }
480 }
481 }
482
483 impl ImageGenerationClient for Client {
484 type ImageGenerationModel = ImageGenerationModel;
485
486 fn image_generation_model(&self, model: &str) -> ImageGenerationModel {
498 ImageGenerationModel::new(self.clone(), model)
499 }
500 }
501}
502
503#[cfg(feature = "audio")]
507pub use audio_generation::*;
508
509#[cfg(feature = "audio")]
510mod audio_generation {
511 use super::{ApiResponse, Client};
512 use crate::audio_generation;
513 use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
514 use crate::client::AudioGenerationClient;
515 use base64::prelude::BASE64_STANDARD;
516 use base64::Engine;
517 use serde::Deserialize;
518 use serde_json::json;
519
520 #[derive(Clone)]
521 pub struct AudioGenerationModel {
522 client: Client,
523 pub language: String,
524 }
525
526 impl AudioGenerationModel {
527 pub(crate) fn new(client: Client, language: &str) -> AudioGenerationModel {
528 Self {
529 client,
530 language: language.to_string(),
531 }
532 }
533 }
534
535 #[derive(Clone, Deserialize)]
536 pub struct AudioGenerationResponse {
537 audio: String,
538 }
539
540 impl TryFrom<AudioGenerationResponse>
541 for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
542 {
543 type Error = AudioGenerationError;
544
545 fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
546 let data = BASE64_STANDARD
547 .decode(&value.audio)
548 .expect("Could not decode audio.");
549
550 Ok(Self {
551 audio: data,
552 response: value,
553 })
554 }
555 }
556
557 impl audio_generation::AudioGenerationModel for AudioGenerationModel {
558 type Response = AudioGenerationResponse;
559
560 async fn audio_generation(
561 &self,
562 request: AudioGenerationRequest,
563 ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
564 {
565 let request = json!({
566 "language": self.language,
567 "speaker": request.voice,
568 "text": request.text,
569 "speed": request.speed
570 });
571
572 let response = self
573 .client
574 .post("/audio/generation")
575 .json(&request)
576 .send()
577 .await?;
578
579 if !response.status().is_success() {
580 return Err(AudioGenerationError::ProviderError(format!(
581 "{}: {}",
582 response.status(),
583 response.text().await?
584 )));
585 }
586
587 match serde_json::from_str::<ApiResponse<AudioGenerationResponse>>(
588 &response.text().await?,
589 )? {
590 ApiResponse::Ok(response) => response.try_into(),
591 ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
592 }
593 }
594 }
595 impl AudioGenerationClient for Client {
596 type AudioGenerationModel = AudioGenerationModel;
597
598 fn audio_generation_model(&self, language: &str) -> AudioGenerationModel {
610 AudioGenerationModel::new(self.clone(), language)
611 }
612 }
613}