1use super::openai::{AssistantContent, send_compatible_streaming_request};
12
13use crate::client::{self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder};
14use crate::client::{BearerAuth, ProviderClient};
15use crate::http_client::{self, HttpClientExt};
16use crate::streaming::StreamingCompletionResponse;
17
18use crate::providers::openai;
19use crate::{
20 OneOrMany,
21 completion::{self, CompletionError, CompletionRequest},
22 json_utils,
23 providers::openai::Message,
24};
25use serde::{Deserialize, Serialize};
26
27const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz";
31
32#[derive(Debug, Default, Clone, Copy)]
33pub struct HyperbolicExt;
34#[derive(Debug, Default, Clone, Copy)]
35pub struct HyperbolicBuilder;
36
37type HyperbolicApiKey = BearerAuth;
38
39impl Provider for HyperbolicExt {
40 type Builder = HyperbolicBuilder;
41
42 const VERIFY_PATH: &'static str = "/models";
43}
44
45impl<H> Capabilities<H> for HyperbolicExt {
46 type Completion = Capable<CompletionModel<H>>;
47 type Embeddings = Nothing;
48 type Transcription = Nothing;
49 type ModelListing = Nothing;
50 #[cfg(feature = "image")]
51 type ImageGeneration = Capable<ImageGenerationModel<H>>;
52 #[cfg(feature = "audio")]
53 type AudioGeneration = Capable<AudioGenerationModel<H>>;
54}
55
56impl DebugExt for HyperbolicExt {}
57
58impl ProviderBuilder for HyperbolicBuilder {
59 type Extension<H>
60 = HyperbolicExt
61 where
62 H: HttpClientExt;
63 type ApiKey = HyperbolicApiKey;
64
65 const BASE_URL: &'static str = HYPERBOLIC_API_BASE_URL;
66
67 fn build<H>(
68 _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
69 ) -> http_client::Result<Self::Extension<H>>
70 where
71 H: HttpClientExt,
72 {
73 Ok(HyperbolicExt)
74 }
75}
76
77pub type Client<H = reqwest::Client> = client::Client<HyperbolicExt, H>;
78pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<HyperbolicBuilder, String, H>;
79
80impl ProviderClient for Client {
81 type Input = HyperbolicApiKey;
82
83 fn from_env() -> Self {
86 let api_key = std::env::var("HYPERBOLIC_API_KEY").expect("HYPERBOLIC_API_KEY not set");
87 Self::new(&api_key).unwrap()
88 }
89
90 fn from_val(input: Self::Input) -> Self {
91 Self::new(input).unwrap()
92 }
93}
94
95#[derive(Debug, Deserialize)]
96struct ApiErrorResponse {
97 message: String,
98}
99
100#[derive(Debug, Deserialize)]
101#[serde(untagged)]
102enum ApiResponse<T> {
103 Ok(T),
104 Err(ApiErrorResponse),
105}
106
107#[derive(Debug, Deserialize)]
108pub struct EmbeddingData {
109 pub object: String,
110 pub embedding: Vec<f64>,
111 pub index: usize,
112}
113
114#[derive(Clone, Debug, Deserialize, Serialize)]
115pub struct Usage {
116 pub prompt_tokens: usize,
117 pub total_tokens: usize,
118}
119
120impl std::fmt::Display for Usage {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 write!(
123 f,
124 "Prompt tokens: {} Total tokens: {}",
125 self.prompt_tokens, self.total_tokens
126 )
127 }
128}
129
130pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
136pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
138pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
140pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
142pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
144pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
146pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
148pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
150pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
152pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
154pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
156pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
158
159#[derive(Debug, Deserialize, Serialize)]
163pub struct CompletionResponse {
164 pub id: String,
165 pub object: String,
166 pub created: u64,
167 pub model: String,
168 pub choices: Vec<Choice>,
169 pub usage: Option<Usage>,
170}
171
172impl From<ApiErrorResponse> for CompletionError {
173 fn from(err: ApiErrorResponse) -> Self {
174 CompletionError::ProviderError(err.message)
175 }
176}
177
178impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
179 type Error = CompletionError;
180
181 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
182 let choice = response.choices.first().ok_or_else(|| {
183 CompletionError::ResponseError("Response contained no choices".to_owned())
184 })?;
185
186 let content = match &choice.message {
187 Message::Assistant {
188 content,
189 tool_calls,
190 ..
191 } => {
192 let mut content = content
193 .iter()
194 .map(|c| match c {
195 AssistantContent::Text { text } => completion::AssistantContent::text(text),
196 AssistantContent::Refusal { refusal } => {
197 completion::AssistantContent::text(refusal)
198 }
199 })
200 .collect::<Vec<_>>();
201
202 content.extend(
203 tool_calls
204 .iter()
205 .map(|call| {
206 completion::AssistantContent::tool_call(
207 &call.id,
208 &call.function.name,
209 call.function.arguments.clone(),
210 )
211 })
212 .collect::<Vec<_>>(),
213 );
214 Ok(content)
215 }
216 _ => Err(CompletionError::ResponseError(
217 "Response did not contain a valid message or tool call".into(),
218 )),
219 }?;
220
221 let choice = OneOrMany::many(content).map_err(|_| {
222 CompletionError::ResponseError(
223 "Response contained no message or tool call (empty)".to_owned(),
224 )
225 })?;
226
227 let usage = response
228 .usage
229 .as_ref()
230 .map(|usage| completion::Usage {
231 input_tokens: usage.prompt_tokens as u64,
232 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
233 total_tokens: usage.total_tokens as u64,
234 cached_input_tokens: 0,
235 cache_creation_input_tokens: 0,
236 })
237 .unwrap_or_default();
238
239 Ok(completion::CompletionResponse {
240 choice,
241 usage,
242 raw_response: response,
243 message_id: None,
244 })
245 }
246}
247
248#[derive(Debug, Deserialize, Serialize)]
249pub struct Choice {
250 pub index: usize,
251 pub message: Message,
252 pub finish_reason: String,
253}
254
255#[derive(Debug, Serialize, Deserialize)]
256pub(super) struct HyperbolicCompletionRequest {
257 model: String,
258 pub messages: Vec<Message>,
259 #[serde(skip_serializing_if = "Option::is_none")]
260 temperature: Option<f64>,
261 #[serde(flatten, skip_serializing_if = "Option::is_none")]
262 pub additional_params: Option<serde_json::Value>,
263}
264
265impl TryFrom<(&str, CompletionRequest)> for HyperbolicCompletionRequest {
266 type Error = CompletionError;
267
268 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
269 if req.output_schema.is_some() {
270 tracing::warn!("Structured outputs currently not supported for Hyperbolic");
271 }
272
273 let model = req.model.clone().unwrap_or_else(|| model.to_string());
274 if req.tool_choice.is_some() {
275 tracing::warn!("WARNING: `tool_choice` not supported on Hyperbolic");
276 }
277
278 if !req.tools.is_empty() {
279 tracing::warn!("WARNING: `tools` not supported on Hyperbolic");
280 }
281
282 let mut full_history: Vec<Message> = match &req.preamble {
283 Some(preamble) => vec![Message::system(preamble)],
284 None => vec![],
285 };
286
287 if let Some(docs) = req.normalized_documents() {
288 let docs: Vec<Message> = docs.try_into()?;
289 full_history.extend(docs);
290 }
291
292 let chat_history: Vec<Message> = req
293 .chat_history
294 .clone()
295 .into_iter()
296 .map(|message| message.try_into())
297 .collect::<Result<Vec<Vec<Message>>, _>>()?
298 .into_iter()
299 .flatten()
300 .collect();
301
302 full_history.extend(chat_history);
303
304 Ok(Self {
305 model: model.to_string(),
306 messages: full_history,
307 temperature: req.temperature,
308 additional_params: req.additional_params,
309 })
310 }
311}
312
313#[derive(Clone)]
314pub struct CompletionModel<T = reqwest::Client> {
315 client: Client<T>,
316 pub model: String,
318}
319
320impl<T> CompletionModel<T> {
321 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
322 Self {
323 client,
324 model: model.into(),
325 }
326 }
327
328 pub fn with_model(client: Client<T>, model: &str) -> Self {
329 Self {
330 client,
331 model: model.into(),
332 }
333 }
334}
335
336impl<T> completion::CompletionModel for CompletionModel<T>
337where
338 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
339{
340 type Response = CompletionResponse;
341 type StreamingResponse = openai::StreamingCompletionResponse;
342
343 type Client = Client<T>;
344
345 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
346 Self::new(client.clone(), model)
347 }
348
349 async fn completion(
350 &self,
351 completion_request: CompletionRequest,
352 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
353 let span = if tracing::Span::current().is_disabled() {
354 info_span!(
355 target: "rig::completions",
356 "chat",
357 gen_ai.operation.name = "chat",
358 gen_ai.provider.name = "hyperbolic",
359 gen_ai.request.model = self.model,
360 gen_ai.system_instructions = tracing::field::Empty,
361 gen_ai.response.id = tracing::field::Empty,
362 gen_ai.response.model = tracing::field::Empty,
363 gen_ai.usage.output_tokens = tracing::field::Empty,
364 gen_ai.usage.input_tokens = tracing::field::Empty,
365 gen_ai.usage.cached_tokens = tracing::field::Empty,
366 )
367 } else {
368 tracing::Span::current()
369 };
370
371 span.record("gen_ai.system_instructions", &completion_request.preamble);
372 let request =
373 HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
374
375 if tracing::enabled!(tracing::Level::TRACE) {
376 tracing::trace!(target: "rig::completions",
377 "Hyperbolic completion request: {}",
378 serde_json::to_string_pretty(&request)?
379 );
380 }
381
382 let body = serde_json::to_vec(&request)?;
383
384 let req = self
385 .client
386 .post("/v1/chat/completions")?
387 .body(body)
388 .map_err(http_client::Error::from)?;
389
390 let async_block = async move {
391 let response = self.client.send::<_, bytes::Bytes>(req).await?;
392
393 let status = response.status();
394 let response_body = response.into_body().into_future().await?.to_vec();
395
396 if status.is_success() {
397 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
398 ApiResponse::Ok(response) => {
399 if tracing::enabled!(tracing::Level::TRACE) {
400 tracing::trace!(target: "rig::completions",
401 "Hyperbolic completion response: {}",
402 serde_json::to_string_pretty(&response)?
403 );
404 }
405
406 response.try_into()
407 }
408 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
409 }
410 } else {
411 Err(CompletionError::ProviderError(
412 String::from_utf8_lossy(&response_body).to_string(),
413 ))
414 }
415 };
416
417 async_block.instrument(span).await
418 }
419
420 async fn stream(
421 &self,
422 completion_request: CompletionRequest,
423 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
424 let span = if tracing::Span::current().is_disabled() {
425 info_span!(
426 target: "rig::completions",
427 "chat_streaming",
428 gen_ai.operation.name = "chat_streaming",
429 gen_ai.provider.name = "hyperbolic",
430 gen_ai.request.model = self.model,
431 gen_ai.system_instructions = tracing::field::Empty,
432 gen_ai.response.id = tracing::field::Empty,
433 gen_ai.response.model = tracing::field::Empty,
434 gen_ai.usage.output_tokens = tracing::field::Empty,
435 gen_ai.usage.input_tokens = tracing::field::Empty,
436 gen_ai.usage.cached_tokens = tracing::field::Empty,
437 )
438 } else {
439 tracing::Span::current()
440 };
441
442 span.record("gen_ai.system_instructions", &completion_request.preamble);
443 let mut request =
444 HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
445
446 let params = json_utils::merge(
447 request.additional_params.unwrap_or(serde_json::json!({})),
448 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
449 );
450
451 request.additional_params = Some(params);
452
453 if tracing::enabled!(tracing::Level::TRACE) {
454 tracing::trace!(target: "rig::completions",
455 "Hyperbolic streaming completion request: {}",
456 serde_json::to_string_pretty(&request)?
457 );
458 }
459
460 let body = serde_json::to_vec(&request)?;
461
462 let req = self
463 .client
464 .post("/v1/chat/completions")?
465 .body(body)
466 .map_err(http_client::Error::from)?;
467
468 send_compatible_streaming_request(self.client.clone(), req)
469 .instrument(span)
470 .await
471 }
472}
473
474#[cfg(feature = "image")]
479pub use image_generation::*;
480
481#[cfg(feature = "image")]
482#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
483mod image_generation {
484 use super::{ApiResponse, Client};
485 use crate::http_client::HttpClientExt;
486 use crate::image_generation;
487 use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
488 use crate::json_utils::merge_inplace;
489 use base64::Engine;
490 use base64::prelude::BASE64_STANDARD;
491 use serde::Deserialize;
492 use serde_json::json;
493
494 pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
495 pub const SD2: &str = "SD2";
496 pub const SD1_5: &str = "SD1.5";
497 pub const SSD: &str = "SSD";
498 pub const SDXL_TURBO: &str = "SDXL-turbo";
499 pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
500 pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
501
502 #[derive(Clone)]
503 pub struct ImageGenerationModel<T> {
504 client: Client<T>,
505 pub model: String,
506 }
507
508 impl<T> ImageGenerationModel<T> {
509 pub(crate) fn new(client: Client<T>, model: impl Into<String>) -> Self {
510 Self {
511 client,
512 model: model.into(),
513 }
514 }
515
516 pub fn with_model(client: Client<T>, model: &str) -> Self {
517 Self {
518 client,
519 model: model.into(),
520 }
521 }
522 }
523
524 #[derive(Clone, Deserialize)]
525 pub struct Image {
526 image: String,
527 }
528
529 #[derive(Clone, Deserialize)]
530 pub struct ImageGenerationResponse {
531 images: Vec<Image>,
532 }
533
534 impl TryFrom<ImageGenerationResponse>
535 for image_generation::ImageGenerationResponse<ImageGenerationResponse>
536 {
537 type Error = ImageGenerationError;
538
539 fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
540 let data = BASE64_STANDARD
541 .decode(&value.images[0].image)
542 .expect("Could not decode image.");
543
544 Ok(Self {
545 image: data,
546 response: value,
547 })
548 }
549 }
550
551 impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
552 where
553 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
554 {
555 type Response = ImageGenerationResponse;
556
557 type Client = Client<T>;
558
559 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
560 Self::new(client.clone(), model)
561 }
562
563 async fn image_generation(
564 &self,
565 generation_request: ImageGenerationRequest,
566 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
567 {
568 let mut request = json!({
569 "model_name": self.model,
570 "prompt": generation_request.prompt,
571 "height": generation_request.height,
572 "width": generation_request.width,
573 });
574
575 if let Some(params) = generation_request.additional_params {
576 merge_inplace(&mut request, params);
577 }
578
579 let body = serde_json::to_vec(&request)?;
580
581 let request = self
582 .client
583 .post("/v1/image/generation")?
584 .header("Content-Type", "application/json")
585 .body(body)
586 .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
587
588 let response = self.client.send::<_, bytes::Bytes>(request).await?;
589
590 let status = response.status();
591 let response_body = response.into_body().into_future().await?.to_vec();
592
593 if !status.is_success() {
594 return Err(ImageGenerationError::ProviderError(format!(
595 "{status}: {}",
596 String::from_utf8_lossy(&response_body)
597 )));
598 }
599
600 match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
601 ApiResponse::Ok(response) => response.try_into(),
602 ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
603 }
604 }
605 }
606}
607
608#[cfg(feature = "audio")]
612pub use audio_generation::*;
613use tracing::{Instrument, info_span};
614
615#[cfg(feature = "audio")]
616#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
617mod audio_generation {
618 use super::{ApiResponse, Client};
619 use crate::audio_generation;
620 use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
621 use crate::http_client::{self, HttpClientExt};
622 use base64::Engine;
623 use base64::prelude::BASE64_STANDARD;
624 use bytes::Bytes;
625 use serde::Deserialize;
626 use serde_json::json;
627
628 #[derive(Clone)]
629 pub struct AudioGenerationModel<T> {
630 client: Client<T>,
631 pub language: String,
632 }
633
634 #[derive(Clone, Deserialize)]
635 pub struct AudioGenerationResponse {
636 audio: String,
637 }
638
639 impl TryFrom<AudioGenerationResponse>
640 for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
641 {
642 type Error = AudioGenerationError;
643
644 fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
645 let data = BASE64_STANDARD
646 .decode(&value.audio)
647 .expect("Could not decode audio.");
648
649 Ok(Self {
650 audio: data,
651 response: value,
652 })
653 }
654 }
655
656 impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
657 where
658 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
659 {
660 type Response = AudioGenerationResponse;
661 type Client = Client<T>;
662
663 fn make(client: &Self::Client, language: impl Into<String>) -> Self {
664 Self {
665 client: client.clone(),
666 language: language.into(),
667 }
668 }
669
670 async fn audio_generation(
671 &self,
672 request: AudioGenerationRequest,
673 ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
674 {
675 let request = json!({
676 "language": self.language,
677 "speaker": request.voice,
678 "text": request.text,
679 "speed": request.speed
680 });
681
682 let body = serde_json::to_vec(&request)?;
683
684 let req = self
685 .client
686 .post("/v1/audio/generation")?
687 .body(body)
688 .map_err(http_client::Error::from)?;
689
690 let response = self.client.send::<_, Bytes>(req).await?;
691 let status = response.status();
692 let response_body = response.into_body().into_future().await?.to_vec();
693
694 if !status.is_success() {
695 return Err(AudioGenerationError::ProviderError(format!(
696 "{status}: {}",
697 String::from_utf8_lossy(&response_body)
698 )));
699 }
700
701 match serde_json::from_slice::<ApiResponse<AudioGenerationResponse>>(&response_body)? {
702 ApiResponse::Ok(response) => response.try_into(),
703 ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
704 }
705 }
706 }
707}
708
709#[cfg(test)]
710mod tests {
711 #[test]
712 fn test_client_initialization() {
713 let _client =
714 crate::providers::hyperbolic::Client::new("dummy-key").expect("Client::new() failed");
715 let _client_from_builder = crate::providers::hyperbolic::Client::builder()
716 .api_key("dummy-key")
717 .build()
718 .expect("Client::builder() failed");
719 }
720}