rig/providers/huggingface/
client.rs

1use super::completion::CompletionModel;
2#[cfg(feature = "image")]
3use crate::client::ImageGenerationClient;
4use crate::client::{
5    ClientBuilderError, CompletionClient, ProviderClient, TranscriptionClient, VerifyClient,
6    VerifyError,
7};
8use crate::http_client::{self, HttpClientExt};
9#[cfg(feature = "image")]
10use crate::image_generation::ImageGenerationError;
11#[cfg(feature = "image")]
12use crate::providers::huggingface::image_generation::ImageGenerationModel;
13use crate::providers::huggingface::transcription::TranscriptionModel;
14use crate::transcription::TranscriptionError;
15use bytes::Bytes;
16use rig::client::impl_conversion_traits;
17use std::fmt::Debug;
18use std::fmt::Display;
19
20// ================================================================
21// Main Huggingface Client
22// ================================================================
23const HUGGINGFACE_API_BASE_URL: &str = "https://router.huggingface.co";
24
25#[derive(Debug, Clone, PartialEq, Default)]
26pub enum SubProvider {
27    #[default]
28    HFInference,
29    Together,
30    SambaNova,
31    Fireworks,
32    Hyperbolic,
33    Nebius,
34    Novita,
35    Custom(String),
36}
37
38impl SubProvider {
39    /// Get the chat completion endpoint for the SubProvider
40    /// Required because Huggingface Inference requires the model
41    /// in the url and in the request body.
42    pub fn completion_endpoint(&self, _model: &str) -> String {
43        "v1/chat/completions".to_string()
44    }
45
46    /// Get the transcription endpoint for the SubProvider
47    /// Required because Huggingface Inference requires the model
48    /// in the url and in the request body.
49    pub fn transcription_endpoint(&self, model: &str) -> Result<String, TranscriptionError> {
50        match self {
51            SubProvider::HFInference => Ok(format!("/{model}")),
52            _ => Err(TranscriptionError::ProviderError(format!(
53                "transcription endpoint is not supported yet for {self}"
54            ))),
55        }
56    }
57
58    /// Get the image generation endpoint for the SubProvider
59    /// Required because Huggingface Inference requires the model
60    /// in the url and in the request body.
61    #[cfg(feature = "image")]
62    pub fn image_generation_endpoint(&self, model: &str) -> Result<String, ImageGenerationError> {
63        match self {
64            SubProvider::HFInference => Ok(format!("/{model}")),
65            _ => Err(ImageGenerationError::ProviderError(format!(
66                "image generation endpoint is not supported yet for {self}"
67            ))),
68        }
69    }
70
71    pub fn model_identifier(&self, model: &str) -> String {
72        match self {
73            SubProvider::Fireworks => format!("accounts/fireworks/models/{model}"),
74            _ => model.to_string(),
75        }
76    }
77}
78
79impl From<&str> for SubProvider {
80    fn from(s: &str) -> Self {
81        SubProvider::Custom(s.to_string())
82    }
83}
84
85impl From<String> for SubProvider {
86    fn from(value: String) -> Self {
87        SubProvider::Custom(value)
88    }
89}
90
91impl Display for SubProvider {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        let route = match self {
94            SubProvider::HFInference => "hf-inference/models".to_string(),
95            SubProvider::Together => "together".to_string(),
96            SubProvider::SambaNova => "sambanova".to_string(),
97            SubProvider::Fireworks => "fireworks-ai".to_string(),
98            SubProvider::Hyperbolic => "hyperbolic".to_string(),
99            SubProvider::Nebius => "nebius".to_string(),
100            SubProvider::Novita => "novita".to_string(),
101            SubProvider::Custom(route) => route.clone(),
102        };
103
104        write!(f, "{route}")
105    }
106}
107
108pub struct ClientBuilder<T = reqwest::Client> {
109    api_key: String,
110    base_url: String,
111    sub_provider: SubProvider,
112    http_client: T,
113}
114
115impl<T> ClientBuilder<T>
116where
117    T: Default,
118{
119    pub fn new(api_key: &str) -> ClientBuilder<T> {
120        ClientBuilder {
121            api_key: api_key.to_string(),
122            base_url: HUGGINGFACE_API_BASE_URL.to_string(),
123            sub_provider: SubProvider::default(),
124            http_client: Default::default(),
125        }
126    }
127}
128
129impl<T> ClientBuilder<T> {
130    pub fn new_with_client(api_key: &str, http_client: T) -> ClientBuilder<T> {
131        ClientBuilder {
132            api_key: api_key.to_string(),
133            base_url: HUGGINGFACE_API_BASE_URL.to_string(),
134            sub_provider: SubProvider::default(),
135            http_client,
136        }
137    }
138
139    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<U> {
140        ClientBuilder {
141            api_key: self.api_key,
142            base_url: self.base_url,
143            sub_provider: self.sub_provider,
144            http_client,
145        }
146    }
147
148    pub fn base_url(mut self, base_url: &str) -> Self {
149        self.base_url = base_url.to_string();
150        self
151    }
152
153    pub fn sub_provider(mut self, provider: impl Into<SubProvider>) -> Self {
154        self.sub_provider = provider.into();
155        self
156    }
157
158    pub fn build(self) -> Result<Client<T>, ClientBuilderError> {
159        let mut default_headers = reqwest::header::HeaderMap::new();
160        default_headers.insert(
161            "Content-Type",
162            "application/json"
163                .parse()
164                .expect("Failed to parse Content-Type"),
165        );
166
167        Ok(Client {
168            base_url: self.base_url,
169            default_headers,
170            api_key: self.api_key,
171            http_client: self.http_client,
172            sub_provider: self.sub_provider,
173        })
174    }
175}
176
177#[derive(Clone)]
178pub struct Client<T = reqwest::Client> {
179    base_url: String,
180    default_headers: reqwest::header::HeaderMap,
181    api_key: String,
182    pub http_client: T,
183    pub(crate) sub_provider: SubProvider,
184}
185
186impl<T> Debug for Client<T>
187where
188    T: Debug,
189{
190    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191        f.debug_struct("Client")
192            .field("base_url", &self.base_url)
193            .field("http_client", &self.http_client)
194            .field("default_headers", &self.default_headers)
195            .field("sub_provider", &self.sub_provider)
196            .field("api_key", &"<REDACTED>")
197            .finish()
198    }
199}
200
201impl<T> Client<T>
202where
203    T: HttpClientExt,
204{
205    pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
206        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
207
208        println!("URL: {url}");
209
210        let mut req = http_client::Request::post(url);
211
212        if let Some(hs) = req.headers_mut() {
213            *hs = self.default_headers.clone();
214        }
215
216        http_client::with_bearer_auth(req, &self.api_key)
217    }
218
219    pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
220        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
221
222        let mut req = http_client::Request::get(url);
223
224        if let Some(hs) = req.headers_mut() {
225            *hs = self.default_headers.clone();
226        }
227
228        http_client::with_bearer_auth(req, &self.api_key)
229    }
230
231    pub(crate) async fn send<U, V>(
232        &self,
233        req: http_client::Request<U>,
234    ) -> http_client::Result<http_client::Response<http_client::LazyBody<V>>>
235    where
236        U: Into<Bytes> + Send,
237        V: From<Bytes> + Send + 'static,
238    {
239        self.http_client.send(req).await
240    }
241}
242
243impl Client<reqwest::Client> {
244    /// Create a new Huggingface client builder.
245    ///
246    /// # Example
247    /// ```
248    /// use rig::providers::huggingface::{ClientBuilder, self};
249    ///
250    /// // Initialize the Huggingface client
251    /// let client = Client::builder("your-huggingface-api-key")
252    ///    .build()
253    /// ```
254    pub fn builder(api_key: &str) -> ClientBuilder<reqwest::Client> {
255        ClientBuilder::new(api_key)
256    }
257
258    /// Create a new Huggingface client. For more control, use the `builder` method.
259    ///
260    /// # Panics
261    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
262    pub fn new(api_key: &str) -> Self {
263        Self::builder(api_key)
264            .build()
265            .expect("Huggingface client should build")
266    }
267
268    pub fn from_env() -> Self {
269        <Self as ProviderClient>::from_env()
270    }
271}
272
273impl<T> ProviderClient for Client<T>
274where
275    T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
276{
277    /// Create a new Huggingface client from the `HUGGINGFACE_API_KEY` environment variable.
278    /// Panics if the environment variable is not set.
279    fn from_env() -> Self {
280        let api_key = std::env::var("HUGGINGFACE_API_KEY").expect("HUGGINGFACE_API_KEY is not set");
281        ClientBuilder::<T>::new(&api_key).build().unwrap()
282    }
283
284    fn from_val(input: crate::client::ProviderValue) -> Self {
285        let crate::client::ProviderValue::Simple(api_key) = input else {
286            panic!("Incorrect provider value type")
287        };
288        ClientBuilder::<T>::new(&api_key).build().unwrap()
289    }
290}
291
292impl<T> CompletionClient for Client<T>
293where
294    T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
295{
296    type CompletionModel = CompletionModel<T>;
297
298    /// Create a new completion model with the given name
299    ///
300    /// # Example
301    /// ```
302    /// use rig::providers::huggingface::{Client, self}
303    ///
304    /// // Initialize the Huggingface client
305    /// let client = Client::new("your-huggingface-api-key");
306    ///
307    /// let completion_model = client.completion_model(huggingface::GEMMA_2);
308    /// ```
309    fn completion_model(&self, model: &str) -> Self::CompletionModel {
310        CompletionModel::new(self.clone(), model)
311    }
312}
313
314impl<T> TranscriptionClient for Client<T>
315where
316    T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
317{
318    type TranscriptionModel = TranscriptionModel<T>;
319
320    /// Create a new transcription model with the given name
321    ///
322    /// # Example
323    /// ```
324    /// use rig::providers::huggingface::{Client, self}
325    ///
326    /// // Initialize the Huggingface client
327    /// let client = Client::new("your-huggingface-api-key");
328    ///
329    /// let completion_model = client.transcription_model(huggingface::WHISPER_LARGE_V3);
330    /// ```
331    ///
332    fn transcription_model(&self, model: &str) -> Self::TranscriptionModel {
333        TranscriptionModel::new(self.clone(), model)
334    }
335}
336
337#[cfg(feature = "image")]
338impl<T> ImageGenerationClient for Client<T>
339where
340    T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
341{
342    type ImageGenerationModel = ImageGenerationModel<T>;
343
344    /// Create a new image generation model with the given name
345    ///
346    /// # Example
347    /// ```
348    /// use rig::providers::huggingface::{Client, self}
349    ///
350    /// // Initialize the Huggingface client
351    /// let client = Client::new("your-huggingface-api-key");
352    ///
353    /// let completion_model = client.image_generation_model(huggingface::WHISPER_LARGE_V3);
354    /// ```
355    fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
356        ImageGenerationModel::new(self.clone(), model)
357    }
358}
359
360impl<T> VerifyClient for Client<T>
361where
362    T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
363{
364    #[cfg_attr(feature = "worker", worker::send)]
365    async fn verify(&self) -> Result<(), VerifyError> {
366        let req = self
367            .get("/api/whoami-v2")?
368            .body(http_client::NoBody)
369            .map_err(|e| VerifyError::HttpError(e.into()))?;
370
371        let response = self
372            .http_client
373            .send::<_, Vec<u8>>(req)
374            .await
375            .map_err(|e| VerifyError::HttpError(http_client::Error::Instance(e.into())))?;
376
377        match response.status() {
378            reqwest::StatusCode::OK => Ok(()),
379            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
380            reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
381                let text = http_client::text(response).await?;
382                Err(VerifyError::ProviderError(text))
383            }
384            _ => {
385                let text = http_client::text(response).await?;
386                Err(VerifyError::ProviderError(text))
387            }
388        }
389    }
390}
391
392impl_conversion_traits!(AsEmbeddings, AsAudioGeneration for Client<T>);