1use super::openai::{AssistantContent, send_compatible_streaming_request};
12
13use crate::client::{self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder};
14use crate::client::{BearerAuth, ProviderClient};
15use crate::http_client::{self, HttpClientExt};
16use crate::streaming::StreamingCompletionResponse;
17
18use crate::providers::openai;
19use crate::{
20 OneOrMany,
21 completion::{self, CompletionError, CompletionRequest},
22 json_utils,
23 providers::openai::Message,
24};
25use serde::{Deserialize, Serialize};
26
27const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz";
31
32#[derive(Debug, Default, Clone, Copy)]
33pub struct HyperbolicExt;
34#[derive(Debug, Default, Clone, Copy)]
35pub struct HyperbolicBuilder;
36
37type HyperbolicApiKey = BearerAuth;
38
39impl Provider for HyperbolicExt {
40 type Builder = HyperbolicBuilder;
41
42 const VERIFY_PATH: &'static str = "/models";
43
44 fn build<H>(
45 _: &crate::client::ClientBuilder<
46 Self::Builder,
47 <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
48 H,
49 >,
50 ) -> http_client::Result<Self> {
51 Ok(Self)
52 }
53}
54
55impl<H> Capabilities<H> for HyperbolicExt {
56 type Completion = Capable<CompletionModel<H>>;
57 type Embeddings = Nothing;
58 type Transcription = Nothing;
59 #[cfg(feature = "image")]
60 type ImageGeneration = Capable<ImageGenerationModel<H>>;
61 #[cfg(feature = "audio")]
62 type AudioGeneration = Capable<AudioGenerationModel<H>>;
63}
64
65impl DebugExt for HyperbolicExt {}
66
67impl ProviderBuilder for HyperbolicBuilder {
68 type Output = HyperbolicExt;
69 type ApiKey = HyperbolicApiKey;
70
71 const BASE_URL: &'static str = HYPERBOLIC_API_BASE_URL;
72}
73
74pub type Client<H = reqwest::Client> = client::Client<HyperbolicExt, H>;
75pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<HyperbolicBuilder, String, H>;
76
77impl ProviderClient for Client {
78 type Input = HyperbolicApiKey;
79
80 fn from_env() -> Self {
83 let api_key = std::env::var("HYPERBOLIC_API_KEY").expect("HYPERBOLIC_API_KEY not set");
84 Self::new(&api_key).unwrap()
85 }
86
87 fn from_val(input: Self::Input) -> Self {
88 Self::new(input).unwrap()
89 }
90}
91
92#[derive(Debug, Deserialize)]
93struct ApiErrorResponse {
94 message: String,
95}
96
97#[derive(Debug, Deserialize)]
98#[serde(untagged)]
99enum ApiResponse<T> {
100 Ok(T),
101 Err(ApiErrorResponse),
102}
103
104#[derive(Debug, Deserialize)]
105pub struct EmbeddingData {
106 pub object: String,
107 pub embedding: Vec<f64>,
108 pub index: usize,
109}
110
111#[derive(Clone, Debug, Deserialize, Serialize)]
112pub struct Usage {
113 pub prompt_tokens: usize,
114 pub total_tokens: usize,
115}
116
117impl std::fmt::Display for Usage {
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 write!(
120 f,
121 "Prompt tokens: {} Total tokens: {}",
122 self.prompt_tokens, self.total_tokens
123 )
124 }
125}
126
127pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
133pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
135pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
137pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
139pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
141pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
143pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
145pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
147pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
149pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
151pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
153pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
155
156#[derive(Debug, Deserialize, Serialize)]
160pub struct CompletionResponse {
161 pub id: String,
162 pub object: String,
163 pub created: u64,
164 pub model: String,
165 pub choices: Vec<Choice>,
166 pub usage: Option<Usage>,
167}
168
169impl From<ApiErrorResponse> for CompletionError {
170 fn from(err: ApiErrorResponse) -> Self {
171 CompletionError::ProviderError(err.message)
172 }
173}
174
175impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
176 type Error = CompletionError;
177
178 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
179 let choice = response.choices.first().ok_or_else(|| {
180 CompletionError::ResponseError("Response contained no choices".to_owned())
181 })?;
182
183 let content = match &choice.message {
184 Message::Assistant {
185 content,
186 tool_calls,
187 ..
188 } => {
189 let mut content = content
190 .iter()
191 .map(|c| match c {
192 AssistantContent::Text { text } => completion::AssistantContent::text(text),
193 AssistantContent::Refusal { refusal } => {
194 completion::AssistantContent::text(refusal)
195 }
196 })
197 .collect::<Vec<_>>();
198
199 content.extend(
200 tool_calls
201 .iter()
202 .map(|call| {
203 completion::AssistantContent::tool_call(
204 &call.id,
205 &call.function.name,
206 call.function.arguments.clone(),
207 )
208 })
209 .collect::<Vec<_>>(),
210 );
211 Ok(content)
212 }
213 _ => Err(CompletionError::ResponseError(
214 "Response did not contain a valid message or tool call".into(),
215 )),
216 }?;
217
218 let choice = OneOrMany::many(content).map_err(|_| {
219 CompletionError::ResponseError(
220 "Response contained no message or tool call (empty)".to_owned(),
221 )
222 })?;
223
224 let usage = response
225 .usage
226 .as_ref()
227 .map(|usage| completion::Usage {
228 input_tokens: usage.prompt_tokens as u64,
229 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
230 total_tokens: usage.total_tokens as u64,
231 cached_input_tokens: 0,
232 })
233 .unwrap_or_default();
234
235 Ok(completion::CompletionResponse {
236 choice,
237 usage,
238 raw_response: response,
239 })
240 }
241}
242
243#[derive(Debug, Deserialize, Serialize)]
244pub struct Choice {
245 pub index: usize,
246 pub message: Message,
247 pub finish_reason: String,
248}
249
250#[derive(Debug, Serialize, Deserialize)]
251pub(super) struct HyperbolicCompletionRequest {
252 model: String,
253 pub messages: Vec<Message>,
254 #[serde(skip_serializing_if = "Option::is_none")]
255 temperature: Option<f64>,
256 #[serde(flatten, skip_serializing_if = "Option::is_none")]
257 pub additional_params: Option<serde_json::Value>,
258}
259
260impl TryFrom<(&str, CompletionRequest)> for HyperbolicCompletionRequest {
261 type Error = CompletionError;
262
263 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
264 if req.tool_choice.is_some() {
265 tracing::warn!("WARNING: `tool_choice` not supported on Hyperbolic");
266 }
267
268 if !req.tools.is_empty() {
269 tracing::warn!("WARNING: `tools` not supported on Hyperbolic");
270 }
271
272 let mut full_history: Vec<Message> = match &req.preamble {
273 Some(preamble) => vec![Message::system(preamble)],
274 None => vec![],
275 };
276
277 if let Some(docs) = req.normalized_documents() {
278 let docs: Vec<Message> = docs.try_into()?;
279 full_history.extend(docs);
280 }
281
282 let chat_history: Vec<Message> = req
283 .chat_history
284 .clone()
285 .into_iter()
286 .map(|message| message.try_into())
287 .collect::<Result<Vec<Vec<Message>>, _>>()?
288 .into_iter()
289 .flatten()
290 .collect();
291
292 full_history.extend(chat_history);
293
294 Ok(Self {
295 model: model.to_string(),
296 messages: full_history,
297 temperature: req.temperature,
298 additional_params: req.additional_params,
299 })
300 }
301}
302
303#[derive(Clone)]
304pub struct CompletionModel<T = reqwest::Client> {
305 client: Client<T>,
306 pub model: String,
308}
309
310impl<T> CompletionModel<T> {
311 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
312 Self {
313 client,
314 model: model.into(),
315 }
316 }
317
318 pub fn with_model(client: Client<T>, model: &str) -> Self {
319 Self {
320 client,
321 model: model.into(),
322 }
323 }
324}
325
326impl<T> completion::CompletionModel for CompletionModel<T>
327where
328 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
329{
330 type Response = CompletionResponse;
331 type StreamingResponse = openai::StreamingCompletionResponse;
332
333 type Client = Client<T>;
334
335 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
336 Self::new(client.clone(), model)
337 }
338
339 async fn completion(
340 &self,
341 completion_request: CompletionRequest,
342 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
343 let span = if tracing::Span::current().is_disabled() {
344 info_span!(
345 target: "rig::completions",
346 "chat",
347 gen_ai.operation.name = "chat",
348 gen_ai.provider.name = "hyperbolic",
349 gen_ai.request.model = self.model,
350 gen_ai.system_instructions = tracing::field::Empty,
351 gen_ai.response.id = tracing::field::Empty,
352 gen_ai.response.model = tracing::field::Empty,
353 gen_ai.usage.output_tokens = tracing::field::Empty,
354 gen_ai.usage.input_tokens = tracing::field::Empty,
355 )
356 } else {
357 tracing::Span::current()
358 };
359
360 span.record("gen_ai.system_instructions", &completion_request.preamble);
361 let request =
362 HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
363
364 if tracing::enabled!(tracing::Level::TRACE) {
365 tracing::trace!(target: "rig::completions",
366 "Hyperbolic completion request: {}",
367 serde_json::to_string_pretty(&request)?
368 );
369 }
370
371 let body = serde_json::to_vec(&request)?;
372
373 let req = self
374 .client
375 .post("/v1/chat/completions")?
376 .body(body)
377 .map_err(http_client::Error::from)?;
378
379 let async_block = async move {
380 let response = self.client.send::<_, bytes::Bytes>(req).await?;
381
382 let status = response.status();
383 let response_body = response.into_body().into_future().await?.to_vec();
384
385 if status.is_success() {
386 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(&response_body)? {
387 ApiResponse::Ok(response) => {
388 if tracing::enabled!(tracing::Level::TRACE) {
389 tracing::trace!(target: "rig::completions",
390 "Hyperbolic completion response: {}",
391 serde_json::to_string_pretty(&response)?
392 );
393 }
394
395 response.try_into()
396 }
397 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
398 }
399 } else {
400 Err(CompletionError::ProviderError(
401 String::from_utf8_lossy(&response_body).to_string(),
402 ))
403 }
404 };
405
406 async_block.instrument(span).await
407 }
408
409 async fn stream(
410 &self,
411 completion_request: CompletionRequest,
412 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
413 let span = if tracing::Span::current().is_disabled() {
414 info_span!(
415 target: "rig::completions",
416 "chat_streaming",
417 gen_ai.operation.name = "chat_streaming",
418 gen_ai.provider.name = "hyperbolic",
419 gen_ai.request.model = self.model,
420 gen_ai.system_instructions = tracing::field::Empty,
421 gen_ai.response.id = tracing::field::Empty,
422 gen_ai.response.model = tracing::field::Empty,
423 gen_ai.usage.output_tokens = tracing::field::Empty,
424 gen_ai.usage.input_tokens = tracing::field::Empty,
425 )
426 } else {
427 tracing::Span::current()
428 };
429
430 span.record("gen_ai.system_instructions", &completion_request.preamble);
431 let mut request =
432 HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
433
434 let params = json_utils::merge(
435 request.additional_params.unwrap_or(serde_json::json!({})),
436 serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
437 );
438
439 request.additional_params = Some(params);
440
441 if tracing::enabled!(tracing::Level::TRACE) {
442 tracing::trace!(target: "rig::completions",
443 "Hyperbolic streaming completion request: {}",
444 serde_json::to_string_pretty(&request)?
445 );
446 }
447
448 let body = serde_json::to_vec(&request)?;
449
450 let req = self
451 .client
452 .post("/v1/chat/completions")?
453 .body(body)
454 .map_err(http_client::Error::from)?;
455
456 send_compatible_streaming_request(self.client.clone(), req)
457 .instrument(span)
458 .await
459 }
460}
461
462#[cfg(feature = "image")]
467pub use image_generation::*;
468
469#[cfg(feature = "image")]
470#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
471mod image_generation {
472 use super::{ApiResponse, Client};
473 use crate::http_client::HttpClientExt;
474 use crate::image_generation;
475 use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
476 use crate::json_utils::merge_inplace;
477 use base64::Engine;
478 use base64::prelude::BASE64_STANDARD;
479 use serde::Deserialize;
480 use serde_json::json;
481
482 pub const SDXL1_0_BASE: &str = "SDXL1.0-base";
483 pub const SD2: &str = "SD2";
484 pub const SD1_5: &str = "SD1.5";
485 pub const SSD: &str = "SSD";
486 pub const SDXL_TURBO: &str = "SDXL-turbo";
487 pub const SDXL_CONTROLNET: &str = "SDXL-ControlNet";
488 pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet";
489
490 #[derive(Clone)]
491 pub struct ImageGenerationModel<T> {
492 client: Client<T>,
493 pub model: String,
494 }
495
496 impl<T> ImageGenerationModel<T> {
497 pub(crate) fn new(client: Client<T>, model: impl Into<String>) -> Self {
498 Self {
499 client,
500 model: model.into(),
501 }
502 }
503
504 pub fn with_model(client: Client<T>, model: &str) -> Self {
505 Self {
506 client,
507 model: model.into(),
508 }
509 }
510 }
511
512 #[derive(Clone, Deserialize)]
513 pub struct Image {
514 image: String,
515 }
516
517 #[derive(Clone, Deserialize)]
518 pub struct ImageGenerationResponse {
519 images: Vec<Image>,
520 }
521
522 impl TryFrom<ImageGenerationResponse>
523 for image_generation::ImageGenerationResponse<ImageGenerationResponse>
524 {
525 type Error = ImageGenerationError;
526
527 fn try_from(value: ImageGenerationResponse) -> Result<Self, Self::Error> {
528 let data = BASE64_STANDARD
529 .decode(&value.images[0].image)
530 .expect("Could not decode image.");
531
532 Ok(Self {
533 image: data,
534 response: value,
535 })
536 }
537 }
538
539 impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
540 where
541 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
542 {
543 type Response = ImageGenerationResponse;
544
545 type Client = Client<T>;
546
547 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
548 Self::new(client.clone(), model)
549 }
550
551 async fn image_generation(
552 &self,
553 generation_request: ImageGenerationRequest,
554 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
555 {
556 let mut request = json!({
557 "model_name": self.model,
558 "prompt": generation_request.prompt,
559 "height": generation_request.height,
560 "width": generation_request.width,
561 });
562
563 if let Some(params) = generation_request.additional_params {
564 merge_inplace(&mut request, params);
565 }
566
567 let body = serde_json::to_vec(&request)?;
568
569 let request = self
570 .client
571 .post("/v1/image/generation")?
572 .header("Content-Type", "application/json")
573 .body(body)
574 .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
575
576 let response = self.client.send::<_, bytes::Bytes>(request).await?;
577
578 let status = response.status();
579 let response_body = response.into_body().into_future().await?.to_vec();
580
581 if !status.is_success() {
582 return Err(ImageGenerationError::ProviderError(format!(
583 "{status}: {}",
584 String::from_utf8_lossy(&response_body)
585 )));
586 }
587
588 match serde_json::from_slice::<ApiResponse<ImageGenerationResponse>>(&response_body)? {
589 ApiResponse::Ok(response) => response.try_into(),
590 ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)),
591 }
592 }
593 }
594}
595
596#[cfg(feature = "audio")]
600pub use audio_generation::*;
601use tracing::{Instrument, info_span};
602
603#[cfg(feature = "audio")]
604#[cfg_attr(docsrs, doc(cfg(feature = "image")))]
605mod audio_generation {
606 use super::{ApiResponse, Client};
607 use crate::audio_generation;
608 use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest};
609 use crate::http_client::{self, HttpClientExt};
610 use base64::Engine;
611 use base64::prelude::BASE64_STANDARD;
612 use bytes::Bytes;
613 use serde::Deserialize;
614 use serde_json::json;
615
616 #[derive(Clone)]
617 pub struct AudioGenerationModel<T> {
618 client: Client<T>,
619 pub language: String,
620 }
621
622 #[derive(Clone, Deserialize)]
623 pub struct AudioGenerationResponse {
624 audio: String,
625 }
626
627 impl TryFrom<AudioGenerationResponse>
628 for audio_generation::AudioGenerationResponse<AudioGenerationResponse>
629 {
630 type Error = AudioGenerationError;
631
632 fn try_from(value: AudioGenerationResponse) -> Result<Self, Self::Error> {
633 let data = BASE64_STANDARD
634 .decode(&value.audio)
635 .expect("Could not decode audio.");
636
637 Ok(Self {
638 audio: data,
639 response: value,
640 })
641 }
642 }
643
644 impl<T> audio_generation::AudioGenerationModel for AudioGenerationModel<T>
645 where
646 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
647 {
648 type Response = AudioGenerationResponse;
649 type Client = Client<T>;
650
651 fn make(client: &Self::Client, language: impl Into<String>) -> Self {
652 Self {
653 client: client.clone(),
654 language: language.into(),
655 }
656 }
657
658 async fn audio_generation(
659 &self,
660 request: AudioGenerationRequest,
661 ) -> Result<audio_generation::AudioGenerationResponse<Self::Response>, AudioGenerationError>
662 {
663 let request = json!({
664 "language": self.language,
665 "speaker": request.voice,
666 "text": request.text,
667 "speed": request.speed
668 });
669
670 let body = serde_json::to_vec(&request)?;
671
672 let req = self
673 .client
674 .post("/v1/audio/generation")?
675 .body(body)
676 .map_err(http_client::Error::from)?;
677
678 let response = self.client.send::<_, Bytes>(req).await?;
679 let status = response.status();
680 let response_body = response.into_body().into_future().await?.to_vec();
681
682 if !status.is_success() {
683 return Err(AudioGenerationError::ProviderError(format!(
684 "{status}: {}",
685 String::from_utf8_lossy(&response_body)
686 )));
687 }
688
689 match serde_json::from_slice::<ApiResponse<AudioGenerationResponse>>(&response_body)? {
690 ApiResponse::Ok(response) => response.try_into(),
691 ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)),
692 }
693 }
694 }
695}