rig/providers/huggingface/
client.rs1use super::completion::CompletionModel;
2#[cfg(feature = "image")]
3use crate::client::ImageGenerationClient;
4use crate::client::{
5 ClientBuilderError, CompletionClient, ProviderClient, TranscriptionClient, VerifyClient,
6 VerifyError,
7};
8use crate::http_client::{self, HttpClientExt};
9#[cfg(feature = "image")]
10use crate::image_generation::ImageGenerationError;
11#[cfg(feature = "image")]
12use crate::providers::huggingface::image_generation::ImageGenerationModel;
13use crate::providers::huggingface::transcription::TranscriptionModel;
14use crate::transcription::TranscriptionError;
15use bytes::Bytes;
16use rig::client::impl_conversion_traits;
17use std::fmt::Debug;
18use std::fmt::Display;
19
20const HUGGINGFACE_API_BASE_URL: &str = "https://router.huggingface.co";
24
25#[derive(Debug, Clone, PartialEq, Default)]
26pub enum SubProvider {
27 #[default]
28 HFInference,
29 Together,
30 SambaNova,
31 Fireworks,
32 Hyperbolic,
33 Nebius,
34 Novita,
35 Custom(String),
36}
37
38impl SubProvider {
39 pub fn completion_endpoint(&self, _model: &str) -> String {
43 "v1/chat/completions".to_string()
44 }
45
46 pub fn transcription_endpoint(&self, model: &str) -> Result<String, TranscriptionError> {
50 match self {
51 SubProvider::HFInference => Ok(format!("/{model}")),
52 _ => Err(TranscriptionError::ProviderError(format!(
53 "transcription endpoint is not supported yet for {self}"
54 ))),
55 }
56 }
57
58 #[cfg(feature = "image")]
62 pub fn image_generation_endpoint(&self, model: &str) -> Result<String, ImageGenerationError> {
63 match self {
64 SubProvider::HFInference => Ok(format!("/{model}")),
65 _ => Err(ImageGenerationError::ProviderError(format!(
66 "image generation endpoint is not supported yet for {self}"
67 ))),
68 }
69 }
70
71 pub fn model_identifier(&self, model: &str) -> String {
72 match self {
73 SubProvider::Fireworks => format!("accounts/fireworks/models/{model}"),
74 _ => model.to_string(),
75 }
76 }
77}
78
79impl From<&str> for SubProvider {
80 fn from(s: &str) -> Self {
81 SubProvider::Custom(s.to_string())
82 }
83}
84
85impl From<String> for SubProvider {
86 fn from(value: String) -> Self {
87 SubProvider::Custom(value)
88 }
89}
90
91impl Display for SubProvider {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 let route = match self {
94 SubProvider::HFInference => "hf-inference/models".to_string(),
95 SubProvider::Together => "together".to_string(),
96 SubProvider::SambaNova => "sambanova".to_string(),
97 SubProvider::Fireworks => "fireworks-ai".to_string(),
98 SubProvider::Hyperbolic => "hyperbolic".to_string(),
99 SubProvider::Nebius => "nebius".to_string(),
100 SubProvider::Novita => "novita".to_string(),
101 SubProvider::Custom(route) => route.clone(),
102 };
103
104 write!(f, "{route}")
105 }
106}
107
108pub struct ClientBuilder<T = reqwest::Client> {
109 api_key: String,
110 base_url: String,
111 sub_provider: SubProvider,
112 http_client: T,
113}
114
115impl<T> ClientBuilder<T>
116where
117 T: Default,
118{
119 pub fn new(api_key: &str) -> ClientBuilder<T> {
120 ClientBuilder {
121 api_key: api_key.to_string(),
122 base_url: HUGGINGFACE_API_BASE_URL.to_string(),
123 sub_provider: SubProvider::default(),
124 http_client: Default::default(),
125 }
126 }
127}
128
129impl<T> ClientBuilder<T> {
130 pub fn new_with_client(api_key: &str, http_client: T) -> ClientBuilder<T> {
131 ClientBuilder {
132 api_key: api_key.to_string(),
133 base_url: HUGGINGFACE_API_BASE_URL.to_string(),
134 sub_provider: SubProvider::default(),
135 http_client,
136 }
137 }
138
139 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<U> {
140 ClientBuilder {
141 api_key: self.api_key,
142 base_url: self.base_url,
143 sub_provider: self.sub_provider,
144 http_client,
145 }
146 }
147
148 pub fn base_url(mut self, base_url: &str) -> Self {
149 self.base_url = base_url.to_string();
150 self
151 }
152
153 pub fn sub_provider(mut self, provider: impl Into<SubProvider>) -> Self {
154 self.sub_provider = provider.into();
155 self
156 }
157
158 pub fn build(self) -> Result<Client<T>, ClientBuilderError> {
159 let mut default_headers = reqwest::header::HeaderMap::new();
160 default_headers.insert(
161 "Content-Type",
162 "application/json"
163 .parse()
164 .expect("Failed to parse Content-Type"),
165 );
166
167 Ok(Client {
168 base_url: self.base_url,
169 default_headers,
170 api_key: self.api_key,
171 http_client: self.http_client,
172 sub_provider: self.sub_provider,
173 })
174 }
175}
176
177#[derive(Clone)]
178pub struct Client<T = reqwest::Client> {
179 base_url: String,
180 default_headers: reqwest::header::HeaderMap,
181 api_key: String,
182 pub http_client: T,
183 pub(crate) sub_provider: SubProvider,
184}
185
186impl<T> Debug for Client<T>
187where
188 T: Debug,
189{
190 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191 f.debug_struct("Client")
192 .field("base_url", &self.base_url)
193 .field("http_client", &self.http_client)
194 .field("default_headers", &self.default_headers)
195 .field("sub_provider", &self.sub_provider)
196 .field("api_key", &"<REDACTED>")
197 .finish()
198 }
199}
200
201impl<T> Client<T>
202where
203 T: HttpClientExt,
204{
205 pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
206 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
207
208 println!("URL: {url}");
209
210 let mut req = http_client::Request::post(url);
211
212 if let Some(hs) = req.headers_mut() {
213 *hs = self.default_headers.clone();
214 }
215
216 http_client::with_bearer_auth(req, &self.api_key)
217 }
218
219 pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
220 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
221
222 let mut req = http_client::Request::get(url);
223
224 if let Some(hs) = req.headers_mut() {
225 *hs = self.default_headers.clone();
226 }
227
228 http_client::with_bearer_auth(req, &self.api_key)
229 }
230
231 pub(crate) async fn send<U, V>(
232 &self,
233 req: http_client::Request<U>,
234 ) -> http_client::Result<http_client::Response<http_client::LazyBody<V>>>
235 where
236 U: Into<Bytes> + Send,
237 V: From<Bytes> + Send + 'static,
238 {
239 self.http_client.send(req).await
240 }
241}
242
243impl Client<reqwest::Client> {
244 pub fn builder(api_key: &str) -> ClientBuilder<reqwest::Client> {
255 ClientBuilder::new(api_key)
256 }
257
258 pub fn new(api_key: &str) -> Self {
263 Self::builder(api_key)
264 .build()
265 .expect("Huggingface client should build")
266 }
267
268 pub fn from_env() -> Self {
269 <Self as ProviderClient>::from_env()
270 }
271}
272
273impl<T> ProviderClient for Client<T>
274where
275 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
276{
277 fn from_env() -> Self {
280 let api_key = std::env::var("HUGGINGFACE_API_KEY").expect("HUGGINGFACE_API_KEY is not set");
281 ClientBuilder::<T>::new(&api_key).build().unwrap()
282 }
283
284 fn from_val(input: crate::client::ProviderValue) -> Self {
285 let crate::client::ProviderValue::Simple(api_key) = input else {
286 panic!("Incorrect provider value type")
287 };
288 ClientBuilder::<T>::new(&api_key).build().unwrap()
289 }
290}
291
292impl<T> CompletionClient for Client<T>
293where
294 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
295{
296 type CompletionModel = CompletionModel<T>;
297
298 fn completion_model(&self, model: &str) -> Self::CompletionModel {
310 CompletionModel::new(self.clone(), model)
311 }
312}
313
314impl<T> TranscriptionClient for Client<T>
315where
316 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
317{
318 type TranscriptionModel = TranscriptionModel<T>;
319
320 fn transcription_model(&self, model: &str) -> Self::TranscriptionModel {
333 TranscriptionModel::new(self.clone(), model)
334 }
335}
336
337#[cfg(feature = "image")]
338impl<T> ImageGenerationClient for Client<T>
339where
340 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
341{
342 type ImageGenerationModel = ImageGenerationModel<T>;
343
344 fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
356 ImageGenerationModel::new(self.clone(), model)
357 }
358}
359
360impl<T> VerifyClient for Client<T>
361where
362 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
363{
364 #[cfg_attr(feature = "worker", worker::send)]
365 async fn verify(&self) -> Result<(), VerifyError> {
366 let req = self
367 .get("/api/whoami-v2")?
368 .body(http_client::NoBody)
369 .map_err(|e| VerifyError::HttpError(e.into()))?;
370
371 let response = self
372 .http_client
373 .send::<_, Vec<u8>>(req)
374 .await
375 .map_err(|e| VerifyError::HttpError(http_client::Error::Instance(e.into())))?;
376
377 match response.status() {
378 reqwest::StatusCode::OK => Ok(()),
379 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
380 reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
381 let text = http_client::text(response).await?;
382 Err(VerifyError::ProviderError(text))
383 }
384 _ => {
385 let text = http_client::text(response).await?;
386 Err(VerifyError::ProviderError(text))
387 }
388 }
389 }
390}
391
392impl_conversion_traits!(AsEmbeddings, AsAudioGeneration for Client<T>);