rig/providers/huggingface/
client.rs

1use super::completion::CompletionModel;
2#[cfg(feature = "image")]
3use crate::client::ImageGenerationClient;
4use crate::client::{CompletionClient, ProviderClient, TranscriptionClient};
5#[cfg(feature = "image")]
6use crate::image_generation::ImageGenerationError;
7#[cfg(feature = "image")]
8use crate::providers::huggingface::image_generation::ImageGenerationModel;
9use crate::providers::huggingface::transcription::TranscriptionModel;
10use crate::transcription::TranscriptionError;
11use rig::client::impl_conversion_traits;
12use std::fmt::Display;
13
14// ================================================================
15// Main Huggingface Client
16// ================================================================
17const HUGGINGFACE_API_BASE_URL: &str = "https://router.huggingface.co/";
18
19#[derive(Debug, Clone, PartialEq, Default)]
20pub enum SubProvider {
21    #[default]
22    HFInference,
23    Together,
24    SambaNova,
25    Fireworks,
26    Hyperbolic,
27    Nebius,
28    Novita,
29    Custom(String),
30}
31
32impl SubProvider {
33    /// Get the chat completion endpoint for the SubProvider
34    /// Required because Huggingface Inference requires the model
35    /// in the url and in the request body.
36    pub fn completion_endpoint(&self, model: &str) -> String {
37        match self {
38            SubProvider::HFInference => format!("/{model}/v1/chat/completions"),
39            _ => "/v1/chat/completions".to_string(),
40        }
41    }
42
43    /// Get the transcription endpoint for the SubProvider
44    /// Required because Huggingface Inference requires the model
45    /// in the url and in the request body.
46    pub fn transcription_endpoint(&self, model: &str) -> Result<String, TranscriptionError> {
47        match self {
48            SubProvider::HFInference => Ok(format!("/{model}")),
49            _ => Err(TranscriptionError::ProviderError(format!(
50                "transcription endpoint is not supported yet for {self}"
51            ))),
52        }
53    }
54
55    /// Get the image generation endpoint for the SubProvider
56    /// Required because Huggingface Inference requires the model
57    /// in the url and in the request body.
58    #[cfg(feature = "image")]
59    pub fn image_generation_endpoint(&self, model: &str) -> Result<String, ImageGenerationError> {
60        match self {
61            SubProvider::HFInference => Ok(format!("/{model}")),
62            _ => Err(ImageGenerationError::ProviderError(format!(
63                "image generation endpoint is not supported yet for {self}"
64            ))),
65        }
66    }
67
68    pub fn model_identifier(&self, model: &str) -> String {
69        match self {
70            SubProvider::Fireworks => format!("accounts/fireworks/models/{model}"),
71            _ => model.to_string(),
72        }
73    }
74}
75
76impl From<&str> for SubProvider {
77    fn from(s: &str) -> Self {
78        SubProvider::Custom(s.to_string())
79    }
80}
81
82impl From<String> for SubProvider {
83    fn from(value: String) -> Self {
84        SubProvider::Custom(value)
85    }
86}
87
88impl Display for SubProvider {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        let route = match self {
91            SubProvider::HFInference => "hf-inference/models".to_string(),
92            SubProvider::Together => "together".to_string(),
93            SubProvider::SambaNova => "sambanova".to_string(),
94            SubProvider::Fireworks => "fireworks-ai".to_string(),
95            SubProvider::Hyperbolic => "hyperbolic".to_string(),
96            SubProvider::Nebius => "nebius".to_string(),
97            SubProvider::Novita => "novita".to_string(),
98            SubProvider::Custom(route) => route.clone(),
99        };
100
101        write!(f, "{route}")
102    }
103}
104
105pub struct ClientBuilder {
106    api_key: String,
107    base_url: String,
108    sub_provider: SubProvider,
109}
110
111impl ClientBuilder {
112    pub fn new(api_key: &str) -> Self {
113        Self {
114            api_key: api_key.to_string(),
115            base_url: HUGGINGFACE_API_BASE_URL.to_string(),
116            sub_provider: SubProvider::default(),
117        }
118    }
119
120    pub fn base_url(mut self, base_url: &str) -> Self {
121        self.base_url = base_url.to_string();
122        self
123    }
124
125    pub fn sub_provider(mut self, provider: impl Into<SubProvider>) -> Self {
126        self.sub_provider = provider.into();
127        self
128    }
129
130    pub fn build(self) -> Client {
131        let route = self.sub_provider.to_string();
132
133        let base_url = format!("{}/{}", self.base_url, route).replace("//", "/");
134
135        Client::from_url(self.api_key.as_str(), base_url.as_str(), self.sub_provider)
136    }
137}
138
139#[derive(Clone)]
140pub struct Client {
141    base_url: String,
142    default_headers: reqwest::header::HeaderMap,
143    api_key: String,
144    http_client: reqwest::Client,
145    pub(crate) sub_provider: SubProvider,
146}
147
148impl std::fmt::Debug for Client {
149    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150        f.debug_struct("Client")
151            .field("base_url", &self.base_url)
152            .field("http_client", &self.http_client)
153            .field("default_headers", &self.default_headers)
154            .field("sub_provider", &self.sub_provider)
155            .field("api_key", &"<REDACTED>")
156            .finish()
157    }
158}
159
160impl Client {
161    /// Create a new Huggingface client with the given API key.
162    pub fn new(api_key: &str) -> Self {
163        let base_url =
164            format!("{}/{}", HUGGINGFACE_API_BASE_URL, SubProvider::HFInference).replace("//", "/");
165        Self::from_url(api_key, &base_url, SubProvider::HFInference)
166    }
167
168    /// Create a new Client with the given API key and base API URL.
169    pub fn from_url(api_key: &str, base_url: &str, sub_provider: SubProvider) -> Self {
170        let mut default_headers = reqwest::header::HeaderMap::new();
171        default_headers.insert(
172            "Content-Type",
173            "application/json"
174                .parse()
175                .expect("Failed to parse Content-Type"),
176        );
177        let http_client = reqwest::Client::builder()
178            .build()
179            .expect("Failed to build HTTP client");
180
181        Self {
182            base_url: base_url.to_owned(),
183            api_key: api_key.to_string(),
184            default_headers,
185            http_client,
186            sub_provider,
187        }
188    }
189
190    /// Use your own `reqwest::Client`.
191    /// The API key will be automatically attached upon trying to make a request, so you shouldn't need to add it as a default header.
192    pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
193        self.http_client = client;
194
195        self
196    }
197
198    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
199        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
200        self.http_client
201            .post(url)
202            .bearer_auth(&self.api_key)
203            .headers(self.default_headers.clone())
204    }
205}
206
207impl ProviderClient for Client {
208    /// Create a new Huggingface client from the `HUGGINGFACE_API_KEY` environment variable.
209    /// Panics if the environment variable is not set.
210    fn from_env() -> Self {
211        let api_key = std::env::var("HUGGINGFACE_API_KEY").expect("HUGGINGFACE_API_KEY is not set");
212        Self::new(&api_key)
213    }
214
215    fn from_val(input: crate::client::ProviderValue) -> Self {
216        let crate::client::ProviderValue::Simple(api_key) = input else {
217            panic!("Incorrect provider value type")
218        };
219        Self::new(&api_key)
220    }
221}
222
223impl CompletionClient for Client {
224    type CompletionModel = CompletionModel;
225
226    /// Create a new completion model with the given name
227    ///
228    /// # Example
229    /// ```
230    /// use rig::providers::huggingface::{Client, self}
231    ///
232    /// // Initialize the Huggingface client
233    /// let client = Client::new("your-huggingface-api-key");
234    ///
235    /// let completion_model = client.completion_model(huggingface::GEMMA_2);
236    /// ```
237    fn completion_model(&self, model: &str) -> CompletionModel {
238        CompletionModel::new(self.clone(), model)
239    }
240}
241
242impl TranscriptionClient for Client {
243    type TranscriptionModel = TranscriptionModel;
244
245    /// Create a new transcription model with the given name
246    ///
247    /// # Example
248    /// ```
249    /// use rig::providers::huggingface::{Client, self}
250    ///
251    /// // Initialize the Huggingface client
252    /// let client = Client::new("your-huggingface-api-key");
253    ///
254    /// let completion_model = client.transcription_model(huggingface::WHISPER_LARGE_V3);
255    /// ```
256    ///
257    fn transcription_model(&self, model: &str) -> TranscriptionModel {
258        TranscriptionModel::new(self.clone(), model)
259    }
260}
261
262#[cfg(feature = "image")]
263impl ImageGenerationClient for Client {
264    type ImageGenerationModel = ImageGenerationModel;
265
266    /// Create a new image generation model with the given name
267    ///
268    /// # Example
269    /// ```
270    /// use rig::providers::huggingface::{Client, self}
271    ///
272    /// // Initialize the Huggingface client
273    /// let client = Client::new("your-huggingface-api-key");
274    ///
275    /// let completion_model = client.image_generation_model(huggingface::WHISPER_LARGE_V3);
276    /// ```
277    fn image_generation_model(&self, model: &str) -> ImageGenerationModel {
278        ImageGenerationModel::new(self.clone(), model)
279    }
280}
281
282impl_conversion_traits!(AsEmbeddings, AsAudioGeneration for Client);