1use super::openai::{send_compatible_streaming_request, AssistantContent};
13
14use crate::json_utils::merge_inplace;
15use crate::message;
16use crate::streaming::{StreamingCompletionModel, StreamingResult};
17use crate::{
18 agent::AgentBuilder,
19 completion::{self, CompletionError, CompletionRequest},
20 extractor::ExtractorBuilder,
21 json_utils,
22 providers::openai::Message,
23 OneOrMany,
24};
25
26use schemars::JsonSchema;
27use serde::{Deserialize, Serialize};
28use serde_json::{json, Value};
29
30const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz/v1";
34
35#[derive(Clone)]
36pub struct Client {
37 base_url: String,
38 http_client: reqwest::Client,
39}
40
41impl Client {
42 pub fn new(api_key: &str) -> Self {
44 Self::from_url(api_key, HYPERBOLIC_API_BASE_URL)
45 }
46
47 pub fn from_url(api_key: &str, base_url: &str) -> Self {
49 Self {
50 base_url: base_url.to_string(),
51 http_client: reqwest::Client::builder()
52 .default_headers({
53 let mut headers = reqwest::header::HeaderMap::new();
54 headers.insert(
55 "Authorization",
56 format!("Bearer {}", api_key)
57 .parse()
58 .expect("Bearer token should parse"),
59 );
60 headers
61 })
62 .build()
63 .expect("OpenAI reqwest client should build"),
64 }
65 }
66
67 pub fn from_env() -> Self {
70 let api_key = std::env::var("HYPERBOLIC_API_KEY").expect("HYPERBOLIC_API_KEY not set");
71 Self::new(&api_key)
72 }
73
74 fn post(&self, path: &str) -> reqwest::RequestBuilder {
75 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
76 self.http_client.post(url)
77 }
78
79 pub fn completion_model(&self, model: &str) -> CompletionModel {
91 CompletionModel::new(self.clone(), model)
92 }
93
94 #[cfg(feature = "image")]
106 pub fn image_generation_model(&self, model: &str) -> ImageGenerationModel {
107 ImageGenerationModel::new(self.clone(), model)
108 }
109
110 #[cfg(feature = "audio")]
122 pub fn audio_generation_model(&self, language: &str) -> AudioGenerationModel {
123 AudioGenerationModel::new(self.clone(), language)
124 }
125
126 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
141 AgentBuilder::new(self.completion_model(model))
142 }
143
144 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
146 &self,
147 model: &str,
148 ) -> ExtractorBuilder<T, CompletionModel> {
149 ExtractorBuilder::new(self.completion_model(model))
150 }
151}
152
153#[derive(Debug, Deserialize)]
154struct ApiErrorResponse {
155 message: String,
156}
157
158#[derive(Debug, Deserialize)]
159#[serde(untagged)]
160enum ApiResponse<T> {
161 Ok(T),
162 Err(ApiErrorResponse),
163}
164
165#[derive(Debug, Deserialize)]
166pub struct EmbeddingData {
167 pub object: String,
168 pub embedding: Vec<f64>,
169 pub index: usize,
170}
171
172#[derive(Clone, Debug, Deserialize)]
173pub struct Usage {
174 pub prompt_tokens: usize,
175 pub total_tokens: usize,
176}
177
178impl std::fmt::Display for Usage {
179 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180 write!(
181 f,
182 "Prompt tokens: {} Total tokens: {}",
183 self.prompt_tokens, self.total_tokens
184 )
185 }
186}
187
188pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
193pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
195pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
197pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
199pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
201pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
203pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
205pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
207pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
209pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
211pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
213pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
215
216#[derive(Debug, Deserialize)]
220pub struct CompletionResponse {
221 pub id: String,
222 pub object: String,
223 pub created: u64,
224 pub model: String,
225 pub choices: Vec<Choice>,
226 pub usage: Option<Usage>,
227}
228
229impl From<ApiErrorResponse> for CompletionError {
230 fn from(err: ApiErrorResponse) -> Self {
231 CompletionError::ProviderError(err.message)
232 }
233}
234
235impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
236 type Error = CompletionError;
237
238 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
239 let choice = response.choices.first().ok_or_else(|| {
240 CompletionError::ResponseError("Response contained no choices".to_owned())
241 })?;
242
243 let content = match &choice.message {
244 Message::Assistant {
245 content,
246 tool_calls,
247 ..
248 } => {
249 let mut content = content
250 .iter()
251 .map(|c| match c {
252 AssistantContent::Text { text } => completion::AssistantContent::text(text),
253 AssistantContent::Refusal { refusal } => {
254 completion::AssistantContent::text(refusal)
255 }
256 })
257 .collect::<Vec<_>>();
258
259 content.extend(
260 tool_calls
261 .iter()
262 .map(|call| {
263 completion::AssistantContent::tool_call(
264 &call.id,
265 &call.function.name,
266 call.function.arguments.clone(),
267 )
268 })
269 .collect::<Vec<_>>(),
270 );
271 Ok(content)
272 }
273 _ => Err(CompletionError::ResponseError(
274 "Response did not contain a valid message or tool call".into(),
275 )),
276 }?;
277
278 let choice = OneOrMany::many(content).map_err(|_| {
279 CompletionError::ResponseError(
280 "Response contained no message or tool call (empty)".to_owned(),
281 )
282 })?;
283
284 Ok(completion::CompletionResponse {
285 choice,
286 raw_response: response,
287 })
288 }
289}
290
291#[derive(Debug, Deserialize)]
292pub struct Choice {
293 pub index: usize,
294 pub message: Message,
295 pub finish_reason: String,
296}
297
298#[derive(Clone)]
299pub struct CompletionModel {
300 client: Client,
301 pub model: String,
303}
304
305impl CompletionModel {
306 pub(crate) fn create_completion_request(
307 &self,
308 completion_request: CompletionRequest,
309 ) -> Result<Value, CompletionError> {
310 let mut partial_history = vec![];
312 if let Some(docs) = completion_request.normalized_documents() {
313 partial_history.push(docs);
314 }
315 partial_history.extend(completion_request.chat_history);
316
317 let mut full_history: Vec<Message> = completion_request
319 .preamble
320 .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]);
321
322 full_history.extend(
324 partial_history
325 .into_iter()
326 .map(message::Message::try_into)
327 .collect::<Result<Vec<Vec<Message>>, _>>()?
328 .into_iter()
329 .flatten()
330 .collect::<Vec<_>>(),
331 );
332
333 let request = json!({
334 "model": self.model,
335 "messages": full_history,
336 "temperature": completion_request.temperature,
337 });
338
339 let request = if let Some(params) = completion_request.additional_params {
340 json_utils::merge(request, params)
341 } else {
342 request
343 };
344
345 Ok(request)
346 }
347}
348
349impl CompletionModel {
350 pub fn new(client: Client, model: &str) -> Self {
351 Self {
352 client,
353 model: model.to_string(),
354 }
355 }
356}
357
358impl completion::CompletionModel for CompletionModel {
359 type Response = CompletionResponse;
360
361 #[cfg_attr(feature = "worker", worker::send)]
362 async fn completion(
363 &self,
364 completion_request: CompletionRequest,
365 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
366 let request = self.create_completion_request(completion_request)?;
367
368 let response = self
369 .client
370 .post("/chat/completions")
371 .json(&request)
372 .send()
373 .await?;
374
375 if response.status().is_success() {
376 match response.json::<ApiResponse<CompletionResponse>>().await? {
377 ApiResponse::Ok(response) => {
378 tracing::info!(target: "rig",
379 "Hyperbolic completion token usage: {:?}",
380 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
381 );
382
383 response.try_into()
384 }
385 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
386 }
387 } else {
388 Err(CompletionError::ProviderError(response.text().await?))
389 }
390 }
391}
392
393impl StreamingCompletionModel for CompletionModel {
394 async fn stream(
395 &self,
396 completion_request: CompletionRequest,
397 ) -> Result<StreamingResult, CompletionError> {
398 let mut request = self.create_completion_request(completion_request)?;
399
400 merge_inplace(&mut request, json!({"stream": true}));
401
402 let builder = self.client.post("/chat/completions").json(&request);
403
404 send_compatible_streaming_request(builder).await
405 }
406}
407
408#[cfg(feature = "image")]
413pub use image_generation::*;
414
415#[cfg(feature = "image")]
416mod image_generation {
417 use super::{ApiResponse, Client};
418 use crate::image_generation;
419 use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
420 use crate::json_utils::merge_inplace;
421 use base64::prelude::BASE64_STANDARD;
422 use base64::Engine;
423 use serde::Deserialize;
424 use serde_json::json;
425
426 pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
427 pub const SD2: &str = "SD2";
428 pub const SD1_5: &str = "SD1.5";
429 pub const SSD: &str = "SSD";
430 pub const SDXL_TURBO: &str = "SDXL-turbo";
431 pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
432 pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
433
434 #[cfg(feature = "image")]
435 #[derive(Clone)]
436 pub struct ImageGenerationModel {
437 client: Client,
438 pub model: String,
439 }
440
441 #[cfg(feature = "image")]
442 impl ImageGenerationModel {
443 pub(crate) fn new(client: Client, model: &str) -> ImageGenerationModel {
444 Self {
445 client,
446 model: model.to_string(),
447 }
448 }
449 }
450
451 #[cfg(feature = "image")]
452 #[derive(Clone, Deserialize)]
453 pub struct Image {
454 image: String,
455 }
456
457 #[cfg(feature = "image")]
458 #[derive(Clone, Deserialize)]
459 pub struct ImageGenerationResponse {
460 images: Vec<Image>,
461 }
462
463 #[cfg(feature = "image")]
464 impl TryFrom<ImageGenerationResponse>
465 for image_generation::ImageGenerationResponse<ImageGenerationResponse>
466 {
467 type Error = ImageGenerationError;
468
469 fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
470 let data = BASE64_STANDARD
471 .decode(&value.images[0].image)
472 .expect("Could not decode image.");
473
474 Ok(Self {
475 image: data,
476 response: value,
477 })
478 }
479 }
480
481 #[cfg(feature = "image")]
482 impl image_generation::ImageGenerationModel for ImageGenerationModel {
483 type Response = ImageGenerationResponse;
484
485 async fn image_generation(
486 &self,
487 generation_request: ImageGenerationRequest,
488 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
489 {
490 let mut request = json!({
491 "model_name": self.model,
492 "prompt": generation_request.prompt,
493 "height": generation_request.height,
494 "width": generation_request.width,
495 });
496
497 if let Some(params) = generation_request.additional_params {
498 merge_inplace(&mut request, params);
499 }
500
501 let response = self
502 .client
503 .post("/image/generation")
504 .json(&request)
505 .send()
506 .await?;
507
508 if !response.status().is_success() {
509 return Err(ImageGenerationError::ProviderError(format!(
510 "{}: {}",
511 response.status().as_str(),
512 response.text().await?
513 )));
514 }
515
516 match response
517 .json::<ApiResponse<ImageGenerationResponse>>()
518 .await?
519 {
520 ApiResponse::Ok(response) => response.try_into(),
521 ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
522 }
523 }
524 }
525}
526
527#[cfg(feature = "audio")]
531pub use audio_generation::*;
532#[cfg(feature = "audio")]
533mod audio_generation {
534 use super::{ApiResponse, Client};
535 use crate::audio_generation;
536 use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
537 use base64::prelude::BASE64_STANDARD;
538 use base64::Engine;
539 use serde::Deserialize;
540 use serde_json::json;
541
542 #[derive(Clone)]
543 pub struct AudioGenerationModel {
544 client: Client,
545 pub langauge: String,
546 }
547
548 impl AudioGenerationModel {
549 pub(crate) fn new(client: Client, language: &str) -> AudioGenerationModel {
550 Self {
551 client,
552 langauge: language.to_string(),
553 }
554 }
555 }
556
557 #[derive(Clone, Deserialize)]
558 pub struct AudioGenerationResponse {
559 audio: String,
560 }
561
562 impl TryFrom<AudioGenerationResponse>
563 for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
564 {
565 type Error = AudioGenerationError;
566
567 fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
568 let data = BASE64_STANDARD
569 .decode(&value.audio)
570 .expect("Could not decode audio.");
571
572 Ok(Self {
573 audio: data,
574 response: value,
575 })
576 }
577 }
578
579 impl audio_generation::AudioGenerationModel for AudioGenerationModel {
580 type Response = AudioGenerationResponse;
581
582 async fn audio_generation(
583 &self,
584 request: AudioGenerationRequest,
585 ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
586 {
587 let request = json!({
588 "language": self.langauge,
589 "speaker": request.voice,
590 "text": request.text,
591 "speed": request.speed
592 });
593
594 let response = self
595 .client
596 .post("/audio/generation")
597 .json(&request)
598 .send()
599 .await?;
600
601 if !response.status().is_success() {
602 return Err(AudioGenerationError::ProviderError(format!(
603 "{}: {}",
604 response.status(),
605 response.text().await?
606 )));
607 }
608
609 match serde_json::from_str::<ApiResponse<AudioGenerationResponse>>(
610 &response.text().await?,
611 )? {
612 ApiResponse::Ok(response) => response.try_into(),
613 ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
614 }
615 }
616 }
617}