rig/providers/huggingface/
client.rs

1use super::completion::CompletionModel;
2#[cfg(feature = "image")]
3use crate::client::ImageGenerationClient;
4use crate::client::{
5    ClientBuilderError, CompletionClient, ProviderClient, TranscriptionClient, VerifyClient,
6    VerifyError,
7};
8use crate::http_client::{self, HttpClientExt};
9#[cfg(feature = "image")]
10use crate::image_generation::ImageGenerationError;
11#[cfg(feature = "image")]
12use crate::providers::huggingface::image_generation::ImageGenerationModel;
13use crate::providers::huggingface::transcription::TranscriptionModel;
14use crate::transcription::TranscriptionError;
15use bytes::Bytes;
16use rig::client::impl_conversion_traits;
17use std::fmt::Debug;
18use std::fmt::Display;
19
20// ================================================================
21// Main Huggingface Client
22// ================================================================
23const HUGGINGFACE_API_BASE_URL: &str = "https://router.huggingface.co/";
24
25#[derive(Debug, Clone, PartialEq, Default)]
26pub enum SubProvider {
27    #[default]
28    HFInference,
29    Together,
30    SambaNova,
31    Fireworks,
32    Hyperbolic,
33    Nebius,
34    Novita,
35    Custom(String),
36}
37
38impl SubProvider {
39    /// Get the chat completion endpoint for the SubProvider
40    /// Required because Huggingface Inference requires the model
41    /// in the url and in the request body.
42    pub fn completion_endpoint(&self, model: &str) -> String {
43        match self {
44            SubProvider::HFInference => format!("/{model}/v1/chat/completions"),
45            _ => "/v1/chat/completions".to_string(),
46        }
47    }
48
49    /// Get the transcription endpoint for the SubProvider
50    /// Required because Huggingface Inference requires the model
51    /// in the url and in the request body.
52    pub fn transcription_endpoint(&self, model: &str) -> Result<String, TranscriptionError> {
53        match self {
54            SubProvider::HFInference => Ok(format!("/{model}")),
55            _ => Err(TranscriptionError::ProviderError(format!(
56                "transcription endpoint is not supported yet for {self}"
57            ))),
58        }
59    }
60
61    /// Get the image generation endpoint for the SubProvider
62    /// Required because Huggingface Inference requires the model
63    /// in the url and in the request body.
64    #[cfg(feature = "image")]
65    pub fn image_generation_endpoint(&self, model: &str) -> Result<String, ImageGenerationError> {
66        match self {
67            SubProvider::HFInference => Ok(format!("/{model}")),
68            _ => Err(ImageGenerationError::ProviderError(format!(
69                "image generation endpoint is not supported yet for {self}"
70            ))),
71        }
72    }
73
74    pub fn model_identifier(&self, model: &str) -> String {
75        match self {
76            SubProvider::Fireworks => format!("accounts/fireworks/models/{model}"),
77            _ => model.to_string(),
78        }
79    }
80}
81
82impl From<&str> for SubProvider {
83    fn from(s: &str) -> Self {
84        SubProvider::Custom(s.to_string())
85    }
86}
87
88impl From<String> for SubProvider {
89    fn from(value: String) -> Self {
90        SubProvider::Custom(value)
91    }
92}
93
94impl Display for SubProvider {
95    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96        let route = match self {
97            SubProvider::HFInference => "hf-inference/models".to_string(),
98            SubProvider::Together => "together".to_string(),
99            SubProvider::SambaNova => "sambanova".to_string(),
100            SubProvider::Fireworks => "fireworks-ai".to_string(),
101            SubProvider::Hyperbolic => "hyperbolic".to_string(),
102            SubProvider::Nebius => "nebius".to_string(),
103            SubProvider::Novita => "novita".to_string(),
104            SubProvider::Custom(route) => route.clone(),
105        };
106
107        write!(f, "{route}")
108    }
109}
110
111pub struct ClientBuilder<T = reqwest::Client> {
112    api_key: String,
113    base_url: String,
114    sub_provider: SubProvider,
115    http_client: T,
116}
117
118impl<T> ClientBuilder<T>
119where
120    T: Default,
121{
122    pub fn new(api_key: &str) -> ClientBuilder<T> {
123        ClientBuilder {
124            api_key: api_key.to_string(),
125            base_url: HUGGINGFACE_API_BASE_URL.to_string(),
126            sub_provider: SubProvider::default(),
127            http_client: Default::default(),
128        }
129    }
130}
131
132impl<T> ClientBuilder<T> {
133    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<U> {
134        ClientBuilder {
135            api_key: self.api_key,
136            base_url: self.base_url,
137            sub_provider: self.sub_provider,
138            http_client,
139        }
140    }
141
142    pub fn base_url(mut self, base_url: &str) -> Self {
143        self.base_url = base_url.to_string();
144        self
145    }
146
147    pub fn sub_provider(mut self, provider: impl Into<SubProvider>) -> Self {
148        self.sub_provider = provider.into();
149        self
150    }
151
152    pub fn build(self) -> Result<Client<T>, ClientBuilderError> {
153        let route = self.sub_provider.to_string();
154        let base_url = format!("{}/{}", self.base_url, route).replace("//", "/");
155
156        let mut default_headers = reqwest::header::HeaderMap::new();
157        default_headers.insert(
158            "Content-Type",
159            "application/json"
160                .parse()
161                .expect("Failed to parse Content-Type"),
162        );
163
164        Ok(Client {
165            base_url,
166            default_headers,
167            api_key: self.api_key,
168            http_client: self.http_client,
169            sub_provider: self.sub_provider,
170        })
171    }
172}
173
174#[derive(Clone)]
175pub struct Client<T = reqwest::Client> {
176    base_url: String,
177    default_headers: reqwest::header::HeaderMap,
178    api_key: String,
179    http_client: T,
180    pub(crate) sub_provider: SubProvider,
181}
182
183impl<T> Debug for Client<T>
184where
185    T: Debug,
186{
187    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188        f.debug_struct("Client")
189            .field("base_url", &self.base_url)
190            .field("http_client", &self.http_client)
191            .field("default_headers", &self.default_headers)
192            .field("sub_provider", &self.sub_provider)
193            .field("api_key", &"<REDACTED>")
194            .finish()
195    }
196}
197
198impl<T> Client<T>
199where
200    T: Default,
201{
202    /// Create a new Huggingface client builder.
203    ///
204    /// # Example
205    /// ```
206    /// use rig::providers::huggingface::{ClientBuilder, self};
207    ///
208    /// // Initialize the Huggingface client
209    /// let client = Client::builder("your-huggingface-api-key")
210    ///    .build()
211    /// ```
212    pub fn builder(api_key: &str) -> ClientBuilder<T> {
213        ClientBuilder::new(api_key)
214    }
215
216    /// Create a new Huggingface client. For more control, use the `builder` method.
217    ///
218    /// # Panics
219    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
220    pub fn new(api_key: &str) -> Self {
221        Self::builder(api_key)
222            .build()
223            .expect("Huggingface client should build")
224    }
225}
226
227impl Client<reqwest::Client> {
228    pub(crate) fn post_reqwest(&self, path: &str) -> reqwest::RequestBuilder {
229        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
230
231        self.http_client
232            .post(url)
233            .headers(self.default_headers.clone())
234            .bearer_auth(&self.api_key)
235    }
236}
237
238impl<T> Client<T>
239where
240    T: HttpClientExt,
241{
242    pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
243        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
244
245        let mut req = http_client::Request::post(url);
246
247        if let Some(hs) = req.headers_mut() {
248            *hs = self.default_headers.clone();
249        }
250
251        http_client::with_bearer_auth(req, &self.api_key)
252    }
253
254    pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
255        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
256
257        let mut req = http_client::Request::get(url);
258
259        if let Some(hs) = req.headers_mut() {
260            *hs = self.default_headers.clone();
261        }
262
263        http_client::with_bearer_auth(req, &self.api_key)
264    }
265
266    pub(crate) async fn send<U, V>(
267        &self,
268        req: http_client::Request<U>,
269    ) -> http_client::Result<http_client::Response<http_client::LazyBody<V>>>
270    where
271        U: Into<Bytes> + Send,
272        V: From<Bytes> + Send + 'static,
273    {
274        self.http_client.send(req).await
275    }
276}
277
278impl ProviderClient for Client<reqwest::Client> {
279    /// Create a new Huggingface client from the `HUGGINGFACE_API_KEY` environment variable.
280    /// Panics if the environment variable is not set.
281    fn from_env() -> Self {
282        let api_key = std::env::var("HUGGINGFACE_API_KEY").expect("HUGGINGFACE_API_KEY is not set");
283        Self::new(&api_key)
284    }
285
286    fn from_val(input: crate::client::ProviderValue) -> Self {
287        let crate::client::ProviderValue::Simple(api_key) = input else {
288            panic!("Incorrect provider value type")
289        };
290        Self::new(&api_key)
291    }
292}
293
294impl CompletionClient for Client<reqwest::Client> {
295    type CompletionModel = CompletionModel<reqwest::Client>;
296
297    /// Create a new completion model with the given name
298    ///
299    /// # Example
300    /// ```
301    /// use rig::providers::huggingface::{Client, self}
302    ///
303    /// // Initialize the Huggingface client
304    /// let client = Client::new("your-huggingface-api-key");
305    ///
306    /// let completion_model = client.completion_model(huggingface::GEMMA_2);
307    /// ```
308    fn completion_model(&self, model: &str) -> CompletionModel<reqwest::Client> {
309        CompletionModel::new(self.clone(), model)
310    }
311}
312
313impl TranscriptionClient for Client<reqwest::Client> {
314    type TranscriptionModel = TranscriptionModel<reqwest::Client>;
315
316    /// Create a new transcription model with the given name
317    ///
318    /// # Example
319    /// ```
320    /// use rig::providers::huggingface::{Client, self}
321    ///
322    /// // Initialize the Huggingface client
323    /// let client = Client::new("your-huggingface-api-key");
324    ///
325    /// let completion_model = client.transcription_model(huggingface::WHISPER_LARGE_V3);
326    /// ```
327    ///
328    fn transcription_model(&self, model: &str) -> TranscriptionModel<reqwest::Client> {
329        TranscriptionModel::new(self.clone(), model)
330    }
331}
332
333#[cfg(feature = "image")]
334impl ImageGenerationClient for Client<reqwest::Client> {
335    type ImageGenerationModel = ImageGenerationModel<reqwest::Client>;
336
337    /// Create a new image generation model with the given name
338    ///
339    /// # Example
340    /// ```
341    /// use rig::providers::huggingface::{Client, self}
342    ///
343    /// // Initialize the Huggingface client
344    /// let client = Client::new("your-huggingface-api-key");
345    ///
346    /// let completion_model = client.image_generation_model(huggingface::WHISPER_LARGE_V3);
347    /// ```
348    fn image_generation_model(&self, model: &str) -> ImageGenerationModel<reqwest::Client> {
349        ImageGenerationModel::new(self.clone(), model)
350    }
351}
352
353impl VerifyClient for Client<reqwest::Client> {
354    #[cfg_attr(feature = "worker", worker::send)]
355    async fn verify(&self) -> Result<(), VerifyError> {
356        let req = self
357            .get("/api/whoami-v2")?
358            .body(http_client::NoBody)
359            .map_err(|e| VerifyError::HttpError(e.into()))?;
360
361        let req = reqwest::Request::try_from(req)
362            .map_err(|e| VerifyError::HttpError(http_client::Error::Instance(e.into())))?;
363
364        let response: reqwest::Response = self
365            .http_client
366            .execute(req)
367            .await
368            .map_err(|e| VerifyError::HttpError(http_client::Error::Instance(e.into())))?;
369
370        match response.status() {
371            reqwest::StatusCode::OK => Ok(()),
372            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
373            reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
374                let text = response
375                    .text()
376                    .await
377                    .map_err(|e| VerifyError::HttpError(http_client::Error::Instance(e.into())))?;
378                Err(VerifyError::ProviderError(text))
379            }
380            _ => {
381                response
382                    .error_for_status()
383                    .map_err(|e| VerifyError::HttpError(http_client::Error::Instance(e.into())))?;
384                Ok(())
385            }
386        }
387    }
388}
389
390impl_conversion_traits!(AsEmbeddings, AsAudioGeneration for Client<T>);