rig/providers/huggingface/
client.rs

1use super::completion::CompletionModel;
2#[cfg(feature = "image")]
3use crate::client::ImageGenerationClient;
4use crate::client::{
5    ClientBuilderError, CompletionClient, ProviderClient, TranscriptionClient, VerifyClient,
6    VerifyError,
7};
8#[cfg(feature = "image")]
9use crate::image_generation::ImageGenerationError;
10#[cfg(feature = "image")]
11use crate::providers::huggingface::image_generation::ImageGenerationModel;
12use crate::providers::huggingface::transcription::TranscriptionModel;
13use crate::transcription::TranscriptionError;
14use rig::client::impl_conversion_traits;
15use std::fmt::Display;
16
17// ================================================================
18// Main Huggingface Client
19// ================================================================
20const HUGGINGFACE_API_BASE_URL: &str = "https://router.huggingface.co/";
21
22#[derive(Debug, Clone, PartialEq, Default)]
23pub enum SubProvider {
24    #[default]
25    HFInference,
26    Together,
27    SambaNova,
28    Fireworks,
29    Hyperbolic,
30    Nebius,
31    Novita,
32    Custom(String),
33}
34
35impl SubProvider {
36    /// Get the chat completion endpoint for the SubProvider
37    /// Required because Huggingface Inference requires the model
38    /// in the url and in the request body.
39    pub fn completion_endpoint(&self, model: &str) -> String {
40        match self {
41            SubProvider::HFInference => format!("/{model}/v1/chat/completions"),
42            _ => "/v1/chat/completions".to_string(),
43        }
44    }
45
46    /// Get the transcription endpoint for the SubProvider
47    /// Required because Huggingface Inference requires the model
48    /// in the url and in the request body.
49    pub fn transcription_endpoint(&self, model: &str) -> Result<String, TranscriptionError> {
50        match self {
51            SubProvider::HFInference => Ok(format!("/{model}")),
52            _ => Err(TranscriptionError::ProviderError(format!(
53                "transcription endpoint is not supported yet for {self}"
54            ))),
55        }
56    }
57
58    /// Get the image generation endpoint for the SubProvider
59    /// Required because Huggingface Inference requires the model
60    /// in the url and in the request body.
61    #[cfg(feature = "image")]
62    pub fn image_generation_endpoint(&self, model: &str) -> Result<String, ImageGenerationError> {
63        match self {
64            SubProvider::HFInference => Ok(format!("/{model}")),
65            _ => Err(ImageGenerationError::ProviderError(format!(
66                "image generation endpoint is not supported yet for {self}"
67            ))),
68        }
69    }
70
71    pub fn model_identifier(&self, model: &str) -> String {
72        match self {
73            SubProvider::Fireworks => format!("accounts/fireworks/models/{model}"),
74            _ => model.to_string(),
75        }
76    }
77}
78
79impl From<&str> for SubProvider {
80    fn from(s: &str) -> Self {
81        SubProvider::Custom(s.to_string())
82    }
83}
84
85impl From<String> for SubProvider {
86    fn from(value: String) -> Self {
87        SubProvider::Custom(value)
88    }
89}
90
91impl Display for SubProvider {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        let route = match self {
94            SubProvider::HFInference => "hf-inference/models".to_string(),
95            SubProvider::Together => "together".to_string(),
96            SubProvider::SambaNova => "sambanova".to_string(),
97            SubProvider::Fireworks => "fireworks-ai".to_string(),
98            SubProvider::Hyperbolic => "hyperbolic".to_string(),
99            SubProvider::Nebius => "nebius".to_string(),
100            SubProvider::Novita => "novita".to_string(),
101            SubProvider::Custom(route) => route.clone(),
102        };
103
104        write!(f, "{route}")
105    }
106}
107
108pub struct ClientBuilder {
109    api_key: String,
110    base_url: String,
111    sub_provider: SubProvider,
112    http_client: Option<reqwest::Client>,
113}
114
115impl ClientBuilder {
116    pub fn new(api_key: &str) -> Self {
117        Self {
118            api_key: api_key.to_string(),
119            base_url: HUGGINGFACE_API_BASE_URL.to_string(),
120            sub_provider: SubProvider::default(),
121            http_client: None,
122        }
123    }
124
125    pub fn base_url(mut self, base_url: &str) -> Self {
126        self.base_url = base_url.to_string();
127        self
128    }
129
130    pub fn sub_provider(mut self, provider: impl Into<SubProvider>) -> Self {
131        self.sub_provider = provider.into();
132        self
133    }
134
135    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
136        self.http_client = Some(client);
137        self
138    }
139
140    pub fn build(self) -> Result<Client, ClientBuilderError> {
141        let route = self.sub_provider.to_string();
142        let base_url = format!("{}/{}", self.base_url, route).replace("//", "/");
143
144        let mut default_headers = reqwest::header::HeaderMap::new();
145        default_headers.insert(
146            "Content-Type",
147            "application/json"
148                .parse()
149                .expect("Failed to parse Content-Type"),
150        );
151        let http_client = if let Some(http_client) = self.http_client {
152            http_client
153        } else {
154            reqwest::Client::builder().build()?
155        };
156
157        Ok(Client {
158            base_url,
159            default_headers,
160            api_key: self.api_key,
161            http_client,
162            sub_provider: self.sub_provider,
163        })
164    }
165}
166
167#[derive(Clone)]
168pub struct Client {
169    base_url: String,
170    default_headers: reqwest::header::HeaderMap,
171    api_key: String,
172    http_client: reqwest::Client,
173    pub(crate) sub_provider: SubProvider,
174}
175
176impl std::fmt::Debug for Client {
177    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
178        f.debug_struct("Client")
179            .field("base_url", &self.base_url)
180            .field("http_client", &self.http_client)
181            .field("default_headers", &self.default_headers)
182            .field("sub_provider", &self.sub_provider)
183            .field("api_key", &"<REDACTED>")
184            .finish()
185    }
186}
187
188impl Client {
189    /// Create a new Huggingface client builder.
190    ///
191    /// # Example
192    /// ```
193    /// use rig::providers::huggingface::{ClientBuilder, self};
194    ///
195    /// // Initialize the Huggingface client
196    /// let client = Client::builder("your-huggingface-api-key")
197    ///    .build()
198    /// ```
199    pub fn builder(api_key: &str) -> ClientBuilder {
200        ClientBuilder::new(api_key)
201    }
202
203    /// Create a new Huggingface client. For more control, use the `builder` method.
204    ///
205    /// # Panics
206    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
207    pub fn new(api_key: &str) -> Self {
208        Self::builder(api_key)
209            .build()
210            .expect("Huggingface client should build")
211    }
212
213    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
214        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
215        self.http_client
216            .post(url)
217            .bearer_auth(&self.api_key)
218            .headers(self.default_headers.clone())
219    }
220
221    pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder {
222        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
223        self.http_client
224            .get(url)
225            .bearer_auth(&self.api_key)
226            .headers(self.default_headers.clone())
227    }
228}
229
230impl ProviderClient for Client {
231    /// Create a new Huggingface client from the `HUGGINGFACE_API_KEY` environment variable.
232    /// Panics if the environment variable is not set.
233    fn from_env() -> Self {
234        let api_key = std::env::var("HUGGINGFACE_API_KEY").expect("HUGGINGFACE_API_KEY is not set");
235        Self::new(&api_key)
236    }
237
238    fn from_val(input: crate::client::ProviderValue) -> Self {
239        let crate::client::ProviderValue::Simple(api_key) = input else {
240            panic!("Incorrect provider value type")
241        };
242        Self::new(&api_key)
243    }
244}
245
246impl CompletionClient for Client {
247    type CompletionModel = CompletionModel;
248
249    /// Create a new completion model with the given name
250    ///
251    /// # Example
252    /// ```
253    /// use rig::providers::huggingface::{Client, self}
254    ///
255    /// // Initialize the Huggingface client
256    /// let client = Client::new("your-huggingface-api-key");
257    ///
258    /// let completion_model = client.completion_model(huggingface::GEMMA_2);
259    /// ```
260    fn completion_model(&self, model: &str) -> CompletionModel {
261        CompletionModel::new(self.clone(), model)
262    }
263}
264
265impl TranscriptionClient for Client {
266    type TranscriptionModel = TranscriptionModel;
267
268    /// Create a new transcription model with the given name
269    ///
270    /// # Example
271    /// ```
272    /// use rig::providers::huggingface::{Client, self}
273    ///
274    /// // Initialize the Huggingface client
275    /// let client = Client::new("your-huggingface-api-key");
276    ///
277    /// let completion_model = client.transcription_model(huggingface::WHISPER_LARGE_V3);
278    /// ```
279    ///
280    fn transcription_model(&self, model: &str) -> TranscriptionModel {
281        TranscriptionModel::new(self.clone(), model)
282    }
283}
284
285#[cfg(feature = "image")]
286impl ImageGenerationClient for Client {
287    type ImageGenerationModel = ImageGenerationModel;
288
289    /// Create a new image generation model with the given name
290    ///
291    /// # Example
292    /// ```
293    /// use rig::providers::huggingface::{Client, self}
294    ///
295    /// // Initialize the Huggingface client
296    /// let client = Client::new("your-huggingface-api-key");
297    ///
298    /// let completion_model = client.image_generation_model(huggingface::WHISPER_LARGE_V3);
299    /// ```
300    fn image_generation_model(&self, model: &str) -> ImageGenerationModel {
301        ImageGenerationModel::new(self.clone(), model)
302    }
303}
304
305impl VerifyClient for Client {
306    #[cfg_attr(feature = "worker", worker::send)]
307    async fn verify(&self) -> Result<(), VerifyError> {
308        let response = self.get("/api/whoami-v2").send().await?;
309        match response.status() {
310            reqwest::StatusCode::OK => Ok(()),
311            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
312            reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
313                Err(VerifyError::ProviderError(response.text().await?))
314            }
315            _ => {
316                response.error_for_status()?;
317                Ok(())
318            }
319        }
320    }
321}
322
323impl_conversion_traits!(AsEmbeddings, AsAudioGeneration for Client);