1use super::openai::{AssistantContent, send_compatible_streaming_request};
12
13use crate::client::{self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder};
14use crate::client::{BearerAuth, ProviderClient};
15use crate::http_client::{self, HttpClientExt};
16use crate::streaming::StreamingCompletionResponse;
17
18use crate::providers::openai;
19use crate::{
20 OneOrMany,
21 completion::{self, CompletionError, CompletionRequest},
22 json_utils,
23 providers::openai::Message,
24};
25use serde::{Deserialize, Serialize};
26
27const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz";
31
32#[derive(Debug, Default, Clone, Copy)]
33pub struct HyperbolicExt;
34#[derive(Debug, Default, Clone, Copy)]
35pub struct HyperbolicBuilder;
36
37type HyperbolicApiKey = BearerAuth;
38
39impl Provider for HyperbolicExt {
40 type Builder = HyperbolicBuilder;
41
42 const VERIFY_PATH: &'static str = "/models";
43
44 fn build<H>(
45 _: &crate::client::ClientBuilder<
46 Self::Builder,
47 <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
48 H,
49 >,
50 ) -> http_client::Result<Self> {
51 Ok(Self)
52 }
53}
54
55impl<H> Capabilities<H> for HyperbolicExt {
56 type Completion = Capable<CompletionModel<H>>;
57 type Embeddings = Nothing;
58 type Transcription = Nothing;
59 type ModelListing = Nothing;
60 #[cfg(feature = "image")]
61 type ImageGeneration = Capable<ImageGenerationModel<H>>;
62 #[cfg(feature = "audio")]
63 type AudioGeneration = Capable<AudioGenerationModel<H>>;
64}
65
66impl DebugExt for HyperbolicExt {}
67
68impl ProviderBuilder for HyperbolicBuilder {
69 type Output = HyperbolicExt;
70 type ApiKey = HyperbolicApiKey;
71
72 const BASE_URL: &'static str = HYPERBOLIC_API_BASE_URL;
73}
74
75pub type Client<H = reqwest::Client> = client::Client<HyperbolicExt, H>;
76pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<HyperbolicBuilder, String, H>;
77
78impl ProviderClient for Client {
79 type Input = HyperbolicApiKey;
80
81 fn from_env() -> Self {
84 let api_key = std::env::var("HYPERBOLIC_API_KEY").expect("HYPERBOLIC_API_KEY not set");
85 Self::new(&api_key).unwrap()
86 }
87
88 fn from_val(input: Self::Input) -> Self {
89 Self::new(input).unwrap()
90 }
91}
92
93#[derive(Debug, Deserialize)]
94struct ApiErrorResponse {
95 message: String,
96}
97
98#[derive(Debug, Deserialize)]
99#[serde(untagged)]
100enum ApiResponse<T> {
101 Ok(T),
102 Err(ApiErrorResponse),
103}
104
105#[derive(Debug, Deserialize)]
106pub struct EmbeddingData {
107 pub object: String,
108 pub embedding: Vec<f64>,
109 pub index: usize,
110}
111
112#[derive(Clone, Debug, Deserialize, Serialize)]
113pub struct Usage {
114 pub prompt_tokens: usize,
115 pub total_tokens: usize,
116}
117
118impl std::fmt::Display for Usage {
119 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120 write!(
121 f,
122 "Prompt tokens: {} Total tokens: {}",
123 self.prompt_tokens, self.total_tokens
124 )
125 }
126}
127
128pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
134pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
136pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
138pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
140pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
142pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
144pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
146pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
148pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
150pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
152pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
154pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
156
157#[derive(Debug, Deserialize, Serialize)]
161pub struct CompletionResponse {
162 pub id: String,
163 pub object: String,
164 pub created: u64,
165 pub model: String,
166 pub choices: Vec<Choice>,
167 pub usage: Option<Usage>,
168}
169
170impl From<ApiErrorResponse> for CompletionError {
171 fn from(err: ApiErrorResponse) -> Self {
172 CompletionError::ProviderError(err.message)
173 }
174}
175
176impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
177 type Error = CompletionError;
178
179 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
180 let choice = response.choices.first().ok_or_else(|| {
181 CompletionError::ResponseError("Response contained no choices".to_owned())
182 })?;
183
184 let content = match &choice.message {
185 Message::Assistant {
186 content,
187 tool_calls,
188 ..
189 } => {
190 let mut content = content
191 .iter()
192 .map(|c| match c {
193 AssistantContent::Text { text } => completion::AssistantContent::text(text),
194 AssistantContent::Refusal { refusal } => {
195 completion::AssistantContent::text(refusal)
196 }
197 })
198 .collect::<Vec<_>>();
199
200 content.extend(
201 tool_calls
202 .iter()
203 .map(|call| {
204 completion::AssistantContent::tool_call(
205 &call.id,
206 &call.function.name,
207 call.function.arguments.clone(),
208 )
209 })
210 .collect::<Vec<_>>(),
211 );
212 Ok(content)
213 }
214 _ => Err(CompletionError::ResponseError(
215 "Response did not contain a valid message or tool call".into(),
216 )),
217 }?;
218
219 let choice = OneOrMany::many(content).map_err(|_| {
220 CompletionError::ResponseError(
221 "Response contained no message or tool call (empty)".to_owned(),
222 )
223 })?;
224
225 let usage = response
226 .usage
227 .as_ref()
228 .map(|usage| completion::Usage {
229 input_tokens: usage.prompt_tokens as u64,
230 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
231 total_tokens: usage.total_tokens as u64,
232 cached_input_tokens: 0,
233 })
234 .unwrap_or_default();
235
236 Ok(completion::CompletionResponse {
237 choice,
238 usage,
239 raw_response: response,
240 message_id: None,
241 })
242 }
243}
244
245#[derive(Debug, Deserialize, Serialize)]
246pub struct Choice {
247 pub index: usize,
248 pub message: Message,
249 pub finish_reason: String,
250}
251
252#[derive(Debug, Serialize, Deserialize)]
253pub(super) struct HyperbolicCompletionRequest {
254 model: String,
255 pub messages: Vec<Message>,
256 #[serde(skip_serializing_if = "Option::is_none")]
257 temperature: Option<f64>,
258 #[serde(flatten, skip_serializing_if = "Option::is_none")]
259 pub additional_params: Option<serde_json::Value>,
260}
261
262impl TryFrom<(&str, CompletionRequest)> for HyperbolicCompletionRequest {
263 type Error = CompletionError;
264
265 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
266 if req.output_schema.is_some() {
267 tracing::warn!("Structured outputs currently not supported for Hyperbolic");
268 }
269
270 let model = req.model.clone().unwrap_or_else(|| model.to_string());
271 if req.tool_choice.is_some() {
272 tracing::warn!("WARNING: `tool_choice` not supported on Hyperbolic");
273 }
274
275 if !req.tools.is_empty() {
276 tracing::warn!("WARNING: `tools` not supported on Hyperbolic");
277 }
278
279 let mut full_history: Vec<Message> = match &req.preamble {
280 Some(preamble) => vec![Message::system(preamble)],
281 None => vec![],
282 };
283
284 if let Some(docs) = req.normalized_documents() {
285 let docs: Vec<Message> = docs.try_into()?;
286 full_history.extend(docs);
287 }
288
289 let chat_history: Vec<Message> = req
290 .chat_history
291 .clone()
292 .into_iter()
293 .map(|message| message.try_into())
294 .collect::<Result<Vec<Vec<Message>>, _>>()?
295 .into_iter()
296 .flatten()
297 .collect();
298
299 full_history.extend(chat_history);
300
301 Ok(Self {
302 model: model.to_string(),
303 messages: full_history,
304 temperature: req.temperature,
305 additional_params: req.additional_params,
306 })
307 }
308}
309
310#[derive(Clone)]
311pub struct CompletionModel<T = reqwest::Client> {
312 client: Client<T>,
313 pub model: String,
315}
316
317impl<T> CompletionModel<T> {
318 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
319 Self {
320 client,
321 model: model.into(),
322 }
323 }
324
325 pub fn with_model(client: Client<T>, model: &str) -> Self {
326 Self {
327 client,
328 model: model.into(),
329 }
330 }
331}
332
333impl<T> completion::CompletionModel for CompletionModel<T>
334where
335 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
336{
337 type Response = CompletionResponse;
338 type StreamingResponse = openai::StreamingCompletionResponse;
339
340 type Client = Client<T>;
341
342 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
343 Self::new(client.clone(), model)
344 }
345
346 async fn completion(
347 &self,
348 completion_request: CompletionRequest,
349 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
350 let span = if tracing::Span::current().is_disabled() {
351 info_span!(
352 target: "rig::completions",
353 "chat",
354 gen_ai.operation.name = "chat",
355 gen_ai.provider.name = "hyperbolic",
356 gen_ai.request.model = self.model,
357 gen_ai.system_instructions = tracing::field::Empty,
358 gen_ai.response.id = tracing::field::Empty,
359 gen_ai.response.model = tracing::field::Empty,
360 gen_ai.usage.output_tokens = tracing::field::Empty,
361 gen_ai.usage.input_tokens = tracing::field::Empty,
362 )
363 } else {
364 tracing::Span::current()
365 };
366
367 span.record("gen_ai.system_instructions", &completion_request.preamble);
368 let request =
369 HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
370
371 if tracing::enabled!(tracing::Level::TRACE) {
372 tracing::trace!(target: "rig::completions",
373 "Hyperbolic completion request: {}",
374 serde_json::to_string_pretty(&request)?
375 );
376 }
377
378 let body = serde_json::to_vec(&request)?;
379
380 let req = self
381 .client
382 .post("/v1/chat/completions")?
383 .body(body)
384 .map_err(http_client::Error::from)?;
385
386 let async_block = async move {
387 let response = self.client.send::<_, bytes::Bytes>(req).await?;
388
389 let status = response.status();
390 let response_body = response.into_body().into_future().await?.to_vec();
391
392 if status.is_success() {
393 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
394 ApiResponse::Ok(response) => {
395 if tracing::enabled!(tracing::Level::TRACE) {
396 tracing::trace!(target: "rig::completions",
397 "Hyperbolic completion response: {}",
398 serde_json::to_string_pretty(&response)?
399 );
400 }
401
402 response.try_into()
403 }
404 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
405 }
406 } else {
407 Err(CompletionError::ProviderError(
408 String::from_utf8_lossy(&response_body).to_string(),
409 ))
410 }
411 };
412
413 async_block.instrument(span).await
414 }
415
416 async fn stream(
417 &self,
418 completion_request: CompletionRequest,
419 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
420 let span = if tracing::Span::current().is_disabled() {
421 info_span!(
422 target: "rig::completions",
423 "chat_streaming",
424 gen_ai.operation.name = "chat_streaming",
425 gen_ai.provider.name = "hyperbolic",
426 gen_ai.request.model = self.model,
427 gen_ai.system_instructions = tracing::field::Empty,
428 gen_ai.response.id = tracing::field::Empty,
429 gen_ai.response.model = tracing::field::Empty,
430 gen_ai.usage.output_tokens = tracing::field::Empty,
431 gen_ai.usage.input_tokens = tracing::field::Empty,
432 )
433 } else {
434 tracing::Span::current()
435 };
436
437 span.record("gen_ai.system_instructions", &completion_request.preamble);
438 let mut request =
439 HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
440
441 let params = json_utils::merge(
442 request.additional_params.unwrap_or(serde_json::json!({})),
443 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
444 );
445
446 request.additional_params = Some(params);
447
448 if tracing::enabled!(tracing::Level::TRACE) {
449 tracing::trace!(target: "rig::completions",
450 "Hyperbolic streaming completion request: {}",
451 serde_json::to_string_pretty(&request)?
452 );
453 }
454
455 let body = serde_json::to_vec(&request)?;
456
457 let req = self
458 .client
459 .post("/v1/chat/completions")?
460 .body(body)
461 .map_err(http_client::Error::from)?;
462
463 send_compatible_streaming_request(self.client.clone(), req)
464 .instrument(span)
465 .await
466 }
467}
468
469#[cfg(feature = "image")]
474pub use image_generation::*;
475
476#[cfg(feature = "image")]
477#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
478mod image_generation {
479 use super::{ApiResponse, Client};
480 use crate::http_client::HttpClientExt;
481 use crate::image_generation;
482 use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
483 use crate::json_utils::merge_inplace;
484 use base64::Engine;
485 use base64::prelude::BASE64_STANDARD;
486 use serde::Deserialize;
487 use serde_json::json;
488
489 pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
490 pub const SD2: &str = "SD2";
491 pub const SD1_5: &str = "SD1.5";
492 pub const SSD: &str = "SSD";
493 pub const SDXL_TURBO: &str = "SDXL-turbo";
494 pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
495 pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
496
497 #[derive(Clone)]
498 pub struct ImageGenerationModel<T> {
499 client: Client<T>,
500 pub model: String,
501 }
502
503 impl<T> ImageGenerationModel<T> {
504 pub(crate) fn new(client: Client<T>, model: impl Into<String>) -> Self {
505 Self {
506 client,
507 model: model.into(),
508 }
509 }
510
511 pub fn with_model(client: Client<T>, model: &str) -> Self {
512 Self {
513 client,
514 model: model.into(),
515 }
516 }
517 }
518
519 #[derive(Clone, Deserialize)]
520 pub struct Image {
521 image: String,
522 }
523
524 #[derive(Clone, Deserialize)]
525 pub struct ImageGenerationResponse {
526 images: Vec<Image>,
527 }
528
529 impl TryFrom<ImageGenerationResponse>
530 for image_generation::ImageGenerationResponse<ImageGenerationResponse>
531 {
532 type Error = ImageGenerationError;
533
534 fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
535 let data = BASE64_STANDARD
536 .decode(&value.images[0].image)
537 .expect("Could not decode image.");
538
539 Ok(Self {
540 image: data,
541 response: value,
542 })
543 }
544 }
545
546 impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
547 where
548 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
549 {
550 type Response = ImageGenerationResponse;
551
552 type Client = Client<T>;
553
554 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
555 Self::new(client.clone(), model)
556 }
557
558 async fn image_generation(
559 &self,
560 generation_request: ImageGenerationRequest,
561 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
562 {
563 let mut request = json!({
564 "model_name": self.model,
565 "prompt": generation_request.prompt,
566 "height": generation_request.height,
567 "width": generation_request.width,
568 });
569
570 if let Some(params) = generation_request.additional_params {
571 merge_inplace(&mut request, params);
572 }
573
574 let body = serde_json::to_vec(&request)?;
575
576 let request = self
577 .client
578 .post("/v1/image/generation")?
579 .header("Content-Type", "application/json")
580 .body(body)
581 .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
582
583 let response = self.client.send::<_, bytes::Bytes>(request).await?;
584
585 let status = response.status();
586 let response_body = response.into_body().into_future().await?.to_vec();
587
588 if !status.is_success() {
589 return Err(ImageGenerationError::ProviderError(format!(
590 "{status}: {}",
591 String::from_utf8_lossy(&response_body)
592 )));
593 }
594
595 match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
596 ApiResponse::Ok(response) => response.try_into(),
597 ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
598 }
599 }
600 }
601}
602
603#[cfg(feature = "audio")]
607pub use audio_generation::*;
608use tracing::{Instrument, info_span};
609
610#[cfg(feature = "audio")]
611#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
612mod audio_generation {
613 use super::{ApiResponse, Client};
614 use crate::audio_generation;
615 use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
616 use crate::http_client::{self, HttpClientExt};
617 use base64::Engine;
618 use base64::prelude::BASE64_STANDARD;
619 use bytes::Bytes;
620 use serde::Deserialize;
621 use serde_json::json;
622
623 #[derive(Clone)]
624 pub struct AudioGenerationModel<T> {
625 client: Client<T>,
626 pub language: String,
627 }
628
629 #[derive(Clone, Deserialize)]
630 pub struct AudioGenerationResponse {
631 audio: String,
632 }
633
634 impl TryFrom<AudioGenerationResponse>
635 for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
636 {
637 type Error = AudioGenerationError;
638
639 fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
640 let data = BASE64_STANDARD
641 .decode(&value.audio)
642 .expect("Could not decode audio.");
643
644 Ok(Self {
645 audio: data,
646 response: value,
647 })
648 }
649 }
650
651 impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
652 where
653 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
654 {
655 type Response = AudioGenerationResponse;
656 type Client = Client<T>;
657
658 fn make(client: &Self::Client, language: impl Into<String>) -> Self {
659 Self {
660 client: client.clone(),
661 language: language.into(),
662 }
663 }
664
665 async fn audio_generation(
666 &self,
667 request: AudioGenerationRequest,
668 ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
669 {
670 let request = json!({
671 "language": self.language,
672 "speaker": request.voice,
673 "text": request.text,
674 "speed": request.speed
675 });
676
677 let body = serde_json::to_vec(&request)?;
678
679 let req = self
680 .client
681 .post("/v1/audio/generation")?
682 .body(body)
683 .map_err(http_client::Error::from)?;
684
685 let response = self.client.send::<_, Bytes>(req).await?;
686 let status = response.status();
687 let response_body = response.into_body().into_future().await?.to_vec();
688
689 if !status.is_success() {
690 return Err(AudioGenerationError::ProviderError(format!(
691 "{status}: {}",
692 String::from_utf8_lossy(&response_body)
693 )));
694 }
695
696 match serde_json::from_slice::<ApiResponse<AudioGenerationResponse>>(&response_body)? {
697 ApiResponse::Ok(response) => response.try_into(),
698 ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
699 }
700 }
701 }
702}
703
704#[cfg(test)]
705mod tests {
706 #[test]
707 fn test_client_initialization() {
708 let _client: crate::providers::hyperbolic::Client =
709 crate::providers::hyperbolic::Client::new("dummy-key").expect("Client::new() failed");
710 let _client_from_builder: crate::providers::hyperbolic::Client =
711 crate::providers::hyperbolic::Client::builder()
712 .api_key("dummy-key")
713 .build()
714 .expect("Client::builder() failed");
715 }
716}