rig/providers/huggingface/
client.rs1use super::completion::CompletionModel;
2#[cfg(feature = "image")]
3use crate::client::ImageGenerationClient;
4use crate::client::{
5 ClientBuilderError, CompletionClient, ProviderClient, TranscriptionClient, VerifyClient,
6 VerifyError,
7};
8use crate::http_client::{self, HttpClientExt};
9#[cfg(feature = "image")]
10use crate::image_generation::ImageGenerationError;
11#[cfg(feature = "image")]
12use crate::providers::huggingface::image_generation::ImageGenerationModel;
13use crate::providers::huggingface::transcription::TranscriptionModel;
14use crate::transcription::TranscriptionError;
15use bytes::Bytes;
16use rig::client::impl_conversion_traits;
17use std::fmt::Debug;
18use std::fmt::Display;
19
20const HUGGINGFACE_API_BASE_URL: &str = "https://router.huggingface.co/";
24
25#[derive(Debug, Clone, PartialEq, Default)]
26pub enum SubProvider {
27 #[default]
28 HFInference,
29 Together,
30 SambaNova,
31 Fireworks,
32 Hyperbolic,
33 Nebius,
34 Novita,
35 Custom(String),
36}
37
38impl SubProvider {
39 pub fn completion_endpoint(&self, model: &str) -> String {
43 match self {
44 SubProvider::HFInference => format!("/{model}/v1/chat/completions"),
45 _ => "/v1/chat/completions".to_string(),
46 }
47 }
48
49 pub fn transcription_endpoint(&self, model: &str) -> Result<String, TranscriptionError> {
53 match self {
54 SubProvider::HFInference => Ok(format!("/{model}")),
55 _ => Err(TranscriptionError::ProviderError(format!(
56 "transcription endpoint is not supported yet for {self}"
57 ))),
58 }
59 }
60
61 #[cfg(feature = "image")]
65 pub fn image_generation_endpoint(&self, model: &str) -> Result<String, ImageGenerationError> {
66 match self {
67 SubProvider::HFInference => Ok(format!("/{model}")),
68 _ => Err(ImageGenerationError::ProviderError(format!(
69 "image generation endpoint is not supported yet for {self}"
70 ))),
71 }
72 }
73
74 pub fn model_identifier(&self, model: &str) -> String {
75 match self {
76 SubProvider::Fireworks => format!("accounts/fireworks/models/{model}"),
77 _ => model.to_string(),
78 }
79 }
80}
81
82impl From<&str> for SubProvider {
83 fn from(s: &str) -> Self {
84 SubProvider::Custom(s.to_string())
85 }
86}
87
88impl From<String> for SubProvider {
89 fn from(value: String) -> Self {
90 SubProvider::Custom(value)
91 }
92}
93
94impl Display for SubProvider {
95 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 let route = match self {
97 SubProvider::HFInference => "hf-inference/models".to_string(),
98 SubProvider::Together => "together".to_string(),
99 SubProvider::SambaNova => "sambanova".to_string(),
100 SubProvider::Fireworks => "fireworks-ai".to_string(),
101 SubProvider::Hyperbolic => "hyperbolic".to_string(),
102 SubProvider::Nebius => "nebius".to_string(),
103 SubProvider::Novita => "novita".to_string(),
104 SubProvider::Custom(route) => route.clone(),
105 };
106
107 write!(f, "{route}")
108 }
109}
110
111pub struct ClientBuilder<T = reqwest::Client> {
112 api_key: String,
113 base_url: String,
114 sub_provider: SubProvider,
115 http_client: T,
116}
117
118impl<T> ClientBuilder<T>
119where
120 T: Default,
121{
122 pub fn new(api_key: &str) -> ClientBuilder<T> {
123 ClientBuilder {
124 api_key: api_key.to_string(),
125 base_url: HUGGINGFACE_API_BASE_URL.to_string(),
126 sub_provider: SubProvider::default(),
127 http_client: Default::default(),
128 }
129 }
130}
131
132impl<T> ClientBuilder<T> {
133 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<U> {
134 ClientBuilder {
135 api_key: self.api_key,
136 base_url: self.base_url,
137 sub_provider: self.sub_provider,
138 http_client,
139 }
140 }
141
142 pub fn base_url(mut self, base_url: &str) -> Self {
143 self.base_url = base_url.to_string();
144 self
145 }
146
147 pub fn sub_provider(mut self, provider: impl Into<SubProvider>) -> Self {
148 self.sub_provider = provider.into();
149 self
150 }
151
152 pub fn build(self) -> Result<Client<T>, ClientBuilderError> {
153 let route = self.sub_provider.to_string();
154 let base_url = format!("{}/{}", self.base_url, route).replace("//", "/");
155
156 let mut default_headers = reqwest::header::HeaderMap::new();
157 default_headers.insert(
158 "Content-Type",
159 "application/json"
160 .parse()
161 .expect("Failed to parse Content-Type"),
162 );
163
164 Ok(Client {
165 base_url,
166 default_headers,
167 api_key: self.api_key,
168 http_client: self.http_client,
169 sub_provider: self.sub_provider,
170 })
171 }
172}
173
174#[derive(Clone)]
175pub struct Client<T = reqwest::Client> {
176 base_url: String,
177 default_headers: reqwest::header::HeaderMap,
178 api_key: String,
179 http_client: T,
180 pub(crate) sub_provider: SubProvider,
181}
182
183impl<T> Debug for Client<T>
184where
185 T: Debug,
186{
187 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188 f.debug_struct("Client")
189 .field("base_url", &self.base_url)
190 .field("http_client", &self.http_client)
191 .field("default_headers", &self.default_headers)
192 .field("sub_provider", &self.sub_provider)
193 .field("api_key", &"<REDACTED>")
194 .finish()
195 }
196}
197
198impl<T> Client<T>
199where
200 T: Default,
201{
202 pub fn builder(api_key: &str) -> ClientBuilder<T> {
213 ClientBuilder::new(api_key)
214 }
215
216 pub fn new(api_key: &str) -> Self {
221 Self::builder(api_key)
222 .build()
223 .expect("Huggingface client should build")
224 }
225}
226
227impl Client<reqwest::Client> {
228 pub(crate) fn post_reqwest(&self, path: &str) -> reqwest::RequestBuilder {
229 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
230
231 self.http_client
232 .post(url)
233 .headers(self.default_headers.clone())
234 .bearer_auth(&self.api_key)
235 }
236}
237
238impl<T> Client<T>
239where
240 T: HttpClientExt,
241{
242 pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
243 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
244
245 let mut req = http_client::Request::post(url);
246
247 if let Some(hs) = req.headers_mut() {
248 *hs = self.default_headers.clone();
249 }
250
251 http_client::with_bearer_auth(req, &self.api_key)
252 }
253
254 pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
255 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
256
257 let mut req = http_client::Request::get(url);
258
259 if let Some(hs) = req.headers_mut() {
260 *hs = self.default_headers.clone();
261 }
262
263 http_client::with_bearer_auth(req, &self.api_key)
264 }
265
266 pub(crate) async fn send<U, V>(
267 &self,
268 req: http_client::Request<U>,
269 ) -> http_client::Result<http_client::Response<http_client::LazyBody<V>>>
270 where
271 U: Into<Bytes> + Send,
272 V: From<Bytes> + Send + 'static,
273 {
274 self.http_client.send(req).await
275 }
276}
277
278impl ProviderClient for Client<reqwest::Client> {
279 fn from_env() -> Self {
282 let api_key = std::env::var("HUGGINGFACE_API_KEY").expect("HUGGINGFACE_API_KEY is not set");
283 Self::new(&api_key)
284 }
285
286 fn from_val(input: crate::client::ProviderValue) -> Self {
287 let crate::client::ProviderValue::Simple(api_key) = input else {
288 panic!("Incorrect provider value type")
289 };
290 Self::new(&api_key)
291 }
292}
293
294impl CompletionClient for Client<reqwest::Client> {
295 type CompletionModel = CompletionModel<reqwest::Client>;
296
297 fn completion_model(&self, model: &str) -> CompletionModel<reqwest::Client> {
309 CompletionModel::new(self.clone(), model)
310 }
311}
312
313impl TranscriptionClient for Client<reqwest::Client> {
314 type TranscriptionModel = TranscriptionModel<reqwest::Client>;
315
316 fn transcription_model(&self, model: &str) -> TranscriptionModel<reqwest::Client> {
329 TranscriptionModel::new(self.clone(), model)
330 }
331}
332
333#[cfg(feature = "image")]
334impl ImageGenerationClient for Client<reqwest::Client> {
335 type ImageGenerationModel = ImageGenerationModel<reqwest::Client>;
336
337 fn image_generation_model(&self, model: &str) -> ImageGenerationModel<reqwest::Client> {
349 ImageGenerationModel::new(self.clone(), model)
350 }
351}
352
353impl VerifyClient for Client<reqwest::Client> {
354 #[cfg_attr(feature = "worker", worker::send)]
355 async fn verify(&self) -> Result<(), VerifyError> {
356 let req = self
357 .get("/api/whoami-v2")?
358 .body(http_client::NoBody)
359 .map_err(|e| VerifyError::HttpError(e.into()))?;
360
361 let req = reqwest::Request::try_from(req)
362 .map_err(|e| VerifyError::HttpError(http_client::Error::Instance(e.into())))?;
363
364 let response: reqwest::Response = self
365 .http_client
366 .execute(req)
367 .await
368 .map_err(|e| VerifyError::HttpError(http_client::Error::Instance(e.into())))?;
369
370 match response.status() {
371 reqwest::StatusCode::OK => Ok(()),
372 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
373 reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
374 let text = response
375 .text()
376 .await
377 .map_err(|e| VerifyError::HttpError(http_client::Error::Instance(e.into())))?;
378 Err(VerifyError::ProviderError(text))
379 }
380 _ => {
381 response
382 .error_for_status()
383 .map_err(|e| VerifyError::HttpError(http_client::Error::Instance(e.into())))?;
384 Ok(())
385 }
386 }
387 }
388}
389
390impl_conversion_traits!(AsEmbeddings, AsAudioGeneration for Client<T>);