1use super::openai::{AssistantContent, send_compatible_streaming_request};
15
16use crate::client::{self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder};
17use crate::client::{BearerAuth, ProviderClient};
18use crate::http_client::{self, HttpClientExt};
19use crate::streaming::StreamingCompletionResponse;
20
21use crate::providers::openai;
22use crate::{
23 OneOrMany,
24 completion::{self, CompletionError, CompletionRequest},
25 json_utils,
26 providers::openai::Message,
27};
28use serde::{Deserialize, Serialize};
29
30const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz";
34
35#[derive(Debug, Default, Clone, Copy)]
36pub struct HyperbolicExt;
37#[derive(Debug, Default, Clone, Copy)]
38pub struct HyperbolicBuilder;
39
40type HyperbolicApiKey = BearerAuth;
41
42impl Provider for HyperbolicExt {
43 type Builder = HyperbolicBuilder;
44
45 const VERIFY_PATH: &'static str = "/models";
46}
47
48impl<H> Capabilities<H> for HyperbolicExt {
49 type Completion = Capable<CompletionModel<H>>;
50 type Embeddings = Nothing;
51 type Transcription = Nothing;
52 type ModelListing = Nothing;
53 #[cfg(feature = "image")]
54 type ImageGeneration = Capable<ImageGenerationModel<H>>;
55 #[cfg(feature = "audio")]
56 type AudioGeneration = Capable<AudioGenerationModel<H>>;
57 type Rerank = Nothing;
58}
59
60impl DebugExt for HyperbolicExt {}
61
62impl ProviderBuilder for HyperbolicBuilder {
63 type Extension<H>
64 = HyperbolicExt
65 where
66 H: HttpClientExt;
67 type ApiKey = HyperbolicApiKey;
68
69 const BASE_URL: &'static str = HYPERBOLIC_API_BASE_URL;
70
71 fn build<H>(
72 _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
73 ) -> http_client::Result<Self::Extension<H>>
74 where
75 H: HttpClientExt,
76 {
77 Ok(HyperbolicExt)
78 }
79}
80
81pub type Client<H = reqwest::Client> = client::Client<HyperbolicExt, H>;
82pub type ClientBuilder<H = crate::markers::Missing> =
83 client::ClientBuilder<HyperbolicBuilder, String, H>;
84
85impl ProviderClient for Client {
86 type Input = HyperbolicApiKey;
87 type Error = crate::client::ProviderClientError;
88
89 fn from_env() -> Result<Self, Self::Error> {
91 let api_key = crate::client::required_env_var("HYPERBOLIC_API_KEY")?;
92 Self::new(&api_key).map_err(Into::into)
93 }
94
95 fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
96 Self::new(input).map_err(Into::into)
97 }
98}
99
100#[derive(Debug, Deserialize)]
101struct ApiErrorResponse {
102 message: String,
103}
104
105#[derive(Debug, Deserialize)]
106#[serde(untagged)]
107enum ApiResponse<T> {
108 Ok(T),
109 Err(ApiErrorResponse),
110}
111
112#[derive(Debug, Deserialize)]
113pub struct EmbeddingData {
114 pub object: String,
115 pub embedding: Vec<f64>,
116 pub index: usize,
117}
118
119#[derive(Clone, Debug, Deserialize, Serialize)]
120pub struct Usage {
121 pub prompt_tokens: usize,
122 pub total_tokens: usize,
123}
124
125impl std::fmt::Display for Usage {
126 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127 write!(
128 f,
129 "Prompt tokens: {} Total tokens: {}",
130 self.prompt_tokens, self.total_tokens
131 )
132 }
133}
134
135pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
141pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
143pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
145pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
147pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
149pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
151pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
153pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
155pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
157pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
159pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
161pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
163
164#[derive(Debug, Deserialize, Serialize)]
168pub struct CompletionResponse {
169 pub id: String,
170 pub object: String,
171 pub created: u64,
172 pub model: String,
173 pub choices: Vec<Choice>,
174 pub usage: Option<Usage>,
175}
176
177impl From<ApiErrorResponse> for CompletionError {
178 fn from(err: ApiErrorResponse) -> Self {
179 CompletionError::ProviderError(err.message)
180 }
181}
182
183impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
184 type Error = CompletionError;
185
186 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
187 let choice = response.choices.first().ok_or_else(|| {
188 CompletionError::ResponseError("Response contained no choices".to_owned())
189 })?;
190
191 let content = match &choice.message {
192 Message::Assistant {
193 content,
194 tool_calls,
195 ..
196 } => {
197 let mut content = content
198 .iter()
199 .map(|c| match c {
200 AssistantContent::Text { text } => completion::AssistantContent::text(text),
201 AssistantContent::Refusal { refusal } => {
202 completion::AssistantContent::text(refusal)
203 }
204 })
205 .collect::<Vec<_>>();
206
207 content.extend(
208 tool_calls
209 .iter()
210 .map(|call| {
211 completion::AssistantContent::tool_call(
212 &call.id,
213 &call.function.name,
214 call.function.arguments.clone(),
215 )
216 })
217 .collect::<Vec<_>>(),
218 );
219 Ok(content)
220 }
221 _ => Err(CompletionError::ResponseError(
222 "Response did not contain a valid message or tool call".into(),
223 )),
224 }?;
225
226 let choice = OneOrMany::many(content).map_err(|_| {
227 CompletionError::ResponseError(
228 "Response contained no message or tool call (empty)".to_owned(),
229 )
230 })?;
231
232 let usage = response
233 .usage
234 .as_ref()
235 .map(|usage| completion::Usage {
236 input_tokens: usage.prompt_tokens as u64,
237 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
238 total_tokens: usage.total_tokens as u64,
239 cached_input_tokens: 0,
240 cache_creation_input_tokens: 0,
241 tool_use_prompt_tokens: 0,
242 reasoning_tokens: 0,
243 })
244 .unwrap_or_default();
245
246 Ok(completion::CompletionResponse {
247 choice,
248 usage,
249 raw_response: response,
250 message_id: None,
251 })
252 }
253}
254
255#[derive(Debug, Deserialize, Serialize)]
256pub struct Choice {
257 pub index: usize,
258 pub message: Message,
259 pub finish_reason: String,
260}
261
262#[derive(Debug, Serialize, Deserialize)]
263pub(super) struct HyperbolicCompletionRequest {
264 model: String,
265 pub messages: Vec<Message>,
266 #[serde(skip_serializing_if = "Option::is_none")]
267 temperature: Option<f64>,
268 #[serde(flatten, skip_serializing_if = "Option::is_none")]
269 pub additional_params: Option<serde_json::Value>,
270}
271
272impl TryFrom<(&str, CompletionRequest)> for HyperbolicCompletionRequest {
273 type Error = CompletionError;
274
275 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
276 let chat_history = req.chat_history_with_documents();
277 if req.output_schema.is_some() {
278 tracing::warn!("Structured outputs currently not supported for Hyperbolic");
279 }
280
281 let model = req.model.clone().unwrap_or_else(|| model.to_string());
282 if req.tool_choice.is_some() {
283 tracing::warn!("WARNING: `tool_choice` not supported on Hyperbolic");
284 }
285
286 if !req.tools.is_empty() {
287 tracing::warn!("WARNING: `tools` not supported on Hyperbolic");
288 }
289
290 let mut full_history: Vec<Message> = match &req.preamble {
291 Some(preamble) => vec![Message::system(preamble)],
292 None => vec![],
293 };
294
295 let chat_history: Vec<Message> = chat_history
296 .into_iter()
297 .map(|message| message.try_into())
298 .collect::<Result<Vec<Vec<Message>>, _>>()?
299 .into_iter()
300 .flatten()
301 .collect();
302
303 full_history.extend(chat_history);
304
305 Ok(Self {
306 model: model.to_string(),
307 messages: full_history,
308 temperature: req.temperature,
309 additional_params: req.additional_params,
310 })
311 }
312}
313
314#[derive(Clone)]
315pub struct CompletionModel<T = reqwest::Client> {
316 client: Client<T>,
317 pub model: String,
319}
320
321impl<T> CompletionModel<T> {
322 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
323 Self {
324 client,
325 model: model.into(),
326 }
327 }
328
329 pub fn with_model(client: Client<T>, model: &str) -> Self {
330 Self {
331 client,
332 model: model.into(),
333 }
334 }
335}
336
337impl<T> completion::CompletionModel for CompletionModel<T>
338where
339 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
340{
341 type Response = CompletionResponse;
342 type StreamingResponse = openai::StreamingCompletionResponse;
343
344 type Client = Client<T>;
345
346 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
347 Self::new(client.clone(), model)
348 }
349
350 async fn completion(
351 &self,
352 completion_request: CompletionRequest,
353 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
354 let span = if tracing::Span::current().is_disabled() {
355 info_span!(
356 target: "rig::completions",
357 "chat",
358 gen_ai.operation.name = "chat",
359 gen_ai.provider.name = "hyperbolic",
360 gen_ai.request.model = self.model,
361 gen_ai.system_instructions = tracing::field::Empty,
362 gen_ai.response.id = tracing::field::Empty,
363 gen_ai.response.model = tracing::field::Empty,
364 gen_ai.usage.output_tokens = tracing::field::Empty,
365 gen_ai.usage.input_tokens = tracing::field::Empty,
366 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
367 )
368 } else {
369 tracing::Span::current()
370 };
371
372 span.record("gen_ai.system_instructions", &completion_request.preamble);
373 let request =
374 HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
375
376 if tracing::enabled!(tracing::Level::TRACE) {
377 tracing::trace!(target: "rig::completions",
378 "Hyperbolic completion request: {}",
379 serde_json::to_string_pretty(&request)?
380 );
381 }
382
383 let body = serde_json::to_vec(&request)?;
384
385 let req = self
386 .client
387 .post("/v1/chat/completions")?
388 .body(body)
389 .map_err(http_client::Error::from)?;
390
391 let async_block = async move {
392 let response = self.client.send::<_, bytes::Bytes>(req).await?;
393
394 let status = response.status();
395 let response_body = response.into_body().into_future().await?.to_vec();
396
397 if status.is_success() {
398 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
399 ApiResponse::Ok(response) => {
400 if tracing::enabled!(tracing::Level::TRACE) {
401 tracing::trace!(target: "rig::completions",
402 "Hyperbolic completion response: {}",
403 serde_json::to_string_pretty(&response)?
404 );
405 }
406
407 response.try_into()
408 }
409 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
410 }
411 } else {
412 Err(CompletionError::ProviderError(
413 String::from_utf8_lossy(&response_body).to_string(),
414 ))
415 }
416 };
417
418 async_block.instrument(span).await
419 }
420
421 async fn stream(
422 &self,
423 completion_request: CompletionRequest,
424 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
425 let span = if tracing::Span::current().is_disabled() {
426 info_span!(
427 target: "rig::completions",
428 "chat_streaming",
429 gen_ai.operation.name = "chat_streaming",
430 gen_ai.provider.name = "hyperbolic",
431 gen_ai.request.model = self.model,
432 gen_ai.system_instructions = tracing::field::Empty,
433 gen_ai.response.id = tracing::field::Empty,
434 gen_ai.response.model = tracing::field::Empty,
435 gen_ai.usage.output_tokens = tracing::field::Empty,
436 gen_ai.usage.input_tokens = tracing::field::Empty,
437 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
438 )
439 } else {
440 tracing::Span::current()
441 };
442
443 span.record("gen_ai.system_instructions", &completion_request.preamble);
444 let mut request =
445 HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
446
447 let params = json_utils::merge(
448 request.additional_params.unwrap_or(serde_json::json!({})),
449 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
450 );
451
452 request.additional_params = Some(params);
453
454 if tracing::enabled!(tracing::Level::TRACE) {
455 tracing::trace!(target: "rig::completions",
456 "Hyperbolic streaming completion request: {}",
457 serde_json::to_string_pretty(&request)?
458 );
459 }
460
461 let body = serde_json::to_vec(&request)?;
462
463 let req = self
464 .client
465 .post("/v1/chat/completions")?
466 .body(body)
467 .map_err(http_client::Error::from)?;
468
469 send_compatible_streaming_request(self.client.clone(), req)
470 .instrument(span)
471 .await
472 }
473}
474
475#[cfg(feature = "image")]
480pub use image_generation::*;
481
482#[cfg(feature = "image")]
483#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
484mod image_generation {
485 use super::{ApiResponse, Client};
486 use crate::http_client::HttpClientExt;
487 use crate::image_generation;
488 use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
489 use crate::json_utils::merge_inplace;
490 use base64::Engine;
491 use base64::prelude::BASE64_STANDARD;
492 use serde::Deserialize;
493 use serde_json::json;
494
495 pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
496 pub const SD2: &str = "SD2";
497 pub const SD1_5: &str = "SD1.5";
498 pub const SSD: &str = "SSD";
499 pub const SDXL_TURBO: &str = "SDXL-turbo";
500 pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
501 pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
502
503 #[derive(Clone)]
504 pub struct ImageGenerationModel<T> {
505 client: Client<T>,
506 pub model: String,
507 }
508
509 impl<T> ImageGenerationModel<T> {
510 pub(crate) fn new(client: Client<T>, model: impl Into<String>) -> Self {
511 Self {
512 client,
513 model: model.into(),
514 }
515 }
516
517 pub fn with_model(client: Client<T>, model: &str) -> Self {
518 Self {
519 client,
520 model: model.into(),
521 }
522 }
523 }
524
525 #[derive(Clone, Deserialize)]
526 pub struct Image {
527 image: String,
528 }
529
530 #[derive(Clone, Deserialize)]
531 pub struct ImageGenerationResponse {
532 images: Vec<Image>,
533 }
534
535 impl TryFrom<ImageGenerationResponse>
536 for image_generation::ImageGenerationResponse<ImageGenerationResponse>
537 {
538 type Error = ImageGenerationError;
539
540 fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
541 let image = value
542 .images
543 .first()
544 .ok_or_else(|| ImageGenerationError::ResponseError("missing image data".into()))?;
545 let data = BASE64_STANDARD
546 .decode(&image.image)
547 .map_err(|err| ImageGenerationError::ResponseError(err.to_string()))?;
548
549 Ok(Self {
550 image: data,
551 response: value,
552 })
553 }
554 }
555
556 impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
557 where
558 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
559 {
560 type Response = ImageGenerationResponse;
561
562 type Client = Client<T>;
563
564 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
565 Self::new(client.clone(), model)
566 }
567
568 async fn image_generation(
569 &self,
570 generation_request: ImageGenerationRequest,
571 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
572 {
573 let mut request = json!({
574 "model_name": self.model,
575 "prompt": generation_request.prompt,
576 "height": generation_request.height,
577 "width": generation_request.width,
578 });
579
580 if let Some(params) = generation_request.additional_params {
581 merge_inplace(&mut request, params);
582 }
583
584 let body = serde_json::to_vec(&request)?;
585
586 let request = self
587 .client
588 .post("/v1/image/generation")?
589 .header("Content-Type", "application/json")
590 .body(body)
591 .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
592
593 let response = self.client.send::<_, bytes::Bytes>(request).await?;
594
595 let status = response.status();
596 let response_body = response.into_body().into_future().await?.to_vec();
597
598 if !status.is_success() {
599 return Err(ImageGenerationError::ProviderError(format!(
600 "{status}: {}",
601 String::from_utf8_lossy(&response_body)
602 )));
603 }
604
605 match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
606 ApiResponse::Ok(response) => response.try_into(),
607 ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
608 }
609 }
610 }
611}
612
613#[cfg(feature = "audio")]
617pub use audio_generation::*;
618use tracing::{Instrument, info_span};
619
620#[cfg(feature = "audio")]
621#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
622mod audio_generation {
623 use super::{ApiResponse, Client};
624 use crate::audio_generation;
625 use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
626 use crate::http_client::{self, HttpClientExt};
627 use base64::Engine;
628 use base64::prelude::BASE64_STANDARD;
629 use bytes::Bytes;
630 use serde::Deserialize;
631 use serde_json::json;
632
633 #[derive(Clone)]
634 pub struct AudioGenerationModel<T> {
635 client: Client<T>,
636 pub language: String,
637 }
638
639 #[derive(Clone, Deserialize)]
640 pub struct AudioGenerationResponse {
641 audio: String,
642 }
643
644 impl TryFrom<AudioGenerationResponse>
645 for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
646 {
647 type Error = AudioGenerationError;
648
649 fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
650 let data = BASE64_STANDARD
651 .decode(&value.audio)
652 .map_err(|err| AudioGenerationError::ResponseError(err.to_string()))?;
653
654 Ok(Self {
655 audio: data,
656 response: value,
657 })
658 }
659 }
660
661 impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
662 where
663 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
664 {
665 type Response = AudioGenerationResponse;
666 type Client = Client<T>;
667
668 fn make(client: &Self::Client, language: impl Into<String>) -> Self {
669 Self {
670 client: client.clone(),
671 language: language.into(),
672 }
673 }
674
675 async fn audio_generation(
676 &self,
677 request: AudioGenerationRequest,
678 ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
679 {
680 let request = json!({
681 "language": self.language,
682 "speaker": request.voice,
683 "text": request.text,
684 "speed": request.speed
685 });
686
687 let body = serde_json::to_vec(&request)?;
688
689 let req = self
690 .client
691 .post("/v1/audio/generation")?
692 .body(body)
693 .map_err(http_client::Error::from)?;
694
695 let response = self.client.send::<_, Bytes>(req).await?;
696 let status = response.status();
697 let response_body = response.into_body().into_future().await?.to_vec();
698
699 if !status.is_success() {
700 return Err(AudioGenerationError::ProviderError(format!(
701 "{status}: {}",
702 String::from_utf8_lossy(&response_body)
703 )));
704 }
705
706 match serde_json::from_slice::<ApiResponse<AudioGenerationResponse>>(&response_body)? {
707 ApiResponse::Ok(response) => response.try_into(),
708 ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
709 }
710 }
711 }
712}
713
714#[cfg(test)]
715mod tests {
716 #[test]
717 fn test_client_initialization() {
718 let _client =
719 crate::providers::hyperbolic::Client::new("dummy-key").expect("Client::new() failed");
720 let _client_from_builder = crate::providers::hyperbolic::Client::builder()
721 .api_key("dummy-key")
722 .build()
723 .expect("Client::builder() failed");
724 }
725}