rig/providers/huggingface/
client.rs1use super::completion::CompletionModel;
2#[cfg(feature = "image")]
3use crate::client::ImageGenerationClient;
4use crate::client::{ClientBuilderError, CompletionClient, ProviderClient, TranscriptionClient};
5#[cfg(feature = "image")]
6use crate::image_generation::ImageGenerationError;
7#[cfg(feature = "image")]
8use crate::providers::huggingface::image_generation::ImageGenerationModel;
9use crate::providers::huggingface::transcription::TranscriptionModel;
10use crate::transcription::TranscriptionError;
11use rig::client::impl_conversion_traits;
12use std::fmt::Display;
13
14const HUGGINGFACE_API_BASE_URL: &str = "https://router.huggingface.co/";
18
19#[derive(Debug, Clone, PartialEq, Default)]
20pub enum SubProvider {
21 #[default]
22 HFInference,
23 Together,
24 SambaNova,
25 Fireworks,
26 Hyperbolic,
27 Nebius,
28 Novita,
29 Custom(String),
30}
31
32impl SubProvider {
33 pub fn completion_endpoint(&self, model: &str) -> String {
37 match self {
38 SubProvider::HFInference => format!("/{model}/v1/chat/completions"),
39 _ => "/v1/chat/completions".to_string(),
40 }
41 }
42
43 pub fn transcription_endpoint(&self, model: &str) -> Result<String, TranscriptionError> {
47 match self {
48 SubProvider::HFInference => Ok(format!("/{model}")),
49 _ => Err(TranscriptionError::ProviderError(format!(
50 "transcription endpoint is not supported yet for {self}"
51 ))),
52 }
53 }
54
55 #[cfg(feature = "image")]
59 pub fn image_generation_endpoint(&self, model: &str) -> Result<String, ImageGenerationError> {
60 match self {
61 SubProvider::HFInference => Ok(format!("/{model}")),
62 _ => Err(ImageGenerationError::ProviderError(format!(
63 "image generation endpoint is not supported yet for {self}"
64 ))),
65 }
66 }
67
68 pub fn model_identifier(&self, model: &str) -> String {
69 match self {
70 SubProvider::Fireworks => format!("accounts/fireworks/models/{model}"),
71 _ => model.to_string(),
72 }
73 }
74}
75
76impl From<&str> for SubProvider {
77 fn from(s: &str) -> Self {
78 SubProvider::Custom(s.to_string())
79 }
80}
81
82impl From<String> for SubProvider {
83 fn from(value: String) -> Self {
84 SubProvider::Custom(value)
85 }
86}
87
88impl Display for SubProvider {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 let route = match self {
91 SubProvider::HFInference => "hf-inference/models".to_string(),
92 SubProvider::Together => "together".to_string(),
93 SubProvider::SambaNova => "sambanova".to_string(),
94 SubProvider::Fireworks => "fireworks-ai".to_string(),
95 SubProvider::Hyperbolic => "hyperbolic".to_string(),
96 SubProvider::Nebius => "nebius".to_string(),
97 SubProvider::Novita => "novita".to_string(),
98 SubProvider::Custom(route) => route.clone(),
99 };
100
101 write!(f, "{route}")
102 }
103}
104
105pub struct ClientBuilder {
106 api_key: String,
107 base_url: String,
108 sub_provider: SubProvider,
109 http_client: Option<reqwest::Client>,
110}
111
112impl ClientBuilder {
113 pub fn new(api_key: &str) -> Self {
114 Self {
115 api_key: api_key.to_string(),
116 base_url: HUGGINGFACE_API_BASE_URL.to_string(),
117 sub_provider: SubProvider::default(),
118 http_client: None,
119 }
120 }
121
122 pub fn base_url(mut self, base_url: &str) -> Self {
123 self.base_url = base_url.to_string();
124 self
125 }
126
127 pub fn sub_provider(mut self, provider: impl Into<SubProvider>) -> Self {
128 self.sub_provider = provider.into();
129 self
130 }
131
132 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
133 self.http_client = Some(client);
134 self
135 }
136
137 pub fn build(self) -> Result<Client, ClientBuilderError> {
138 let route = self.sub_provider.to_string();
139 let base_url = format!("{}/{}", self.base_url, route).replace("//", "/");
140
141 let mut default_headers = reqwest::header::HeaderMap::new();
142 default_headers.insert(
143 "Content-Type",
144 "application/json"
145 .parse()
146 .expect("Failed to parse Content-Type"),
147 );
148 let http_client = if let Some(http_client) = self.http_client {
149 http_client
150 } else {
151 reqwest::Client::builder().build()?
152 };
153
154 Ok(Client {
155 base_url,
156 default_headers,
157 api_key: self.api_key,
158 http_client,
159 sub_provider: self.sub_provider,
160 })
161 }
162}
163
164#[derive(Clone)]
165pub struct Client {
166 base_url: String,
167 default_headers: reqwest::header::HeaderMap,
168 api_key: String,
169 http_client: reqwest::Client,
170 pub(crate) sub_provider: SubProvider,
171}
172
173impl std::fmt::Debug for Client {
174 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175 f.debug_struct("Client")
176 .field("base_url", &self.base_url)
177 .field("http_client", &self.http_client)
178 .field("default_headers", &self.default_headers)
179 .field("sub_provider", &self.sub_provider)
180 .field("api_key", &"<REDACTED>")
181 .finish()
182 }
183}
184
185impl Client {
186 pub fn builder(api_key: &str) -> ClientBuilder {
197 ClientBuilder::new(api_key)
198 }
199
200 pub fn new(api_key: &str) -> Self {
205 Self::builder(api_key)
206 .build()
207 .expect("Huggingface client should build")
208 }
209
210 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
211 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
212 self.http_client
213 .post(url)
214 .bearer_auth(&self.api_key)
215 .headers(self.default_headers.clone())
216 }
217}
218
219impl ProviderClient for Client {
220 fn from_env() -> Self {
223 let api_key = std::env::var("HUGGINGFACE_API_KEY").expect("HUGGINGFACE_API_KEY is not set");
224 Self::new(&api_key)
225 }
226
227 fn from_val(input: crate::client::ProviderValue) -> Self {
228 let crate::client::ProviderValue::Simple(api_key) = input else {
229 panic!("Incorrect provider value type")
230 };
231 Self::new(&api_key)
232 }
233}
234
235impl CompletionClient for Client {
236 type CompletionModel = CompletionModel;
237
238 fn completion_model(&self, model: &str) -> CompletionModel {
250 CompletionModel::new(self.clone(), model)
251 }
252}
253
254impl TranscriptionClient for Client {
255 type TranscriptionModel = TranscriptionModel;
256
257 fn transcription_model(&self, model: &str) -> TranscriptionModel {
270 TranscriptionModel::new(self.clone(), model)
271 }
272}
273
274#[cfg(feature = "image")]
275impl ImageGenerationClient for Client {
276 type ImageGenerationModel = ImageGenerationModel;
277
278 fn image_generation_model(&self, model: &str) -> ImageGenerationModel {
290 ImageGenerationModel::new(self.clone(), model)
291 }
292}
293
294impl_conversion_traits!(AsEmbeddings, AsAudioGeneration for Client);