1use super::openai::{AssistantContent, send_compatible_streaming_request};
12
13use crate::client::{self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder};
14use crate::client::{BearerAuth, ProviderClient};
15use crate::http_client::{self, HttpClientExt};
16use crate::streaming::StreamingCompletionResponse;
17
18use crate::providers::openai;
19use crate::{
20 OneOrMany,
21 completion::{self, CompletionError, CompletionRequest},
22 json_utils,
23 providers::openai::Message,
24};
25use serde::{Deserialize, Serialize};
26
27const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz";
31
32#[derive(Debug, Default, Clone, Copy)]
33pub struct HyperbolicExt;
34#[derive(Debug, Default, Clone, Copy)]
35pub struct HyperbolicBuilder;
36
37type HyperbolicApiKey = BearerAuth;
38
39impl Provider for HyperbolicExt {
40 type Builder = HyperbolicBuilder;
41
42 const VERIFY_PATH: &'static str = "/models";
43}
44
45impl<H> Capabilities<H> for HyperbolicExt {
46 type Completion = Capable<CompletionModel<H>>;
47 type Embeddings = Nothing;
48 type Transcription = Nothing;
49 type ModelListing = Nothing;
50 #[cfg(feature = "image")]
51 type ImageGeneration = Capable<ImageGenerationModel<H>>;
52 #[cfg(feature = "audio")]
53 type AudioGeneration = Capable<AudioGenerationModel<H>>;
54}
55
56impl DebugExt for HyperbolicExt {}
57
58impl ProviderBuilder for HyperbolicBuilder {
59 type Extension<H>
60 = HyperbolicExt
61 where
62 H: HttpClientExt;
63 type ApiKey = HyperbolicApiKey;
64
65 const BASE_URL: &'static str = HYPERBOLIC_API_BASE_URL;
66
67 fn build<H>(
68 _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
69 ) -> http_client::Result<Self::Extension<H>>
70 where
71 H: HttpClientExt,
72 {
73 Ok(HyperbolicExt)
74 }
75}
76
77pub type Client<H = reqwest::Client> = client::Client<HyperbolicExt, H>;
78pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<HyperbolicBuilder, String, H>;
79
80impl ProviderClient for Client {
81 type Input = HyperbolicApiKey;
82
83 fn from_env() -> Self {
86 let api_key = std::env::var("HYPERBOLIC_API_KEY").expect("HYPERBOLIC_API_KEY not set");
87 Self::new(&api_key).unwrap()
88 }
89
90 fn from_val(input: Self::Input) -> Self {
91 Self::new(input).unwrap()
92 }
93}
94
95#[derive(Debug, Deserialize)]
96struct ApiErrorResponse {
97 message: String,
98}
99
100#[derive(Debug, Deserialize)]
101#[serde(untagged)]
102enum ApiResponse<T> {
103 Ok(T),
104 Err(ApiErrorResponse),
105}
106
107#[derive(Debug, Deserialize)]
108pub struct EmbeddingData {
109 pub object: String,
110 pub embedding: Vec<f64>,
111 pub index: usize,
112}
113
114#[derive(Clone, Debug, Deserialize, Serialize)]
115pub struct Usage {
116 pub prompt_tokens: usize,
117 pub total_tokens: usize,
118}
119
120impl std::fmt::Display for Usage {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 write!(
123 f,
124 "Prompt tokens: {} Total tokens: {}",
125 self.prompt_tokens, self.total_tokens
126 )
127 }
128}
129
130pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
136pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
138pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
140pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
142pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
144pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
146pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
148pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
150pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
152pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
154pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
156pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
158
159#[derive(Debug, Deserialize, Serialize)]
163pub struct CompletionResponse {
164 pub id: String,
165 pub object: String,
166 pub created: u64,
167 pub model: String,
168 pub choices: Vec<Choice>,
169 pub usage: Option<Usage>,
170}
171
172impl From<ApiErrorResponse> for CompletionError {
173 fn from(err: ApiErrorResponse) -> Self {
174 CompletionError::ProviderError(err.message)
175 }
176}
177
178impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
179 type Error = CompletionError;
180
181 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
182 let choice = response.choices.first().ok_or_else(|| {
183 CompletionError::ResponseError("Response contained no choices".to_owned())
184 })?;
185
186 let content = match &choice.message {
187 Message::Assistant {
188 content,
189 tool_calls,
190 ..
191 } => {
192 let mut content = content
193 .iter()
194 .map(|c| match c {
195 AssistantContent::Text { text } => completion::AssistantContent::text(text),
196 AssistantContent::Refusal { refusal } => {
197 completion::AssistantContent::text(refusal)
198 }
199 })
200 .collect::<Vec<_>>();
201
202 content.extend(
203 tool_calls
204 .iter()
205 .map(|call| {
206 completion::AssistantContent::tool_call(
207 &call.id,
208 &call.function.name,
209 call.function.arguments.clone(),
210 )
211 })
212 .collect::<Vec<_>>(),
213 );
214 Ok(content)
215 }
216 _ => Err(CompletionError::ResponseError(
217 "Response did not contain a valid message or tool call".into(),
218 )),
219 }?;
220
221 let choice = OneOrMany::many(content).map_err(|_| {
222 CompletionError::ResponseError(
223 "Response contained no message or tool call (empty)".to_owned(),
224 )
225 })?;
226
227 let usage = response
228 .usage
229 .as_ref()
230 .map(|usage| completion::Usage {
231 input_tokens: usage.prompt_tokens as u64,
232 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
233 total_tokens: usage.total_tokens as u64,
234 cached_input_tokens: 0,
235 })
236 .unwrap_or_default();
237
238 Ok(completion::CompletionResponse {
239 choice,
240 usage,
241 raw_response: response,
242 message_id: None,
243 })
244 }
245}
246
247#[derive(Debug, Deserialize, Serialize)]
248pub struct Choice {
249 pub index: usize,
250 pub message: Message,
251 pub finish_reason: String,
252}
253
254#[derive(Debug, Serialize, Deserialize)]
255pub(super) struct HyperbolicCompletionRequest {
256 model: String,
257 pub messages: Vec<Message>,
258 #[serde(skip_serializing_if = "Option::is_none")]
259 temperature: Option<f64>,
260 #[serde(flatten, skip_serializing_if = "Option::is_none")]
261 pub additional_params: Option<serde_json::Value>,
262}
263
264impl TryFrom<(&str, CompletionRequest)> for HyperbolicCompletionRequest {
265 type Error = CompletionError;
266
267 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
268 if req.output_schema.is_some() {
269 tracing::warn!("Structured outputs currently not supported for Hyperbolic");
270 }
271
272 let model = req.model.clone().unwrap_or_else(|| model.to_string());
273 if req.tool_choice.is_some() {
274 tracing::warn!("WARNING: `tool_choice` not supported on Hyperbolic");
275 }
276
277 if !req.tools.is_empty() {
278 tracing::warn!("WARNING: `tools` not supported on Hyperbolic");
279 }
280
281 let mut full_history: Vec<Message> = match &req.preamble {
282 Some(preamble) => vec![Message::system(preamble)],
283 None => vec![],
284 };
285
286 if let Some(docs) = req.normalized_documents() {
287 let docs: Vec<Message> = docs.try_into()?;
288 full_history.extend(docs);
289 }
290
291 let chat_history: Vec<Message> = req
292 .chat_history
293 .clone()
294 .into_iter()
295 .map(|message| message.try_into())
296 .collect::<Result<Vec<Vec<Message>>, _>>()?
297 .into_iter()
298 .flatten()
299 .collect();
300
301 full_history.extend(chat_history);
302
303 Ok(Self {
304 model: model.to_string(),
305 messages: full_history,
306 temperature: req.temperature,
307 additional_params: req.additional_params,
308 })
309 }
310}
311
312#[derive(Clone)]
313pub struct CompletionModel<T = reqwest::Client> {
314 client: Client<T>,
315 pub model: String,
317}
318
319impl<T> CompletionModel<T> {
320 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
321 Self {
322 client,
323 model: model.into(),
324 }
325 }
326
327 pub fn with_model(client: Client<T>, model: &str) -> Self {
328 Self {
329 client,
330 model: model.into(),
331 }
332 }
333}
334
335impl<T> completion::CompletionModel for CompletionModel<T>
336where
337 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
338{
339 type Response = CompletionResponse;
340 type StreamingResponse = openai::StreamingCompletionResponse;
341
342 type Client = Client<T>;
343
344 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
345 Self::new(client.clone(), model)
346 }
347
348 async fn completion(
349 &self,
350 completion_request: CompletionRequest,
351 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
352 let span = if tracing::Span::current().is_disabled() {
353 info_span!(
354 target: "rig::completions",
355 "chat",
356 gen_ai.operation.name = "chat",
357 gen_ai.provider.name = "hyperbolic",
358 gen_ai.request.model = self.model,
359 gen_ai.system_instructions = tracing::field::Empty,
360 gen_ai.response.id = tracing::field::Empty,
361 gen_ai.response.model = tracing::field::Empty,
362 gen_ai.usage.output_tokens = tracing::field::Empty,
363 gen_ai.usage.input_tokens = tracing::field::Empty,
364 )
365 } else {
366 tracing::Span::current()
367 };
368
369 span.record("gen_ai.system_instructions", &completion_request.preamble);
370 let request =
371 HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
372
373 if tracing::enabled!(tracing::Level::TRACE) {
374 tracing::trace!(target: "rig::completions",
375 "Hyperbolic completion request: {}",
376 serde_json::to_string_pretty(&request)?
377 );
378 }
379
380 let body = serde_json::to_vec(&request)?;
381
382 let req = self
383 .client
384 .post("/v1/chat/completions")?
385 .body(body)
386 .map_err(http_client::Error::from)?;
387
388 let async_block = async move {
389 let response = self.client.send::<_, bytes::Bytes>(req).await?;
390
391 let status = response.status();
392 let response_body = response.into_body().into_future().await?.to_vec();
393
394 if status.is_success() {
395 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
396 ApiResponse::Ok(response) => {
397 if tracing::enabled!(tracing::Level::TRACE) {
398 tracing::trace!(target: "rig::completions",
399 "Hyperbolic completion response: {}",
400 serde_json::to_string_pretty(&response)?
401 );
402 }
403
404 response.try_into()
405 }
406 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
407 }
408 } else {
409 Err(CompletionError::ProviderError(
410 String::from_utf8_lossy(&response_body).to_string(),
411 ))
412 }
413 };
414
415 async_block.instrument(span).await
416 }
417
418 async fn stream(
419 &self,
420 completion_request: CompletionRequest,
421 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
422 let span = if tracing::Span::current().is_disabled() {
423 info_span!(
424 target: "rig::completions",
425 "chat_streaming",
426 gen_ai.operation.name = "chat_streaming",
427 gen_ai.provider.name = "hyperbolic",
428 gen_ai.request.model = self.model,
429 gen_ai.system_instructions = tracing::field::Empty,
430 gen_ai.response.id = tracing::field::Empty,
431 gen_ai.response.model = tracing::field::Empty,
432 gen_ai.usage.output_tokens = tracing::field::Empty,
433 gen_ai.usage.input_tokens = tracing::field::Empty,
434 )
435 } else {
436 tracing::Span::current()
437 };
438
439 span.record("gen_ai.system_instructions", &completion_request.preamble);
440 let mut request =
441 HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
442
443 let params = json_utils::merge(
444 request.additional_params.unwrap_or(serde_json::json!({})),
445 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
446 );
447
448 request.additional_params = Some(params);
449
450 if tracing::enabled!(tracing::Level::TRACE) {
451 tracing::trace!(target: "rig::completions",
452 "Hyperbolic streaming completion request: {}",
453 serde_json::to_string_pretty(&request)?
454 );
455 }
456
457 let body = serde_json::to_vec(&request)?;
458
459 let req = self
460 .client
461 .post("/v1/chat/completions")?
462 .body(body)
463 .map_err(http_client::Error::from)?;
464
465 send_compatible_streaming_request(self.client.clone(), req)
466 .instrument(span)
467 .await
468 }
469}
470
471#[cfg(feature = "image")]
476pub use image_generation::*;
477
478#[cfg(feature = "image")]
479#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
480mod image_generation {
481 use super::{ApiResponse, Client};
482 use crate::http_client::HttpClientExt;
483 use crate::image_generation;
484 use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
485 use crate::json_utils::merge_inplace;
486 use base64::Engine;
487 use base64::prelude::BASE64_STANDARD;
488 use serde::Deserialize;
489 use serde_json::json;
490
491 pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
492 pub const SD2: &str = "SD2";
493 pub const SD1_5: &str = "SD1.5";
494 pub const SSD: &str = "SSD";
495 pub const SDXL_TURBO: &str = "SDXL-turbo";
496 pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
497 pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
498
499 #[derive(Clone)]
500 pub struct ImageGenerationModel<T> {
501 client: Client<T>,
502 pub model: String,
503 }
504
505 impl<T> ImageGenerationModel<T> {
506 pub(crate) fn new(client: Client<T>, model: impl Into<String>) -> Self {
507 Self {
508 client,
509 model: model.into(),
510 }
511 }
512
513 pub fn with_model(client: Client<T>, model: &str) -> Self {
514 Self {
515 client,
516 model: model.into(),
517 }
518 }
519 }
520
521 #[derive(Clone, Deserialize)]
522 pub struct Image {
523 image: String,
524 }
525
526 #[derive(Clone, Deserialize)]
527 pub struct ImageGenerationResponse {
528 images: Vec<Image>,
529 }
530
531 impl TryFrom<ImageGenerationResponse>
532 for image_generation::ImageGenerationResponse<ImageGenerationResponse>
533 {
534 type Error = ImageGenerationError;
535
536 fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
537 let data = BASE64_STANDARD
538 .decode(&value.images[0].image)
539 .expect("Could not decode image.");
540
541 Ok(Self {
542 image: data,
543 response: value,
544 })
545 }
546 }
547
548 impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
549 where
550 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
551 {
552 type Response = ImageGenerationResponse;
553
554 type Client = Client<T>;
555
556 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
557 Self::new(client.clone(), model)
558 }
559
560 async fn image_generation(
561 &self,
562 generation_request: ImageGenerationRequest,
563 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
564 {
565 let mut request = json!({
566 "model_name": self.model,
567 "prompt": generation_request.prompt,
568 "height": generation_request.height,
569 "width": generation_request.width,
570 });
571
572 if let Some(params) = generation_request.additional_params {
573 merge_inplace(&mut request, params);
574 }
575
576 let body = serde_json::to_vec(&request)?;
577
578 let request = self
579 .client
580 .post("/v1/image/generation")?
581 .header("Content-Type", "application/json")
582 .body(body)
583 .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
584
585 let response = self.client.send::<_, bytes::Bytes>(request).await?;
586
587 let status = response.status();
588 let response_body = response.into_body().into_future().await?.to_vec();
589
590 if !status.is_success() {
591 return Err(ImageGenerationError::ProviderError(format!(
592 "{status}: {}",
593 String::from_utf8_lossy(&response_body)
594 )));
595 }
596
597 match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
598 ApiResponse::Ok(response) => response.try_into(),
599 ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
600 }
601 }
602 }
603}
604
605#[cfg(feature = "audio")]
609pub use audio_generation::*;
610use tracing::{Instrument, info_span};
611
612#[cfg(feature = "audio")]
613#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
614mod audio_generation {
615 use super::{ApiResponse, Client};
616 use crate::audio_generation;
617 use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
618 use crate::http_client::{self, HttpClientExt};
619 use base64::Engine;
620 use base64::prelude::BASE64_STANDARD;
621 use bytes::Bytes;
622 use serde::Deserialize;
623 use serde_json::json;
624
625 #[derive(Clone)]
626 pub struct AudioGenerationModel<T> {
627 client: Client<T>,
628 pub language: String,
629 }
630
631 #[derive(Clone, Deserialize)]
632 pub struct AudioGenerationResponse {
633 audio: String,
634 }
635
636 impl TryFrom<AudioGenerationResponse>
637 for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
638 {
639 type Error = AudioGenerationError;
640
641 fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
642 let data = BASE64_STANDARD
643 .decode(&value.audio)
644 .expect("Could not decode audio.");
645
646 Ok(Self {
647 audio: data,
648 response: value,
649 })
650 }
651 }
652
653 impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
654 where
655 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
656 {
657 type Response = AudioGenerationResponse;
658 type Client = Client<T>;
659
660 fn make(client: &Self::Client, language: impl Into<String>) -> Self {
661 Self {
662 client: client.clone(),
663 language: language.into(),
664 }
665 }
666
667 async fn audio_generation(
668 &self,
669 request: AudioGenerationRequest,
670 ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
671 {
672 let request = json!({
673 "language": self.language,
674 "speaker": request.voice,
675 "text": request.text,
676 "speed": request.speed
677 });
678
679 let body = serde_json::to_vec(&request)?;
680
681 let req = self
682 .client
683 .post("/v1/audio/generation")?
684 .body(body)
685 .map_err(http_client::Error::from)?;
686
687 let response = self.client.send::<_, Bytes>(req).await?;
688 let status = response.status();
689 let response_body = response.into_body().into_future().await?.to_vec();
690
691 if !status.is_success() {
692 return Err(AudioGenerationError::ProviderError(format!(
693 "{status}: {}",
694 String::from_utf8_lossy(&response_body)
695 )));
696 }
697
698 match serde_json::from_slice::<ApiResponse<AudioGenerationResponse>>(&response_body)? {
699 ApiResponse::Ok(response) => response.try_into(),
700 ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
701 }
702 }
703 }
704}
705
706#[cfg(test)]
707mod tests {
708 #[test]
709 fn test_client_initialization() {
710 let _client =
711 crate::providers::hyperbolic::Client::new("dummy-key").expect("Client::new() failed");
712 let _client_from_builder = crate::providers::hyperbolic::Client::builder()
713 .api_key("dummy-key")
714 .build()
715 .expect("Client::builder() failed");
716 }
717}