1use super::openai::{AssistantContent, send_compatible_streaming_request};
15
16use crate::client::{self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder};
17use crate::client::{BearerAuth, ProviderClient};
18use crate::http_client::{self, HttpClientExt};
19use crate::streaming::StreamingCompletionResponse;
20
21use crate::providers::openai;
22use crate::{
23 OneOrMany,
24 completion::{self, CompletionError, CompletionRequest},
25 json_utils,
26 providers::openai::Message,
27};
28use serde::{Deserialize, Serialize};
29
30const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz";
34
35#[derive(Debug, Default, Clone, Copy)]
36pub struct HyperbolicExt;
37#[derive(Debug, Default, Clone, Copy)]
38pub struct HyperbolicBuilder;
39
40type HyperbolicApiKey = BearerAuth;
41
42impl Provider for HyperbolicExt {
43 type Builder = HyperbolicBuilder;
44
45 const VERIFY_PATH: &'static str = "/models";
46}
47
48impl<H> Capabilities<H> for HyperbolicExt {
49 type Completion = Capable<CompletionModel<H>>;
50 type Embeddings = Nothing;
51 type Transcription = Nothing;
52 type ModelListing = Nothing;
53 #[cfg(feature = "image")]
54 type ImageGeneration = Capable<ImageGenerationModel<H>>;
55 #[cfg(feature = "audio")]
56 type AudioGeneration = Capable<AudioGenerationModel<H>>;
57}
58
59impl DebugExt for HyperbolicExt {}
60
61impl ProviderBuilder for HyperbolicBuilder {
62 type Extension<H>
63 = HyperbolicExt
64 where
65 H: HttpClientExt;
66 type ApiKey = HyperbolicApiKey;
67
68 const BASE_URL: &'static str = HYPERBOLIC_API_BASE_URL;
69
70 fn build<H>(
71 _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
72 ) -> http_client::Result<Self::Extension<H>>
73 where
74 H: HttpClientExt,
75 {
76 Ok(HyperbolicExt)
77 }
78}
79
80pub type Client<H = reqwest::Client> = client::Client<HyperbolicExt, H>;
81pub type ClientBuilder<H = crate::markers::Missing> =
82 client::ClientBuilder<HyperbolicBuilder, String, H>;
83
84impl ProviderClient for Client {
85 type Input = HyperbolicApiKey;
86 type Error = crate::client::ProviderClientError;
87
88 fn from_env() -> Result<Self, Self::Error> {
90 let api_key = crate::client::required_env_var("HYPERBOLIC_API_KEY")?;
91 Self::new(&api_key).map_err(Into::into)
92 }
93
94 fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
95 Self::new(input).map_err(Into::into)
96 }
97}
98
99#[derive(Debug, Deserialize)]
100struct ApiErrorResponse {
101 message: String,
102}
103
104#[derive(Debug, Deserialize)]
105#[serde(untagged)]
106enum ApiResponse<T> {
107 Ok(T),
108 Err(ApiErrorResponse),
109}
110
111#[derive(Debug, Deserialize)]
112pub struct EmbeddingData {
113 pub object: String,
114 pub embedding: Vec<f64>,
115 pub index: usize,
116}
117
118#[derive(Clone, Debug, Deserialize, Serialize)]
119pub struct Usage {
120 pub prompt_tokens: usize,
121 pub total_tokens: usize,
122}
123
124impl std::fmt::Display for Usage {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 write!(
127 f,
128 "Prompt tokens: {} Total tokens: {}",
129 self.prompt_tokens, self.total_tokens
130 )
131 }
132}
133
134pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
140pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
142pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
144pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
146pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
148pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
150pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
152pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
154pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
156pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
158pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
160pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
162
163#[derive(Debug, Deserialize, Serialize)]
167pub struct CompletionResponse {
168 pub id: String,
169 pub object: String,
170 pub created: u64,
171 pub model: String,
172 pub choices: Vec<Choice>,
173 pub usage: Option<Usage>,
174}
175
176impl From<ApiErrorResponse> for CompletionError {
177 fn from(err: ApiErrorResponse) -> Self {
178 CompletionError::ProviderError(err.message)
179 }
180}
181
182impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
183 type Error = CompletionError;
184
185 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
186 let choice = response.choices.first().ok_or_else(|| {
187 CompletionError::ResponseError("Response contained no choices".to_owned())
188 })?;
189
190 let content = match &choice.message {
191 Message::Assistant {
192 content,
193 tool_calls,
194 ..
195 } => {
196 let mut content = content
197 .iter()
198 .map(|c| match c {
199 AssistantContent::Text { text } => completion::AssistantContent::text(text),
200 AssistantContent::Refusal { refusal } => {
201 completion::AssistantContent::text(refusal)
202 }
203 })
204 .collect::<Vec<_>>();
205
206 content.extend(
207 tool_calls
208 .iter()
209 .map(|call| {
210 completion::AssistantContent::tool_call(
211 &call.id,
212 &call.function.name,
213 call.function.arguments.clone(),
214 )
215 })
216 .collect::<Vec<_>>(),
217 );
218 Ok(content)
219 }
220 _ => Err(CompletionError::ResponseError(
221 "Response did not contain a valid message or tool call".into(),
222 )),
223 }?;
224
225 let choice = OneOrMany::many(content).map_err(|_| {
226 CompletionError::ResponseError(
227 "Response contained no message or tool call (empty)".to_owned(),
228 )
229 })?;
230
231 let usage = response
232 .usage
233 .as_ref()
234 .map(|usage| completion::Usage {
235 input_tokens: usage.prompt_tokens as u64,
236 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
237 total_tokens: usage.total_tokens as u64,
238 cached_input_tokens: 0,
239 cache_creation_input_tokens: 0,
240 reasoning_tokens: 0,
241 })
242 .unwrap_or_default();
243
244 Ok(completion::CompletionResponse {
245 choice,
246 usage,
247 raw_response: response,
248 message_id: None,
249 })
250 }
251}
252
253#[derive(Debug, Deserialize, Serialize)]
254pub struct Choice {
255 pub index: usize,
256 pub message: Message,
257 pub finish_reason: String,
258}
259
260#[derive(Debug, Serialize, Deserialize)]
261pub(super) struct HyperbolicCompletionRequest {
262 model: String,
263 pub messages: Vec<Message>,
264 #[serde(skip_serializing_if = "Option::is_none")]
265 temperature: Option<f64>,
266 #[serde(flatten, skip_serializing_if = "Option::is_none")]
267 pub additional_params: Option<serde_json::Value>,
268}
269
270impl TryFrom<(&str, CompletionRequest)> for HyperbolicCompletionRequest {
271 type Error = CompletionError;
272
273 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
274 if req.output_schema.is_some() {
275 tracing::warn!("Structured outputs currently not supported for Hyperbolic");
276 }
277
278 let model = req.model.clone().unwrap_or_else(|| model.to_string());
279 if req.tool_choice.is_some() {
280 tracing::warn!("WARNING: `tool_choice` not supported on Hyperbolic");
281 }
282
283 if !req.tools.is_empty() {
284 tracing::warn!("WARNING: `tools` not supported on Hyperbolic");
285 }
286
287 let mut full_history: Vec<Message> = match &req.preamble {
288 Some(preamble) => vec![Message::system(preamble)],
289 None => vec![],
290 };
291
292 if let Some(docs) = req.normalized_documents() {
293 let docs: Vec<Message> = docs.try_into()?;
294 full_history.extend(docs);
295 }
296
297 let chat_history: Vec<Message> = req
298 .chat_history
299 .clone()
300 .into_iter()
301 .map(|message| message.try_into())
302 .collect::<Result<Vec<Vec<Message>>, _>>()?
303 .into_iter()
304 .flatten()
305 .collect();
306
307 full_history.extend(chat_history);
308
309 Ok(Self {
310 model: model.to_string(),
311 messages: full_history,
312 temperature: req.temperature,
313 additional_params: req.additional_params,
314 })
315 }
316}
317
318#[derive(Clone)]
319pub struct CompletionModel<T = reqwest::Client> {
320 client: Client<T>,
321 pub model: String,
323}
324
325impl<T> CompletionModel<T> {
326 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
327 Self {
328 client,
329 model: model.into(),
330 }
331 }
332
333 pub fn with_model(client: Client<T>, model: &str) -> Self {
334 Self {
335 client,
336 model: model.into(),
337 }
338 }
339}
340
341impl<T> completion::CompletionModel for CompletionModel<T>
342where
343 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
344{
345 type Response = CompletionResponse;
346 type StreamingResponse = openai::StreamingCompletionResponse;
347
348 type Client = Client<T>;
349
350 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
351 Self::new(client.clone(), model)
352 }
353
354 async fn completion(
355 &self,
356 completion_request: CompletionRequest,
357 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
358 let span = if tracing::Span::current().is_disabled() {
359 info_span!(
360 target: "rig::completions",
361 "chat",
362 gen_ai.operation.name = "chat",
363 gen_ai.provider.name = "hyperbolic",
364 gen_ai.request.model = self.model,
365 gen_ai.system_instructions = tracing::field::Empty,
366 gen_ai.response.id = tracing::field::Empty,
367 gen_ai.response.model = tracing::field::Empty,
368 gen_ai.usage.output_tokens = tracing::field::Empty,
369 gen_ai.usage.input_tokens = tracing::field::Empty,
370 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
371 )
372 } else {
373 tracing::Span::current()
374 };
375
376 span.record("gen_ai.system_instructions", &completion_request.preamble);
377 let request =
378 HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
379
380 if tracing::enabled!(tracing::Level::TRACE) {
381 tracing::trace!(target: "rig::completions",
382 "Hyperbolic completion request: {}",
383 serde_json::to_string_pretty(&request)?
384 );
385 }
386
387 let body = serde_json::to_vec(&request)?;
388
389 let req = self
390 .client
391 .post("/v1/chat/completions")?
392 .body(body)
393 .map_err(http_client::Error::from)?;
394
395 let async_block = async move {
396 let response = self.client.send::<_, bytes::Bytes>(req).await?;
397
398 let status = response.status();
399 let response_body = response.into_body().into_future().await?.to_vec();
400
401 if status.is_success() {
402 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
403 ApiResponse::Ok(response) => {
404 if tracing::enabled!(tracing::Level::TRACE) {
405 tracing::trace!(target: "rig::completions",
406 "Hyperbolic completion response: {}",
407 serde_json::to_string_pretty(&response)?
408 );
409 }
410
411 response.try_into()
412 }
413 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
414 }
415 } else {
416 Err(CompletionError::ProviderError(
417 String::from_utf8_lossy(&response_body).to_string(),
418 ))
419 }
420 };
421
422 async_block.instrument(span).await
423 }
424
425 async fn stream(
426 &self,
427 completion_request: CompletionRequest,
428 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
429 let span = if tracing::Span::current().is_disabled() {
430 info_span!(
431 target: "rig::completions",
432 "chat_streaming",
433 gen_ai.operation.name = "chat_streaming",
434 gen_ai.provider.name = "hyperbolic",
435 gen_ai.request.model = self.model,
436 gen_ai.system_instructions = tracing::field::Empty,
437 gen_ai.response.id = tracing::field::Empty,
438 gen_ai.response.model = tracing::field::Empty,
439 gen_ai.usage.output_tokens = tracing::field::Empty,
440 gen_ai.usage.input_tokens = tracing::field::Empty,
441 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
442 )
443 } else {
444 tracing::Span::current()
445 };
446
447 span.record("gen_ai.system_instructions", &completion_request.preamble);
448 let mut request =
449 HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
450
451 let params = json_utils::merge(
452 request.additional_params.unwrap_or(serde_json::json!({})),
453 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
454 );
455
456 request.additional_params = Some(params);
457
458 if tracing::enabled!(tracing::Level::TRACE) {
459 tracing::trace!(target: "rig::completions",
460 "Hyperbolic streaming completion request: {}",
461 serde_json::to_string_pretty(&request)?
462 );
463 }
464
465 let body = serde_json::to_vec(&request)?;
466
467 let req = self
468 .client
469 .post("/v1/chat/completions")?
470 .body(body)
471 .map_err(http_client::Error::from)?;
472
473 send_compatible_streaming_request(self.client.clone(), req)
474 .instrument(span)
475 .await
476 }
477}
478
479#[cfg(feature = "image")]
484pub use image_generation::*;
485
486#[cfg(feature = "image")]
487#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
488mod image_generation {
489 use super::{ApiResponse, Client};
490 use crate::http_client::HttpClientExt;
491 use crate::image_generation;
492 use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
493 use crate::json_utils::merge_inplace;
494 use base64::Engine;
495 use base64::prelude::BASE64_STANDARD;
496 use serde::Deserialize;
497 use serde_json::json;
498
499 pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
500 pub const SD2: &str = "SD2";
501 pub const SD1_5: &str = "SD1.5";
502 pub const SSD: &str = "SSD";
503 pub const SDXL_TURBO: &str = "SDXL-turbo";
504 pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
505 pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
506
507 #[derive(Clone)]
508 pub struct ImageGenerationModel<T> {
509 client: Client<T>,
510 pub model: String,
511 }
512
513 impl<T> ImageGenerationModel<T> {
514 pub(crate) fn new(client: Client<T>, model: impl Into<String>) -> Self {
515 Self {
516 client,
517 model: model.into(),
518 }
519 }
520
521 pub fn with_model(client: Client<T>, model: &str) -> Self {
522 Self {
523 client,
524 model: model.into(),
525 }
526 }
527 }
528
529 #[derive(Clone, Deserialize)]
530 pub struct Image {
531 image: String,
532 }
533
534 #[derive(Clone, Deserialize)]
535 pub struct ImageGenerationResponse {
536 images: Vec<Image>,
537 }
538
539 impl TryFrom<ImageGenerationResponse>
540 for image_generation::ImageGenerationResponse<ImageGenerationResponse>
541 {
542 type Error = ImageGenerationError;
543
544 fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
545 let image = value
546 .images
547 .first()
548 .ok_or_else(|| ImageGenerationError::ResponseError("missing image data".into()))?;
549 let data = BASE64_STANDARD
550 .decode(&image.image)
551 .map_err(|err| ImageGenerationError::ResponseError(err.to_string()))?;
552
553 Ok(Self {
554 image: data,
555 response: value,
556 })
557 }
558 }
559
560 impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
561 where
562 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
563 {
564 type Response = ImageGenerationResponse;
565
566 type Client = Client<T>;
567
568 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
569 Self::new(client.clone(), model)
570 }
571
572 async fn image_generation(
573 &self,
574 generation_request: ImageGenerationRequest,
575 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
576 {
577 let mut request = json!({
578 "model_name": self.model,
579 "prompt": generation_request.prompt,
580 "height": generation_request.height,
581 "width": generation_request.width,
582 });
583
584 if let Some(params) = generation_request.additional_params {
585 merge_inplace(&mut request, params);
586 }
587
588 let body = serde_json::to_vec(&request)?;
589
590 let request = self
591 .client
592 .post("/v1/image/generation")?
593 .header("Content-Type", "application/json")
594 .body(body)
595 .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
596
597 let response = self.client.send::<_, bytes::Bytes>(request).await?;
598
599 let status = response.status();
600 let response_body = response.into_body().into_future().await?.to_vec();
601
602 if !status.is_success() {
603 return Err(ImageGenerationError::ProviderError(format!(
604 "{status}: {}",
605 String::from_utf8_lossy(&response_body)
606 )));
607 }
608
609 match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
610 ApiResponse::Ok(response) => response.try_into(),
611 ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
612 }
613 }
614 }
615}
616
617#[cfg(feature = "audio")]
621pub use audio_generation::*;
622use tracing::{Instrument, info_span};
623
624#[cfg(feature = "audio")]
625#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
626mod audio_generation {
627 use super::{ApiResponse, Client};
628 use crate::audio_generation;
629 use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
630 use crate::http_client::{self, HttpClientExt};
631 use base64::Engine;
632 use base64::prelude::BASE64_STANDARD;
633 use bytes::Bytes;
634 use serde::Deserialize;
635 use serde_json::json;
636
637 #[derive(Clone)]
638 pub struct AudioGenerationModel<T> {
639 client: Client<T>,
640 pub language: String,
641 }
642
643 #[derive(Clone, Deserialize)]
644 pub struct AudioGenerationResponse {
645 audio: String,
646 }
647
648 impl TryFrom<AudioGenerationResponse>
649 for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
650 {
651 type Error = AudioGenerationError;
652
653 fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
654 let data = BASE64_STANDARD
655 .decode(&value.audio)
656 .map_err(|err| AudioGenerationError::ResponseError(err.to_string()))?;
657
658 Ok(Self {
659 audio: data,
660 response: value,
661 })
662 }
663 }
664
665 impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
666 where
667 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
668 {
669 type Response = AudioGenerationResponse;
670 type Client = Client<T>;
671
672 fn make(client: &Self::Client, language: impl Into<String>) -> Self {
673 Self {
674 client: client.clone(),
675 language: language.into(),
676 }
677 }
678
679 async fn audio_generation(
680 &self,
681 request: AudioGenerationRequest,
682 ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
683 {
684 let request = json!({
685 "language": self.language,
686 "speaker": request.voice,
687 "text": request.text,
688 "speed": request.speed
689 });
690
691 let body = serde_json::to_vec(&request)?;
692
693 let req = self
694 .client
695 .post("/v1/audio/generation")?
696 .body(body)
697 .map_err(http_client::Error::from)?;
698
699 let response = self.client.send::<_, Bytes>(req).await?;
700 let status = response.status();
701 let response_body = response.into_body().into_future().await?.to_vec();
702
703 if !status.is_success() {
704 return Err(AudioGenerationError::ProviderError(format!(
705 "{status}: {}",
706 String::from_utf8_lossy(&response_body)
707 )));
708 }
709
710 match serde_json::from_slice::<ApiResponse<AudioGenerationResponse>>(&response_body)? {
711 ApiResponse::Ok(response) => response.try_into(),
712 ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
713 }
714 }
715 }
716}
717
718#[cfg(test)]
719mod tests {
720 #[test]
721 fn test_client_initialization() {
722 let _client =
723 crate::providers::hyperbolic::Client::new("dummy-key").expect("Client::new() failed");
724 let _client_from_builder = crate::providers::hyperbolic::Client::builder()
725 .api_key("dummy-key")
726 .build()
727 .expect("Client::builder() failed");
728 }
729}