1use super::openai::{AssistantContent, send_compatible_streaming_request};
12
13use crate::client::{self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder};
14use crate::client::{BearerAuth, ProviderClient};
15use crate::http_client::{self, HttpClientExt};
16use crate::streaming::StreamingCompletionResponse;
17
18use crate::providers::openai;
19use crate::{
20 OneOrMany,
21 completion::{self, CompletionError, CompletionRequest},
22 json_utils,
23 providers::openai::Message,
24};
25use serde::{Deserialize, Serialize};
26
27const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz";
31
32#[derive(Debug, Default, Clone, Copy)]
33pub struct HyperbolicExt;
34#[derive(Debug, Default, Clone, Copy)]
35pub struct HyperbolicBuilder;
36
37type HyperbolicApiKey = BearerAuth;
38
39impl Provider for HyperbolicExt {
40 type Builder = HyperbolicBuilder;
41
42 const VERIFY_PATH: &'static str = "/models";
43}
44
45impl<H> Capabilities<H> for HyperbolicExt {
46 type Completion = Capable<CompletionModel<H>>;
47 type Embeddings = Nothing;
48 type Transcription = Nothing;
49 type ModelListing = Nothing;
50 #[cfg(feature = "image")]
51 type ImageGeneration = Capable<ImageGenerationModel<H>>;
52 #[cfg(feature = "audio")]
53 type AudioGeneration = Capable<AudioGenerationModel<H>>;
54}
55
56impl DebugExt for HyperbolicExt {}
57
58impl ProviderBuilder for HyperbolicBuilder {
59 type Extension<H>
60 = HyperbolicExt
61 where
62 H: HttpClientExt;
63 type ApiKey = HyperbolicApiKey;
64
65 const BASE_URL: &'static str = HYPERBOLIC_API_BASE_URL;
66
67 fn build<H>(
68 _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
69 ) -> http_client::Result<Self::Extension<H>>
70 where
71 H: HttpClientExt,
72 {
73 Ok(HyperbolicExt)
74 }
75}
76
77pub type Client<H = reqwest::Client> = client::Client<HyperbolicExt, H>;
78pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<HyperbolicBuilder, String, H>;
79
80impl ProviderClient for Client {
81 type Input = HyperbolicApiKey;
82
83 fn from_env() -> Self {
86 let api_key = std::env::var("HYPERBOLIC_API_KEY").expect("HYPERBOLIC_API_KEY not set");
87 Self::new(&api_key).unwrap()
88 }
89
90 fn from_val(input: Self::Input) -> Self {
91 Self::new(input).unwrap()
92 }
93}
94
95#[derive(Debug, Deserialize)]
96struct ApiErrorResponse {
97 message: String,
98}
99
100#[derive(Debug, Deserialize)]
101#[serde(untagged)]
102enum ApiResponse<T> {
103 Ok(T),
104 Err(ApiErrorResponse),
105}
106
107#[derive(Debug, Deserialize)]
108pub struct EmbeddingData {
109 pub object: String,
110 pub embedding: Vec<f64>,
111 pub index: usize,
112}
113
114#[derive(Clone, Debug, Deserialize, Serialize)]
115pub struct Usage {
116 pub prompt_tokens: usize,
117 pub total_tokens: usize,
118}
119
120impl std::fmt::Display for Usage {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 write!(
123 f,
124 "Prompt tokens: {} Total tokens: {}",
125 self.prompt_tokens, self.total_tokens
126 )
127 }
128}
129
130pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
136pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
138pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
140pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
142pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
144pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
146pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
148pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
150pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
152pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
154pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
156pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
158
159#[derive(Debug, Deserialize, Serialize)]
163pub struct CompletionResponse {
164 pub id: String,
165 pub object: String,
166 pub created: u64,
167 pub model: String,
168 pub choices: Vec<Choice>,
169 pub usage: Option<Usage>,
170}
171
172impl From<ApiErrorResponse> for CompletionError {
173 fn from(err: ApiErrorResponse) -> Self {
174 CompletionError::ProviderError(err.message)
175 }
176}
177
178impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
179 type Error = CompletionError;
180
181 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
182 let choice = response.choices.first().ok_or_else(|| {
183 CompletionError::ResponseError("Response contained no choices".to_owned())
184 })?;
185
186 let content = match &choice.message {
187 Message::Assistant {
188 content,
189 tool_calls,
190 ..
191 } => {
192 let mut content = content
193 .iter()
194 .map(|c| match c {
195 AssistantContent::Text { text } => completion::AssistantContent::text(text),
196 AssistantContent::Refusal { refusal } => {
197 completion::AssistantContent::text(refusal)
198 }
199 })
200 .collect::<Vec<_>>();
201
202 content.extend(
203 tool_calls
204 .iter()
205 .map(|call| {
206 completion::AssistantContent::tool_call(
207 &call.id,
208 &call.function.name,
209 call.function.arguments.clone(),
210 )
211 })
212 .collect::<Vec<_>>(),
213 );
214 Ok(content)
215 }
216 _ => Err(CompletionError::ResponseError(
217 "Response did not contain a valid message or tool call".into(),
218 )),
219 }?;
220
221 let choice = OneOrMany::many(content).map_err(|_| {
222 CompletionError::ResponseError(
223 "Response contained no message or tool call (empty)".to_owned(),
224 )
225 })?;
226
227 let usage = response
228 .usage
229 .as_ref()
230 .map(|usage| completion::Usage {
231 input_tokens: usage.prompt_tokens as u64,
232 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
233 total_tokens: usage.total_tokens as u64,
234 cached_input_tokens: 0,
235 })
236 .unwrap_or_default();
237
238 Ok(completion::CompletionResponse {
239 choice,
240 usage,
241 raw_response: response,
242 message_id: None,
243 })
244 }
245}
246
247#[derive(Debug, Deserialize, Serialize)]
248pub struct Choice {
249 pub index: usize,
250 pub message: Message,
251 pub finish_reason: String,
252}
253
254#[derive(Debug, Serialize, Deserialize)]
255pub(super) struct HyperbolicCompletionRequest {
256 model: String,
257 pub messages: Vec<Message>,
258 #[serde(skip_serializing_if = "Option::is_none")]
259 temperature: Option<f64>,
260 #[serde(flatten, skip_serializing_if = "Option::is_none")]
261 pub additional_params: Option<serde_json::Value>,
262}
263
264impl TryFrom<(&str, CompletionRequest)> for HyperbolicCompletionRequest {
265 type Error = CompletionError;
266
267 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
268 if req.output_schema.is_some() {
269 tracing::warn!("Structured outputs currently not supported for Hyperbolic");
270 }
271
272 let model = req.model.clone().unwrap_or_else(|| model.to_string());
273 if req.tool_choice.is_some() {
274 tracing::warn!("WARNING: `tool_choice` not supported on Hyperbolic");
275 }
276
277 if !req.tools.is_empty() {
278 tracing::warn!("WARNING: `tools` not supported on Hyperbolic");
279 }
280
281 let mut full_history: Vec<Message> = match &req.preamble {
282 Some(preamble) => vec![Message::system(preamble)],
283 None => vec![],
284 };
285
286 if let Some(docs) = req.normalized_documents() {
287 let docs: Vec<Message> = docs.try_into()?;
288 full_history.extend(docs);
289 }
290
291 let chat_history: Vec<Message> = req
292 .chat_history
293 .clone()
294 .into_iter()
295 .map(|message| message.try_into())
296 .collect::<Result<Vec<Vec<Message>>, _>>()?
297 .into_iter()
298 .flatten()
299 .collect();
300
301 full_history.extend(chat_history);
302
303 Ok(Self {
304 model: model.to_string(),
305 messages: full_history,
306 temperature: req.temperature,
307 additional_params: req.additional_params,
308 })
309 }
310}
311
312#[derive(Clone)]
313pub struct CompletionModel<T = reqwest::Client> {
314 client: Client<T>,
315 pub model: String,
317}
318
319impl<T> CompletionModel<T> {
320 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
321 Self {
322 client,
323 model: model.into(),
324 }
325 }
326
327 pub fn with_model(client: Client<T>, model: &str) -> Self {
328 Self {
329 client,
330 model: model.into(),
331 }
332 }
333}
334
335impl<T> completion::CompletionModel for CompletionModel<T>
336where
337 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
338{
339 type Response = CompletionResponse;
340 type StreamingResponse = openai::StreamingCompletionResponse;
341
342 type Client = Client<T>;
343
344 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
345 Self::new(client.clone(), model)
346 }
347
348 async fn completion(
349 &self,
350 completion_request: CompletionRequest,
351 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
352 let span = if tracing::Span::current().is_disabled() {
353 info_span!(
354 target: "rig::completions",
355 "chat",
356 gen_ai.operation.name = "chat",
357 gen_ai.provider.name = "hyperbolic",
358 gen_ai.request.model = self.model,
359 gen_ai.system_instructions = tracing::field::Empty,
360 gen_ai.response.id = tracing::field::Empty,
361 gen_ai.response.model = tracing::field::Empty,
362 gen_ai.usage.output_tokens = tracing::field::Empty,
363 gen_ai.usage.input_tokens = tracing::field::Empty,
364 gen_ai.usage.cached_tokens = tracing::field::Empty,
365 )
366 } else {
367 tracing::Span::current()
368 };
369
370 span.record("gen_ai.system_instructions", &completion_request.preamble);
371 let request =
372 HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
373
374 if tracing::enabled!(tracing::Level::TRACE) {
375 tracing::trace!(target: "rig::completions",
376 "Hyperbolic completion request: {}",
377 serde_json::to_string_pretty(&request)?
378 );
379 }
380
381 let body = serde_json::to_vec(&request)?;
382
383 let req = self
384 .client
385 .post("/v1/chat/completions")?
386 .body(body)
387 .map_err(http_client::Error::from)?;
388
389 let async_block = async move {
390 let response = self.client.send::<_, bytes::Bytes>(req).await?;
391
392 let status = response.status();
393 let response_body = response.into_body().into_future().await?.to_vec();
394
395 if status.is_success() {
396 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
397 ApiResponse::Ok(response) => {
398 if tracing::enabled!(tracing::Level::TRACE) {
399 tracing::trace!(target: "rig::completions",
400 "Hyperbolic completion response: {}",
401 serde_json::to_string_pretty(&response)?
402 );
403 }
404
405 response.try_into()
406 }
407 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
408 }
409 } else {
410 Err(CompletionError::ProviderError(
411 String::from_utf8_lossy(&response_body).to_string(),
412 ))
413 }
414 };
415
416 async_block.instrument(span).await
417 }
418
419 async fn stream(
420 &self,
421 completion_request: CompletionRequest,
422 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
423 let span = if tracing::Span::current().is_disabled() {
424 info_span!(
425 target: "rig::completions",
426 "chat_streaming",
427 gen_ai.operation.name = "chat_streaming",
428 gen_ai.provider.name = "hyperbolic",
429 gen_ai.request.model = self.model,
430 gen_ai.system_instructions = tracing::field::Empty,
431 gen_ai.response.id = tracing::field::Empty,
432 gen_ai.response.model = tracing::field::Empty,
433 gen_ai.usage.output_tokens = tracing::field::Empty,
434 gen_ai.usage.input_tokens = tracing::field::Empty,
435 gen_ai.usage.cached_tokens = tracing::field::Empty,
436 )
437 } else {
438 tracing::Span::current()
439 };
440
441 span.record("gen_ai.system_instructions", &completion_request.preamble);
442 let mut request =
443 HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
444
445 let params = json_utils::merge(
446 request.additional_params.unwrap_or(serde_json::json!({})),
447 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
448 );
449
450 request.additional_params = Some(params);
451
452 if tracing::enabled!(tracing::Level::TRACE) {
453 tracing::trace!(target: "rig::completions",
454 "Hyperbolic streaming completion request: {}",
455 serde_json::to_string_pretty(&request)?
456 );
457 }
458
459 let body = serde_json::to_vec(&request)?;
460
461 let req = self
462 .client
463 .post("/v1/chat/completions")?
464 .body(body)
465 .map_err(http_client::Error::from)?;
466
467 send_compatible_streaming_request(self.client.clone(), req)
468 .instrument(span)
469 .await
470 }
471}
472
473#[cfg(feature = "image")]
478pub use image_generation::*;
479
480#[cfg(feature = "image")]
481#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
482mod image_generation {
483 use super::{ApiResponse, Client};
484 use crate::http_client::HttpClientExt;
485 use crate::image_generation;
486 use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
487 use crate::json_utils::merge_inplace;
488 use base64::Engine;
489 use base64::prelude::BASE64_STANDARD;
490 use serde::Deserialize;
491 use serde_json::json;
492
493 pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
494 pub const SD2: &str = "SD2";
495 pub const SD1_5: &str = "SD1.5";
496 pub const SSD: &str = "SSD";
497 pub const SDXL_TURBO: &str = "SDXL-turbo";
498 pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
499 pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
500
501 #[derive(Clone)]
502 pub struct ImageGenerationModel<T> {
503 client: Client<T>,
504 pub model: String,
505 }
506
507 impl<T> ImageGenerationModel<T> {
508 pub(crate) fn new(client: Client<T>, model: impl Into<String>) -> Self {
509 Self {
510 client,
511 model: model.into(),
512 }
513 }
514
515 pub fn with_model(client: Client<T>, model: &str) -> Self {
516 Self {
517 client,
518 model: model.into(),
519 }
520 }
521 }
522
523 #[derive(Clone, Deserialize)]
524 pub struct Image {
525 image: String,
526 }
527
528 #[derive(Clone, Deserialize)]
529 pub struct ImageGenerationResponse {
530 images: Vec<Image>,
531 }
532
533 impl TryFrom<ImageGenerationResponse>
534 for image_generation::ImageGenerationResponse<ImageGenerationResponse>
535 {
536 type Error = ImageGenerationError;
537
538 fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
539 let data = BASE64_STANDARD
540 .decode(&value.images[0].image)
541 .expect("Could not decode image.");
542
543 Ok(Self {
544 image: data,
545 response: value,
546 })
547 }
548 }
549
550 impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
551 where
552 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
553 {
554 type Response = ImageGenerationResponse;
555
556 type Client = Client<T>;
557
558 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
559 Self::new(client.clone(), model)
560 }
561
562 async fn image_generation(
563 &self,
564 generation_request: ImageGenerationRequest,
565 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
566 {
567 let mut request = json!({
568 "model_name": self.model,
569 "prompt": generation_request.prompt,
570 "height": generation_request.height,
571 "width": generation_request.width,
572 });
573
574 if let Some(params) = generation_request.additional_params {
575 merge_inplace(&mut request, params);
576 }
577
578 let body = serde_json::to_vec(&request)?;
579
580 let request = self
581 .client
582 .post("/v1/image/generation")?
583 .header("Content-Type", "application/json")
584 .body(body)
585 .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
586
587 let response = self.client.send::<_, bytes::Bytes>(request).await?;
588
589 let status = response.status();
590 let response_body = response.into_body().into_future().await?.to_vec();
591
592 if !status.is_success() {
593 return Err(ImageGenerationError::ProviderError(format!(
594 "{status}: {}",
595 String::from_utf8_lossy(&response_body)
596 )));
597 }
598
599 match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
600 ApiResponse::Ok(response) => response.try_into(),
601 ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
602 }
603 }
604 }
605}
606
607#[cfg(feature = "audio")]
611pub use audio_generation::*;
612use tracing::{Instrument, info_span};
613
614#[cfg(feature = "audio")]
615#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
616mod audio_generation {
617 use super::{ApiResponse, Client};
618 use crate::audio_generation;
619 use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
620 use crate::http_client::{self, HttpClientExt};
621 use base64::Engine;
622 use base64::prelude::BASE64_STANDARD;
623 use bytes::Bytes;
624 use serde::Deserialize;
625 use serde_json::json;
626
627 #[derive(Clone)]
628 pub struct AudioGenerationModel<T> {
629 client: Client<T>,
630 pub language: String,
631 }
632
633 #[derive(Clone, Deserialize)]
634 pub struct AudioGenerationResponse {
635 audio: String,
636 }
637
638 impl TryFrom<AudioGenerationResponse>
639 for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
640 {
641 type Error = AudioGenerationError;
642
643 fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
644 let data = BASE64_STANDARD
645 .decode(&value.audio)
646 .expect("Could not decode audio.");
647
648 Ok(Self {
649 audio: data,
650 response: value,
651 })
652 }
653 }
654
655 impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
656 where
657 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
658 {
659 type Response = AudioGenerationResponse;
660 type Client = Client<T>;
661
662 fn make(client: &Self::Client, language: impl Into<String>) -> Self {
663 Self {
664 client: client.clone(),
665 language: language.into(),
666 }
667 }
668
669 async fn audio_generation(
670 &self,
671 request: AudioGenerationRequest,
672 ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
673 {
674 let request = json!({
675 "language": self.language,
676 "speaker": request.voice,
677 "text": request.text,
678 "speed": request.speed
679 });
680
681 let body = serde_json::to_vec(&request)?;
682
683 let req = self
684 .client
685 .post("/v1/audio/generation")?
686 .body(body)
687 .map_err(http_client::Error::from)?;
688
689 let response = self.client.send::<_, Bytes>(req).await?;
690 let status = response.status();
691 let response_body = response.into_body().into_future().await?.to_vec();
692
693 if !status.is_success() {
694 return Err(AudioGenerationError::ProviderError(format!(
695 "{status}: {}",
696 String::from_utf8_lossy(&response_body)
697 )));
698 }
699
700 match serde_json::from_slice::<ApiResponse<AudioGenerationResponse>>(&response_body)? {
701 ApiResponse::Ok(response) => response.try_into(),
702 ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
703 }
704 }
705 }
706}
707
708#[cfg(test)]
709mod tests {
710 #[test]
711 fn test_client_initialization() {
712 let _client =
713 crate::providers::hyperbolic::Client::new("dummy-key").expect("Client::new() failed");
714 let _client_from_builder = crate::providers::hyperbolic::Client::builder()
715 .api_key("dummy-key")
716 .build()
717 .expect("Client::builder() failed");
718 }
719}