1use super::openai::{AssistantContent, send_compatible_streaming_request};
12
13use crate::client::{self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder};
14use crate::client::{BearerAuth, ProviderClient};
15use crate::http_client::{self, HttpClientExt};
16use crate::streaming::StreamingCompletionResponse;
17
18use crate::providers::openai;
19use crate::{
20 OneOrMany,
21 completion::{self, CompletionError, CompletionRequest},
22 json_utils,
23 providers::openai::Message,
24};
25use serde::{Deserialize, Serialize};
26
27const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz";
31
32#[derive(Debug, Default, Clone, Copy)]
33pub struct HyperbolicExt;
34#[derive(Debug, Default, Clone, Copy)]
35pub struct HyperbolicBuilder;
36
37type HyperbolicApiKey = BearerAuth;
38
39impl Provider for HyperbolicExt {
40 type Builder = HyperbolicBuilder;
41
42 const VERIFY_PATH: &'static str = "/models";
43
44 fn build<H>(
45 _: &crate::client::ClientBuilder<
46 Self::Builder,
47 <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
48 H,
49 >,
50 ) -> http_client::Result<Self> {
51 Ok(Self)
52 }
53}
54
55impl<H> Capabilities<H> for HyperbolicExt {
56 type Completion = Capable<CompletionModel<H>>;
57 type Embeddings = Nothing;
58 type Transcription = Nothing;
59 #[cfg(feature = "image")]
60 type ImageGeneration = Capable<ImageGenerationModel<H>>;
61 #[cfg(feature = "audio")]
62 type AudioGeneration = Capable<AudioGenerationModel<H>>;
63}
64
65impl DebugExt for HyperbolicExt {}
66
67impl ProviderBuilder for HyperbolicBuilder {
68 type Output = HyperbolicExt;
69 type ApiKey = HyperbolicApiKey;
70
71 const BASE_URL: &'static str = HYPERBOLIC_API_BASE_URL;
72}
73
74pub type Client<H = reqwest::Client> = client::Client<HyperbolicExt, H>;
75pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<HyperbolicBuilder, String, H>;
76
77impl ProviderClient for Client {
78 type Input = HyperbolicApiKey;
79
80 fn from_env() -> Self {
83 let api_key = std::env::var("HYPERBOLIC_API_KEY").expect("HYPERBOLIC_API_KEY not set");
84 Self::new(&api_key).unwrap()
85 }
86
87 fn from_val(input: Self::Input) -> Self {
88 Self::new(input).unwrap()
89 }
90}
91
92#[derive(Debug, Deserialize)]
93struct ApiErrorResponse {
94 message: String,
95}
96
97#[derive(Debug, Deserialize)]
98#[serde(untagged)]
99enum ApiResponse<T> {
100 Ok(T),
101 Err(ApiErrorResponse),
102}
103
104#[derive(Debug, Deserialize)]
105pub struct EmbeddingData {
106 pub object: String,
107 pub embedding: Vec<f64>,
108 pub index: usize,
109}
110
111#[derive(Clone, Debug, Deserialize, Serialize)]
112pub struct Usage {
113 pub prompt_tokens: usize,
114 pub total_tokens: usize,
115}
116
117impl std::fmt::Display for Usage {
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 write!(
120 f,
121 "Prompt tokens: {} Total tokens: {}",
122 self.prompt_tokens, self.total_tokens
123 )
124 }
125}
126
127pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
133pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
135pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
137pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
139pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
141pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
143pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
145pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
147pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
149pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
151pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
153pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
155
156#[derive(Debug, Deserialize, Serialize)]
160pub struct CompletionResponse {
161 pub id: String,
162 pub object: String,
163 pub created: u64,
164 pub model: String,
165 pub choices: Vec<Choice>,
166 pub usage: Option<Usage>,
167}
168
169impl From<ApiErrorResponse> for CompletionError {
170 fn from(err: ApiErrorResponse) -> Self {
171 CompletionError::ProviderError(err.message)
172 }
173}
174
175impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
176 type Error = CompletionError;
177
178 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
179 let choice = response.choices.first().ok_or_else(|| {
180 CompletionError::ResponseError("Response contained no choices".to_owned())
181 })?;
182
183 let content = match &choice.message {
184 Message::Assistant {
185 content,
186 tool_calls,
187 ..
188 } => {
189 let mut content = content
190 .iter()
191 .map(|c| match c {
192 AssistantContent::Text { text } => completion::AssistantContent::text(text),
193 AssistantContent::Refusal { refusal } => {
194 completion::AssistantContent::text(refusal)
195 }
196 })
197 .collect::<Vec<_>>();
198
199 content.extend(
200 tool_calls
201 .iter()
202 .map(|call| {
203 completion::AssistantContent::tool_call(
204 &call.id,
205 &call.function.name,
206 call.function.arguments.clone(),
207 )
208 })
209 .collect::<Vec<_>>(),
210 );
211 Ok(content)
212 }
213 _ => Err(CompletionError::ResponseError(
214 "Response did not contain a valid message or tool call".into(),
215 )),
216 }?;
217
218 let choice = OneOrMany::many(content).map_err(|_| {
219 CompletionError::ResponseError(
220 "Response contained no message or tool call (empty)".to_owned(),
221 )
222 })?;
223
224 let usage = response
225 .usage
226 .as_ref()
227 .map(|usage| completion::Usage {
228 input_tokens: usage.prompt_tokens as u64,
229 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
230 total_tokens: usage.total_tokens as u64,
231 })
232 .unwrap_or_default();
233
234 Ok(completion::CompletionResponse {
235 choice,
236 usage,
237 raw_response: response,
238 })
239 }
240}
241
242#[derive(Debug, Deserialize, Serialize)]
243pub struct Choice {
244 pub index: usize,
245 pub message: Message,
246 pub finish_reason: String,
247}
248
249#[derive(Debug, Serialize, Deserialize)]
250pub(super) struct HyperbolicCompletionRequest {
251 model: String,
252 pub messages: Vec<Message>,
253 #[serde(flatten, skip_serializing_if = "Option::is_none")]
254 temperature: Option<f64>,
255 #[serde(flatten, skip_serializing_if = "Option::is_none")]
256 pub additional_params: Option<serde_json::Value>,
257}
258
259impl TryFrom<(&str, CompletionRequest)> for HyperbolicCompletionRequest {
260 type Error = CompletionError;
261
262 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
263 if req.tool_choice.is_some() {
264 tracing::warn!("WARNING: `tool_choice` not supported on Hyperbolic");
265 }
266
267 if !req.tools.is_empty() {
268 tracing::warn!("WARNING: `tools` not supported on Hyperbolic");
269 }
270
271 let mut full_history: Vec<Message> = match &req.preamble {
272 Some(preamble) => vec![Message::system(preamble)],
273 None => vec![],
274 };
275
276 if let Some(docs) = req.normalized_documents() {
277 let docs: Vec<Message> = docs.try_into()?;
278 full_history.extend(docs);
279 }
280
281 let chat_history: Vec<Message> = req
282 .chat_history
283 .clone()
284 .into_iter()
285 .map(|message| message.try_into())
286 .collect::<Result<Vec<Vec<Message>>, _>>()?
287 .into_iter()
288 .flatten()
289 .collect();
290
291 full_history.extend(chat_history);
292
293 Ok(Self {
294 model: model.to_string(),
295 messages: full_history,
296 temperature: req.temperature,
297 additional_params: req.additional_params,
298 })
299 }
300}
301
302#[derive(Clone)]
303pub struct CompletionModel<T = reqwest::Client> {
304 client: Client<T>,
305 pub model: String,
307}
308
309impl<T> CompletionModel<T> {
310 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
311 Self {
312 client,
313 model: model.into(),
314 }
315 }
316
317 pub fn with_model(client: Client<T>, model: &str) -> Self {
318 Self {
319 client,
320 model: model.into(),
321 }
322 }
323}
324
325impl<T> completion::CompletionModel for CompletionModel<T>
326where
327 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
328{
329 type Response = CompletionResponse;
330 type StreamingResponse = openai::StreamingCompletionResponse;
331
332 type Client = Client<T>;
333
334 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
335 Self::new(client.clone(), model)
336 }
337
338 #[cfg_attr(feature = "worker", worker::send)]
339 async fn completion(
340 &self,
341 completion_request: CompletionRequest,
342 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
343 let preamble = completion_request.preamble.clone();
344 let request =
345 HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
346 let body = serde_json::to_vec(&request)?;
347
348 let span = if tracing::Span::current().is_disabled() {
349 info_span!(
350 target: "rig::completions",
351 "chat",
352 gen_ai.operation.name = "chat",
353 gen_ai.provider.name = "hyperbolic",
354 gen_ai.request.model = self.model,
355 gen_ai.system_instructions = preamble,
356 gen_ai.response.id = tracing::field::Empty,
357 gen_ai.response.model = tracing::field::Empty,
358 gen_ai.usage.output_tokens = tracing::field::Empty,
359 gen_ai.usage.input_tokens = tracing::field::Empty,
360 gen_ai.input.messages = serde_json::to_string(&request.messages)?,
361 gen_ai.output.messages = tracing::field::Empty,
362 )
363 } else {
364 tracing::Span::current()
365 };
366
367 let req = self
368 .client
369 .post("/v1/chat/completions")?
370 .body(body)
371 .map_err(http_client::Error::from)?;
372
373 let async_block = async move {
374 let response = self.client.send::<_, bytes::Bytes>(req).await?;
375
376 let status = response.status();
377 let response_body = response.into_body().into_future().await?.to_vec();
378
379 if status.is_success() {
380 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
381 ApiResponse::Ok(response) => {
382 tracing::info!(target: "rig",
383 "Hyperbolic completion token usage: {:?}",
384 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
385 );
386
387 response.try_into()
388 }
389 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
390 }
391 } else {
392 Err(CompletionError::ProviderError(
393 String::from_utf8_lossy(&response_body).to_string(),
394 ))
395 }
396 };
397
398 async_block.instrument(span).await
399 }
400
401 #[cfg_attr(feature = "worker", worker::send)]
402 async fn stream(
403 &self,
404 completion_request: CompletionRequest,
405 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
406 let preamble = completion_request.preamble.clone();
407 let mut request =
408 HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
409
410 let span = if tracing::Span::current().is_disabled() {
411 info_span!(
412 target: "rig::completions",
413 "chat_streaming",
414 gen_ai.operation.name = "chat_streaming",
415 gen_ai.provider.name = "hyperbolic",
416 gen_ai.request.model = self.model,
417 gen_ai.system_instructions = preamble,
418 gen_ai.response.id = tracing::field::Empty,
419 gen_ai.response.model = tracing::field::Empty,
420 gen_ai.usage.output_tokens = tracing::field::Empty,
421 gen_ai.usage.input_tokens = tracing::field::Empty,
422 gen_ai.input.messages = serde_json::to_string(&request.messages)?,
423 gen_ai.output.messages = tracing::field::Empty,
424 )
425 } else {
426 tracing::Span::current()
427 };
428
429 let params = json_utils::merge(
430 request.additional_params.unwrap_or(serde_json::json!({})),
431 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
432 );
433
434 request.additional_params = Some(params);
435
436 let body = serde_json::to_vec(&request)?;
437
438 let req = self
439 .client
440 .post("/v1/chat/completions")?
441 .body(body)
442 .map_err(http_client::Error::from)?;
443
444 send_compatible_streaming_request(self.client.http_client().clone(), req)
445 .instrument(span)
446 .await
447 }
448}
449
450#[cfg(feature = "image")]
455pub use image_generation::*;
456
457#[cfg(feature = "image")]
458#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
459mod image_generation {
460 use super::{ApiResponse, Client};
461 use crate::http_client::HttpClientExt;
462 use crate::image_generation;
463 use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
464 use crate::json_utils::merge_inplace;
465 use base64::Engine;
466 use base64::prelude::BASE64_STANDARD;
467 use serde::Deserialize;
468 use serde_json::json;
469
470 pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
471 pub const SD2: &str = "SD2";
472 pub const SD1_5: &str = "SD1.5";
473 pub const SSD: &str = "SSD";
474 pub const SDXL_TURBO: &str = "SDXL-turbo";
475 pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
476 pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
477
478 #[derive(Clone)]
479 pub struct ImageGenerationModel<T> {
480 client: Client<T>,
481 pub model: String,
482 }
483
484 impl<T> ImageGenerationModel<T> {
485 pub(crate) fn new(client: Client<T>, model: impl Into<String>) -> Self {
486 Self {
487 client,
488 model: model.into(),
489 }
490 }
491
492 pub fn with_model(client: Client<T>, model: &str) -> Self {
493 Self {
494 client,
495 model: model.into(),
496 }
497 }
498 }
499
500 #[derive(Clone, Deserialize)]
501 pub struct Image {
502 image: String,
503 }
504
505 #[derive(Clone, Deserialize)]
506 pub struct ImageGenerationResponse {
507 images: Vec<Image>,
508 }
509
510 impl TryFrom<ImageGenerationResponse>
511 for image_generation::ImageGenerationResponse<ImageGenerationResponse>
512 {
513 type Error = ImageGenerationError;
514
515 fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
516 let data = BASE64_STANDARD
517 .decode(&value.images[0].image)
518 .expect("Could not decode image.");
519
520 Ok(Self {
521 image: data,
522 response: value,
523 })
524 }
525 }
526
527 impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
528 where
529 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
530 {
531 type Response = ImageGenerationResponse;
532
533 type Client = Client<T>;
534
535 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
536 Self::new(client.clone(), model)
537 }
538
539 #[cfg_attr(feature = "worker", worker::send)]
540 async fn image_generation(
541 &self,
542 generation_request: ImageGenerationRequest,
543 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
544 {
545 let mut request = json!({
546 "model_name": self.model,
547 "prompt": generation_request.prompt,
548 "height": generation_request.height,
549 "width": generation_request.width,
550 });
551
552 if let Some(params) = generation_request.additional_params {
553 merge_inplace(&mut request, params);
554 }
555
556 let body = serde_json::to_vec(&request)?;
557
558 let request = self
559 .client
560 .post("/v1/image/generation")?
561 .header("Content-Type", "application/json")
562 .body(body)
563 .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
564
565 let response = self.client.send::<_, bytes::Bytes>(request).await?;
566
567 let status = response.status();
568 let response_body = response.into_body().into_future().await?.to_vec();
569
570 if !status.is_success() {
571 return Err(ImageGenerationError::ProviderError(format!(
572 "{status}: {}",
573 String::from_utf8_lossy(&response_body)
574 )));
575 }
576
577 match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
578 ApiResponse::Ok(response) => response.try_into(),
579 ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
580 }
581 }
582 }
583}
584
585#[cfg(feature = "audio")]
589pub use audio_generation::*;
590use tracing::{Instrument, info_span};
591
592#[cfg(feature = "audio")]
593#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
594mod audio_generation {
595 use super::{ApiResponse, Client};
596 use crate::audio_generation;
597 use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
598 use crate::http_client::{self, HttpClientExt};
599 use base64::Engine;
600 use base64::prelude::BASE64_STANDARD;
601 use bytes::Bytes;
602 use serde::Deserialize;
603 use serde_json::json;
604
605 #[derive(Clone)]
606 pub struct AudioGenerationModel<T> {
607 client: Client<T>,
608 pub language: String,
609 }
610
611 #[derive(Clone, Deserialize)]
612 pub struct AudioGenerationResponse {
613 audio: String,
614 }
615
616 impl TryFrom<AudioGenerationResponse>
617 for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
618 {
619 type Error = AudioGenerationError;
620
621 fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
622 let data = BASE64_STANDARD
623 .decode(&value.audio)
624 .expect("Could not decode audio.");
625
626 Ok(Self {
627 audio: data,
628 response: value,
629 })
630 }
631 }
632
633 impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
634 where
635 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
636 {
637 type Response = AudioGenerationResponse;
638 type Client = Client<T>;
639
640 fn make(client: &Self::Client, language: impl Into<String>) -> Self {
641 Self {
642 client: client.clone(),
643 language: language.into(),
644 }
645 }
646
647 #[cfg_attr(feature = "worker", worker::send)]
648 async fn audio_generation(
649 &self,
650 request: AudioGenerationRequest,
651 ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
652 {
653 let request = json!({
654 "language": self.language,
655 "speaker": request.voice,
656 "text": request.text,
657 "speed": request.speed
658 });
659
660 let body = serde_json::to_vec(&request)?;
661
662 let req = self
663 .client
664 .post("/v1/audio/generation")?
665 .body(body)
666 .map_err(http_client::Error::from)?;
667
668 let response = self.client.send::<_, Bytes>(req).await?;
669 let status = response.status();
670 let response_body = response.into_body().into_future().await?.to_vec();
671
672 if !status.is_success() {
673 return Err(AudioGenerationError::ProviderError(format!(
674 "{status}: {}",
675 String::from_utf8_lossy(&response_body)
676 )));
677 }
678
679 match serde_json::from_slice::<ApiResponse<AudioGenerationResponse>>(&response_body)? {
680 ApiResponse::Ok(response) => response.try_into(),
681 ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
682 }
683 }
684 }
685}