1use super::openai::{AssistantContent, send_compatible_streaming_request};
13
14use crate::client::{CompletionClient, ProviderClient};
15use crate::json_utils::merge_inplace;
16use crate::message;
17use crate::streaming::StreamingCompletionResponse;
18
19use crate::impl_conversion_traits;
20use crate::providers::openai;
21use crate::{
22 OneOrMany,
23 completion::{self, CompletionError, CompletionRequest},
24 json_utils,
25 providers::openai::Message,
26};
27use serde::Deserialize;
28use serde_json::{Value, json};
29
30const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz/v1";
34
35#[derive(Clone)]
36pub struct Client {
37 base_url: String,
38 api_key: String,
39 http_client: reqwest::Client,
40}
41
42impl std::fmt::Debug for Client {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 f.debug_struct("Client")
45 .field("base_url", &self.base_url)
46 .field("http_client", &self.http_client)
47 .field("api_key", &"<REDACTED>")
48 .finish()
49 }
50}
51
52impl Client {
53 pub fn new(api_key: &str) -> Self {
55 Self::from_url(api_key, HYPERBOLIC_API_BASE_URL)
56 }
57
58 pub fn from_url(api_key: &str, base_url: &str) -> Self {
60 Self {
61 base_url: base_url.to_string(),
62 api_key: api_key.to_string(),
63 http_client: reqwest::Client::builder()
64 .build()
65 .expect("OpenAI reqwest client should build"),
66 }
67 }
68
69 pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
72 self.http_client = client;
73
74 self
75 }
76
77 fn post(&self, path: &str) -> reqwest::RequestBuilder {
78 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
79 self.http_client.post(url).bearer_auth(&self.api_key)
80 }
81}
82
83impl ProviderClient for Client {
84 fn from_env() -> Self {
87 let api_key = std::env::var("HYPERBOLIC_API_KEY").expect("HYPERBOLIC_API_KEY not set");
88 Self::new(&api_key)
89 }
90}
91
92impl CompletionClient for Client {
93 type CompletionModel = CompletionModel;
94
95 fn completion_model(&self, model: &str) -> CompletionModel {
107 CompletionModel::new(self.clone(), model)
108 }
109}
110
111impl_conversion_traits!(
112 AsEmbeddings,
113 AsTranscription for Client
114);
115
116#[derive(Debug, Deserialize)]
117struct ApiErrorResponse {
118 message: String,
119}
120
121#[derive(Debug, Deserialize)]
122#[serde(untagged)]
123enum ApiResponse<T> {
124 Ok(T),
125 Err(ApiErrorResponse),
126}
127
128#[derive(Debug, Deserialize)]
129pub struct EmbeddingData {
130 pub object: String,
131 pub embedding: Vec<f64>,
132 pub index: usize,
133}
134
135#[derive(Clone, Debug, Deserialize)]
136pub struct Usage {
137 pub prompt_tokens: usize,
138 pub total_tokens: usize,
139}
140
141impl std::fmt::Display for Usage {
142 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143 write!(
144 f,
145 "Prompt tokens: {} Total tokens: {}",
146 self.prompt_tokens, self.total_tokens
147 )
148 }
149}
150
151pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
156pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
158pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
160pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
162pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
164pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
166pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
168pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
170pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
172pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
174pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
176pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
178
179#[derive(Debug, Deserialize)]
183pub struct CompletionResponse {
184 pub id: String,
185 pub object: String,
186 pub created: u64,
187 pub model: String,
188 pub choices: Vec<Choice>,
189 pub usage: Option<Usage>,
190}
191
192impl From<ApiErrorResponse> for CompletionError {
193 fn from(err: ApiErrorResponse) -> Self {
194 CompletionError::ProviderError(err.message)
195 }
196}
197
198impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
199 type Error = CompletionError;
200
201 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
202 let choice = response.choices.first().ok_or_else(|| {
203 CompletionError::ResponseError("Response contained no choices".to_owned())
204 })?;
205
206 let content = match &choice.message {
207 Message::Assistant {
208 content,
209 tool_calls,
210 ..
211 } => {
212 let mut content = content
213 .iter()
214 .map(|c| match c {
215 AssistantContent::Text { text } => completion::AssistantContent::text(text),
216 AssistantContent::Refusal { refusal } => {
217 completion::AssistantContent::text(refusal)
218 }
219 })
220 .collect::<Vec<_>>();
221
222 content.extend(
223 tool_calls
224 .iter()
225 .map(|call| {
226 completion::AssistantContent::tool_call(
227 &call.id,
228 &call.function.name,
229 call.function.arguments.clone(),
230 )
231 })
232 .collect::<Vec<_>>(),
233 );
234 Ok(content)
235 }
236 _ => Err(CompletionError::ResponseError(
237 "Response did not contain a valid message or tool call".into(),
238 )),
239 }?;
240
241 let choice = OneOrMany::many(content).map_err(|_| {
242 CompletionError::ResponseError(
243 "Response contained no message or tool call (empty)".to_owned(),
244 )
245 })?;
246
247 Ok(completion::CompletionResponse {
248 choice,
249 raw_response: response,
250 })
251 }
252}
253
254#[derive(Debug, Deserialize)]
255pub struct Choice {
256 pub index: usize,
257 pub message: Message,
258 pub finish_reason: String,
259}
260
261#[derive(Clone)]
262pub struct CompletionModel {
263 client: Client,
264 pub model: String,
266}
267
268impl CompletionModel {
269 pub(crate) fn create_completion_request(
270 &self,
271 completion_request: CompletionRequest,
272 ) -> Result<Value, CompletionError> {
273 let mut partial_history = vec![];
275 if let Some(docs) = completion_request.normalized_documents() {
276 partial_history.push(docs);
277 }
278 partial_history.extend(completion_request.chat_history);
279
280 let mut full_history: Vec<Message> = completion_request
282 .preamble
283 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
284
285 full_history.extend(
287 partial_history
288 .into_iter()
289 .map(message::Message::try_into)
290 .collect::<Result<Vec<Vec<Message>>, _>>()?
291 .into_iter()
292 .flatten()
293 .collect::<Vec<_>>(),
294 );
295
296 let request = json!({
297 "model": self.model,
298 "messages": full_history,
299 "temperature": completion_request.temperature,
300 });
301
302 let request = if let Some(params) = completion_request.additional_params {
303 json_utils::merge(request, params)
304 } else {
305 request
306 };
307
308 Ok(request)
309 }
310}
311
312impl CompletionModel {
313 pub fn new(client: Client, model: &str) -> Self {
314 Self {
315 client,
316 model: model.to_string(),
317 }
318 }
319}
320
321impl completion::CompletionModel for CompletionModel {
322 type Response = CompletionResponse;
323 type StreamingResponse = openai::StreamingCompletionResponse;
324
325 #[cfg_attr(feature = "worker", worker::send)]
326 async fn completion(
327 &self,
328 completion_request: CompletionRequest,
329 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
330 let request = self.create_completion_request(completion_request)?;
331
332 let response = self
333 .client
334 .post("/chat/completions")
335 .json(&request)
336 .send()
337 .await?;
338
339 if response.status().is_success() {
340 match response.json::<ApiResponse<CompletionResponse>>().await? {
341 ApiResponse::Ok(response) => {
342 tracing::info!(target: "rig",
343 "Hyperbolic completion token usage: {:?}",
344 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
345 );
346
347 response.try_into()
348 }
349 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
350 }
351 } else {
352 Err(CompletionError::ProviderError(response.text().await?))
353 }
354 }
355
356 #[cfg_attr(feature = "worker", worker::send)]
357 async fn stream(
358 &self,
359 completion_request: CompletionRequest,
360 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
361 let mut request = self.create_completion_request(completion_request)?;
362
363 merge_inplace(
364 &mut request,
365 json!({"stream": true, "stream_options": {"include_usage": true}}),
366 );
367
368 let builder = self.client.post("/chat/completions").json(&request);
369
370 send_compatible_streaming_request(builder).await
371 }
372}
373
374#[cfg(feature = "image")]
379pub use image_generation::*;
380
381#[cfg(feature = "image")]
382mod image_generation {
383 use super::{ApiResponse, Client};
384 use crate::client::ImageGenerationClient;
385 use crate::image_generation;
386 use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
387 use crate::json_utils::merge_inplace;
388 use base64::Engine;
389 use base64::prelude::BASE64_STANDARD;
390 use serde::Deserialize;
391 use serde_json::json;
392
393 pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
394 pub const SD2: &str = "SD2";
395 pub const SD1_5: &str = "SD1.5";
396 pub const SSD: &str = "SSD";
397 pub const SDXL_TURBO: &str = "SDXL-turbo";
398 pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
399 pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
400
401 #[cfg(feature = "image")]
402 #[derive(Clone)]
403 pub struct ImageGenerationModel {
404 client: Client,
405 pub model: String,
406 }
407
408 #[cfg(feature = "image")]
409 impl ImageGenerationModel {
410 pub(crate) fn new(client: Client, model: &str) -> ImageGenerationModel {
411 Self {
412 client,
413 model: model.to_string(),
414 }
415 }
416 }
417
418 #[cfg(feature = "image")]
419 #[derive(Clone, Deserialize)]
420 pub struct Image {
421 image: String,
422 }
423
424 #[cfg(feature = "image")]
425 #[derive(Clone, Deserialize)]
426 pub struct ImageGenerationResponse {
427 images: Vec<Image>,
428 }
429
430 #[cfg(feature = "image")]
431 impl TryFrom<ImageGenerationResponse>
432 for image_generation::ImageGenerationResponse<ImageGenerationResponse>
433 {
434 type Error = ImageGenerationError;
435
436 fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
437 let data = BASE64_STANDARD
438 .decode(&value.images[0].image)
439 .expect("Could not decode image.");
440
441 Ok(Self {
442 image: data,
443 response: value,
444 })
445 }
446 }
447
448 #[cfg(feature = "image")]
449 impl image_generation::ImageGenerationModel for ImageGenerationModel {
450 type Response = ImageGenerationResponse;
451
452 async fn image_generation(
453 &self,
454 generation_request: ImageGenerationRequest,
455 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
456 {
457 let mut request = json!({
458 "model_name": self.model,
459 "prompt": generation_request.prompt,
460 "height": generation_request.height,
461 "width": generation_request.width,
462 });
463
464 if let Some(params) = generation_request.additional_params {
465 merge_inplace(&mut request, params);
466 }
467
468 let response = self
469 .client
470 .post("/image/generation")
471 .json(&request)
472 .send()
473 .await?;
474
475 if !response.status().is_success() {
476 return Err(ImageGenerationError::ProviderError(format!(
477 "{}: {}",
478 response.status().as_str(),
479 response.text().await?
480 )));
481 }
482
483 match response
484 .json::<ApiResponse<ImageGenerationResponse>>()
485 .await?
486 {
487 ApiResponse::Ok(response) => response.try_into(),
488 ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
489 }
490 }
491 }
492
493 impl ImageGenerationClient for Client {
494 type ImageGenerationModel = ImageGenerationModel;
495
496 fn image_generation_model(&self, model: &str) -> ImageGenerationModel {
508 ImageGenerationModel::new(self.clone(), model)
509 }
510 }
511}
512
513#[cfg(feature = "audio")]
517pub use audio_generation::*;
518
519#[cfg(feature = "audio")]
520mod audio_generation {
521 use super::{ApiResponse, Client};
522 use crate::audio_generation;
523 use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
524 use crate::client::AudioGenerationClient;
525 use base64::Engine;
526 use base64::prelude::BASE64_STANDARD;
527 use serde::Deserialize;
528 use serde_json::json;
529
530 #[derive(Clone)]
531 pub struct AudioGenerationModel {
532 client: Client,
533 pub language: String,
534 }
535
536 impl AudioGenerationModel {
537 pub(crate) fn new(client: Client, language: &str) -> AudioGenerationModel {
538 Self {
539 client,
540 language: language.to_string(),
541 }
542 }
543 }
544
545 #[derive(Clone, Deserialize)]
546 pub struct AudioGenerationResponse {
547 audio: String,
548 }
549
550 impl TryFrom<AudioGenerationResponse>
551 for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
552 {
553 type Error = AudioGenerationError;
554
555 fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
556 let data = BASE64_STANDARD
557 .decode(&value.audio)
558 .expect("Could not decode audio.");
559
560 Ok(Self {
561 audio: data,
562 response: value,
563 })
564 }
565 }
566
567 impl audio_generation::AudioGenerationModel for AudioGenerationModel {
568 type Response = AudioGenerationResponse;
569
570 async fn audio_generation(
571 &self,
572 request: AudioGenerationRequest,
573 ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
574 {
575 let request = json!({
576 "language": self.language,
577 "speaker": request.voice,
578 "text": request.text,
579 "speed": request.speed
580 });
581
582 let response = self
583 .client
584 .post("/audio/generation")
585 .json(&request)
586 .send()
587 .await?;
588
589 if !response.status().is_success() {
590 return Err(AudioGenerationError::ProviderError(format!(
591 "{}: {}",
592 response.status(),
593 response.text().await?
594 )));
595 }
596
597 match serde_json::from_str::<ApiResponse<AudioGenerationResponse>>(
598 &response.text().await?,
599 )? {
600 ApiResponse::Ok(response) => response.try_into(),
601 ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
602 }
603 }
604 }
605 impl AudioGenerationClient for Client {
606 type AudioGenerationModel = AudioGenerationModel;
607
608 fn audio_generation_model(&self, language: &str) -> AudioGenerationModel {
620 AudioGenerationModel::new(self.clone(), language)
621 }
622 }
623}