1use super::openai::{AssistantContent, send_compatible_streaming_request};
15
16use crate::client::{self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder};
17use crate::client::{BearerAuth, ProviderClient};
18use crate::http_client::{self, HttpClientExt};
19use crate::streaming::StreamingCompletionResponse;
20
21use crate::providers::openai;
22use crate::{
23 OneOrMany,
24 completion::{self, CompletionError, CompletionRequest},
25 json_utils,
26 providers::openai::Message,
27};
28use serde::{Deserialize, Serialize};
29
30const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz";
34
35#[derive(Debug, Default, Clone, Copy)]
36pub struct HyperbolicExt;
37#[derive(Debug, Default, Clone, Copy)]
38pub struct HyperbolicBuilder;
39
40type HyperbolicApiKey = BearerAuth;
41
42impl Provider for HyperbolicExt {
43 type Builder = HyperbolicBuilder;
44
45 const VERIFY_PATH: &'static str = "/models";
46}
47
48impl<H> Capabilities<H> for HyperbolicExt {
49 type Completion = Capable<CompletionModel<H>>;
50 type Embeddings = Nothing;
51 type Transcription = Nothing;
52 type ModelListing = Nothing;
53 #[cfg(feature = "image")]
54 type ImageGeneration = Capable<ImageGenerationModel<H>>;
55 #[cfg(feature = "audio")]
56 type AudioGeneration = Capable<AudioGenerationModel<H>>;
57}
58
59impl DebugExt for HyperbolicExt {}
60
61impl ProviderBuilder for HyperbolicBuilder {
62 type Extension<H>
63 = HyperbolicExt
64 where
65 H: HttpClientExt;
66 type ApiKey = HyperbolicApiKey;
67
68 const BASE_URL: &'static str = HYPERBOLIC_API_BASE_URL;
69
70 fn build<H>(
71 _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
72 ) -> http_client::Result<Self::Extension<H>>
73 where
74 H: HttpClientExt,
75 {
76 Ok(HyperbolicExt)
77 }
78}
79
80pub type Client<H = reqwest::Client> = client::Client<HyperbolicExt, H>;
81pub type ClientBuilder<H = crate::markers::Missing> =
82 client::ClientBuilder<HyperbolicBuilder, String, H>;
83
84impl ProviderClient for Client {
85 type Input = HyperbolicApiKey;
86 type Error = crate::client::ProviderClientError;
87
88 fn from_env() -> Result<Self, Self::Error> {
90 let api_key = crate::client::required_env_var("HYPERBOLIC_API_KEY")?;
91 Self::new(&api_key).map_err(Into::into)
92 }
93
94 fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
95 Self::new(input).map_err(Into::into)
96 }
97}
98
99#[derive(Debug, Deserialize)]
100struct ApiErrorResponse {
101 message: String,
102}
103
104#[derive(Debug, Deserialize)]
105#[serde(untagged)]
106enum ApiResponse<T> {
107 Ok(T),
108 Err(ApiErrorResponse),
109}
110
111#[derive(Debug, Deserialize)]
112pub struct EmbeddingData {
113 pub object: String,
114 pub embedding: Vec<f64>,
115 pub index: usize,
116}
117
118#[derive(Clone, Debug, Deserialize, Serialize)]
119pub struct Usage {
120 pub prompt_tokens: usize,
121 pub total_tokens: usize,
122}
123
124impl std::fmt::Display for Usage {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 write!(
127 f,
128 "Prompt tokens: {} Total tokens: {}",
129 self.prompt_tokens, self.total_tokens
130 )
131 }
132}
133
134pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
140pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
142pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
144pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
146pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
148pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
150pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
152pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
154pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
156pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
158pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
160pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
162
163#[derive(Debug, Deserialize, Serialize)]
167pub struct CompletionResponse {
168 pub id: String,
169 pub object: String,
170 pub created: u64,
171 pub model: String,
172 pub choices: Vec<Choice>,
173 pub usage: Option<Usage>,
174}
175
176impl From<ApiErrorResponse> for CompletionError {
177 fn from(err: ApiErrorResponse) -> Self {
178 CompletionError::ProviderError(err.message)
179 }
180}
181
182impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
183 type Error = CompletionError;
184
185 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
186 let choice = response.choices.first().ok_or_else(|| {
187 CompletionError::ResponseError("Response contained no choices".to_owned())
188 })?;
189
190 let content = match &choice.message {
191 Message::Assistant {
192 content,
193 tool_calls,
194 ..
195 } => {
196 let mut content = content
197 .iter()
198 .map(|c| match c {
199 AssistantContent::Text { text } => completion::AssistantContent::text(text),
200 AssistantContent::Refusal { refusal } => {
201 completion::AssistantContent::text(refusal)
202 }
203 })
204 .collect::<Vec<_>>();
205
206 content.extend(
207 tool_calls
208 .iter()
209 .map(|call| {
210 completion::AssistantContent::tool_call(
211 &call.id,
212 &call.function.name,
213 call.function.arguments.clone(),
214 )
215 })
216 .collect::<Vec<_>>(),
217 );
218 Ok(content)
219 }
220 _ => Err(CompletionError::ResponseError(
221 "Response did not contain a valid message or tool call".into(),
222 )),
223 }?;
224
225 let choice = OneOrMany::many(content).map_err(|_| {
226 CompletionError::ResponseError(
227 "Response contained no message or tool call (empty)".to_owned(),
228 )
229 })?;
230
231 let usage = response
232 .usage
233 .as_ref()
234 .map(|usage| completion::Usage {
235 input_tokens: usage.prompt_tokens as u64,
236 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
237 total_tokens: usage.total_tokens as u64,
238 cached_input_tokens: 0,
239 cache_creation_input_tokens: 0,
240 tool_use_prompt_tokens: 0,
241 reasoning_tokens: 0,
242 })
243 .unwrap_or_default();
244
245 Ok(completion::CompletionResponse {
246 choice,
247 usage,
248 raw_response: response,
249 message_id: None,
250 })
251 }
252}
253
254#[derive(Debug, Deserialize, Serialize)]
255pub struct Choice {
256 pub index: usize,
257 pub message: Message,
258 pub finish_reason: String,
259}
260
261#[derive(Debug, Serialize, Deserialize)]
262pub(super) struct HyperbolicCompletionRequest {
263 model: String,
264 pub messages: Vec<Message>,
265 #[serde(skip_serializing_if = "Option::is_none")]
266 temperature: Option<f64>,
267 #[serde(flatten, skip_serializing_if = "Option::is_none")]
268 pub additional_params: Option<serde_json::Value>,
269}
270
271impl TryFrom<(&str, CompletionRequest)> for HyperbolicCompletionRequest {
272 type Error = CompletionError;
273
274 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
275 if req.output_schema.is_some() {
276 tracing::warn!("Structured outputs currently not supported for Hyperbolic");
277 }
278
279 let model = req.model.clone().unwrap_or_else(|| model.to_string());
280 if req.tool_choice.is_some() {
281 tracing::warn!("WARNING: `tool_choice` not supported on Hyperbolic");
282 }
283
284 if !req.tools.is_empty() {
285 tracing::warn!("WARNING: `tools` not supported on Hyperbolic");
286 }
287
288 let mut full_history: Vec<Message> = match &req.preamble {
289 Some(preamble) => vec![Message::system(preamble)],
290 None => vec![],
291 };
292
293 if let Some(docs) = req.normalized_documents() {
294 let docs: Vec<Message> = docs.try_into()?;
295 full_history.extend(docs);
296 }
297
298 let chat_history: Vec<Message> = req
299 .chat_history
300 .clone()
301 .into_iter()
302 .map(|message| message.try_into())
303 .collect::<Result<Vec<Vec<Message>>, _>>()?
304 .into_iter()
305 .flatten()
306 .collect();
307
308 full_history.extend(chat_history);
309
310 Ok(Self {
311 model: model.to_string(),
312 messages: full_history,
313 temperature: req.temperature,
314 additional_params: req.additional_params,
315 })
316 }
317}
318
319#[derive(Clone)]
320pub struct CompletionModel<T = reqwest::Client> {
321 client: Client<T>,
322 pub model: String,
324}
325
326impl<T> CompletionModel<T> {
327 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
328 Self {
329 client,
330 model: model.into(),
331 }
332 }
333
334 pub fn with_model(client: Client<T>, model: &str) -> Self {
335 Self {
336 client,
337 model: model.into(),
338 }
339 }
340}
341
342impl<T> completion::CompletionModel for CompletionModel<T>
343where
344 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
345{
346 type Response = CompletionResponse;
347 type StreamingResponse = openai::StreamingCompletionResponse;
348
349 type Client = Client<T>;
350
351 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
352 Self::new(client.clone(), model)
353 }
354
355 async fn completion(
356 &self,
357 completion_request: CompletionRequest,
358 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
359 let span = if tracing::Span::current().is_disabled() {
360 info_span!(
361 target: "rig::completions",
362 "chat",
363 gen_ai.operation.name = "chat",
364 gen_ai.provider.name = "hyperbolic",
365 gen_ai.request.model = self.model,
366 gen_ai.system_instructions = tracing::field::Empty,
367 gen_ai.response.id = tracing::field::Empty,
368 gen_ai.response.model = tracing::field::Empty,
369 gen_ai.usage.output_tokens = tracing::field::Empty,
370 gen_ai.usage.input_tokens = tracing::field::Empty,
371 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
372 )
373 } else {
374 tracing::Span::current()
375 };
376
377 span.record("gen_ai.system_instructions", &completion_request.preamble);
378 let request =
379 HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
380
381 if tracing::enabled!(tracing::Level::TRACE) {
382 tracing::trace!(target: "rig::completions",
383 "Hyperbolic completion request: {}",
384 serde_json::to_string_pretty(&request)?
385 );
386 }
387
388 let body = serde_json::to_vec(&request)?;
389
390 let req = self
391 .client
392 .post("/v1/chat/completions")?
393 .body(body)
394 .map_err(http_client::Error::from)?;
395
396 let async_block = async move {
397 let response = self.client.send::<_, bytes::Bytes>(req).await?;
398
399 let status = response.status();
400 let response_body = response.into_body().into_future().await?.to_vec();
401
402 if status.is_success() {
403 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
404 ApiResponse::Ok(response) => {
405 if tracing::enabled!(tracing::Level::TRACE) {
406 tracing::trace!(target: "rig::completions",
407 "Hyperbolic completion response: {}",
408 serde_json::to_string_pretty(&response)?
409 );
410 }
411
412 response.try_into()
413 }
414 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
415 }
416 } else {
417 Err(CompletionError::ProviderError(
418 String::from_utf8_lossy(&response_body).to_string(),
419 ))
420 }
421 };
422
423 async_block.instrument(span).await
424 }
425
426 async fn stream(
427 &self,
428 completion_request: CompletionRequest,
429 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
430 let span = if tracing::Span::current().is_disabled() {
431 info_span!(
432 target: "rig::completions",
433 "chat_streaming",
434 gen_ai.operation.name = "chat_streaming",
435 gen_ai.provider.name = "hyperbolic",
436 gen_ai.request.model = self.model,
437 gen_ai.system_instructions = tracing::field::Empty,
438 gen_ai.response.id = tracing::field::Empty,
439 gen_ai.response.model = tracing::field::Empty,
440 gen_ai.usage.output_tokens = tracing::field::Empty,
441 gen_ai.usage.input_tokens = tracing::field::Empty,
442 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
443 )
444 } else {
445 tracing::Span::current()
446 };
447
448 span.record("gen_ai.system_instructions", &completion_request.preamble);
449 let mut request =
450 HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
451
452 let params = json_utils::merge(
453 request.additional_params.unwrap_or(serde_json::json!({})),
454 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
455 );
456
457 request.additional_params = Some(params);
458
459 if tracing::enabled!(tracing::Level::TRACE) {
460 tracing::trace!(target: "rig::completions",
461 "Hyperbolic streaming completion request: {}",
462 serde_json::to_string_pretty(&request)?
463 );
464 }
465
466 let body = serde_json::to_vec(&request)?;
467
468 let req = self
469 .client
470 .post("/v1/chat/completions")?
471 .body(body)
472 .map_err(http_client::Error::from)?;
473
474 send_compatible_streaming_request(self.client.clone(), req)
475 .instrument(span)
476 .await
477 }
478}
479
480#[cfg(feature = "image")]
485pub use image_generation::*;
486
487#[cfg(feature = "image")]
488#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
489mod image_generation {
490 use super::{ApiResponse, Client};
491 use crate::http_client::HttpClientExt;
492 use crate::image_generation;
493 use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
494 use crate::json_utils::merge_inplace;
495 use base64::Engine;
496 use base64::prelude::BASE64_STANDARD;
497 use serde::Deserialize;
498 use serde_json::json;
499
500 pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
501 pub const SD2: &str = "SD2";
502 pub const SD1_5: &str = "SD1.5";
503 pub const SSD: &str = "SSD";
504 pub const SDXL_TURBO: &str = "SDXL-turbo";
505 pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
506 pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
507
508 #[derive(Clone)]
509 pub struct ImageGenerationModel<T> {
510 client: Client<T>,
511 pub model: String,
512 }
513
514 impl<T> ImageGenerationModel<T> {
515 pub(crate) fn new(client: Client<T>, model: impl Into<String>) -> Self {
516 Self {
517 client,
518 model: model.into(),
519 }
520 }
521
522 pub fn with_model(client: Client<T>, model: &str) -> Self {
523 Self {
524 client,
525 model: model.into(),
526 }
527 }
528 }
529
530 #[derive(Clone, Deserialize)]
531 pub struct Image {
532 image: String,
533 }
534
535 #[derive(Clone, Deserialize)]
536 pub struct ImageGenerationResponse {
537 images: Vec<Image>,
538 }
539
540 impl TryFrom<ImageGenerationResponse>
541 for image_generation::ImageGenerationResponse<ImageGenerationResponse>
542 {
543 type Error = ImageGenerationError;
544
545 fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
546 let image = value
547 .images
548 .first()
549 .ok_or_else(|| ImageGenerationError::ResponseError("missing image data".into()))?;
550 let data = BASE64_STANDARD
551 .decode(&image.image)
552 .map_err(|err| ImageGenerationError::ResponseError(err.to_string()))?;
553
554 Ok(Self {
555 image: data,
556 response: value,
557 })
558 }
559 }
560
561 impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
562 where
563 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
564 {
565 type Response = ImageGenerationResponse;
566
567 type Client = Client<T>;
568
569 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
570 Self::new(client.clone(), model)
571 }
572
573 async fn image_generation(
574 &self,
575 generation_request: ImageGenerationRequest,
576 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
577 {
578 let mut request = json!({
579 "model_name": self.model,
580 "prompt": generation_request.prompt,
581 "height": generation_request.height,
582 "width": generation_request.width,
583 });
584
585 if let Some(params) = generation_request.additional_params {
586 merge_inplace(&mut request, params);
587 }
588
589 let body = serde_json::to_vec(&request)?;
590
591 let request = self
592 .client
593 .post("/v1/image/generation")?
594 .header("Content-Type", "application/json")
595 .body(body)
596 .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
597
598 let response = self.client.send::<_, bytes::Bytes>(request).await?;
599
600 let status = response.status();
601 let response_body = response.into_body().into_future().await?.to_vec();
602
603 if !status.is_success() {
604 return Err(ImageGenerationError::ProviderError(format!(
605 "{status}: {}",
606 String::from_utf8_lossy(&response_body)
607 )));
608 }
609
610 match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
611 ApiResponse::Ok(response) => response.try_into(),
612 ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
613 }
614 }
615 }
616}
617
618#[cfg(feature = "audio")]
622pub use audio_generation::*;
623use tracing::{Instrument, info_span};
624
625#[cfg(feature = "audio")]
626#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
627mod audio_generation {
628 use super::{ApiResponse, Client};
629 use crate::audio_generation;
630 use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
631 use crate::http_client::{self, HttpClientExt};
632 use base64::Engine;
633 use base64::prelude::BASE64_STANDARD;
634 use bytes::Bytes;
635 use serde::Deserialize;
636 use serde_json::json;
637
638 #[derive(Clone)]
639 pub struct AudioGenerationModel<T> {
640 client: Client<T>,
641 pub language: String,
642 }
643
644 #[derive(Clone, Deserialize)]
645 pub struct AudioGenerationResponse {
646 audio: String,
647 }
648
649 impl TryFrom<AudioGenerationResponse>
650 for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
651 {
652 type Error = AudioGenerationError;
653
654 fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
655 let data = BASE64_STANDARD
656 .decode(&value.audio)
657 .map_err(|err| AudioGenerationError::ResponseError(err.to_string()))?;
658
659 Ok(Self {
660 audio: data,
661 response: value,
662 })
663 }
664 }
665
666 impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
667 where
668 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
669 {
670 type Response = AudioGenerationResponse;
671 type Client = Client<T>;
672
673 fn make(client: &Self::Client, language: impl Into<String>) -> Self {
674 Self {
675 client: client.clone(),
676 language: language.into(),
677 }
678 }
679
680 async fn audio_generation(
681 &self,
682 request: AudioGenerationRequest,
683 ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
684 {
685 let request = json!({
686 "language": self.language,
687 "speaker": request.voice,
688 "text": request.text,
689 "speed": request.speed
690 });
691
692 let body = serde_json::to_vec(&request)?;
693
694 let req = self
695 .client
696 .post("/v1/audio/generation")?
697 .body(body)
698 .map_err(http_client::Error::from)?;
699
700 let response = self.client.send::<_, Bytes>(req).await?;
701 let status = response.status();
702 let response_body = response.into_body().into_future().await?.to_vec();
703
704 if !status.is_success() {
705 return Err(AudioGenerationError::ProviderError(format!(
706 "{status}: {}",
707 String::from_utf8_lossy(&response_body)
708 )));
709 }
710
711 match serde_json::from_slice::<ApiResponse<AudioGenerationResponse>>(&response_body)? {
712 ApiResponse::Ok(response) => response.try_into(),
713 ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
714 }
715 }
716 }
717}
718
719#[cfg(test)]
720mod tests {
721 #[test]
722 fn test_client_initialization() {
723 let _client =
724 crate::providers::hyperbolic::Client::new("dummy-key").expect("Client::new() failed");
725 let _client_from_builder = crate::providers::hyperbolic::Client::builder()
726 .api_key("dummy-key")
727 .build()
728 .expect("Client::builder() failed");
729 }
730}