rig/providers/huggingface/
client.rs

1use super::completion::CompletionModel;
2#[cfg(feature = "image")]
3use crate::client::ImageGenerationClient;
4use crate::client::{CompletionClient, ProviderClient, TranscriptionClient};
5#[cfg(feature = "image")]
6use crate::image_generation::ImageGenerationError;
7#[cfg(feature = "image")]
8use crate::providers::huggingface::image_generation::ImageGenerationModel;
9use crate::providers::huggingface::transcription::TranscriptionModel;
10use crate::transcription::TranscriptionError;
11use rig::client::impl_conversion_traits;
12use std::fmt::Display;
13
14// ================================================================
15// Main Huggingface Client
16// ================================================================
17const HUGGINGFACE_API_BASE_URL: &str = "https://router.huggingface.co/";
18
19#[derive(Debug, Clone, PartialEq, Default)]
20pub enum SubProvider {
21    #[default]
22    HFInference,
23    Together,
24    SambaNova,
25    Fireworks,
26    Hyperbolic,
27    Nebius,
28    Novita,
29    Custom(String),
30}
31
32impl SubProvider {
33    /// Get the chat completion endpoint for the SubProvider
34    /// Required because Huggingface Inference requires the model
35    /// in the url and in the request body.
36    pub fn completion_endpoint(&self, model: &str) -> String {
37        match self {
38            SubProvider::HFInference => format!("/{model}/v1/chat/completions"),
39            _ => "/v1/chat/completions".to_string(),
40        }
41    }
42
43    /// Get the transcription endpoint for the SubProvider
44    /// Required because Huggingface Inference requires the model
45    /// in the url and in the request body.
46    pub fn transcription_endpoint(&self, model: &str) -> Result<String, TranscriptionError> {
47        match self {
48            SubProvider::HFInference => Ok(format!("/{model}")),
49            _ => Err(TranscriptionError::ProviderError(format!(
50                "transcription endpoint is not supported yet for {self}"
51            ))),
52        }
53    }
54
55    /// Get the image generation endpoint for the SubProvider
56    /// Required because Huggingface Inference requires the model
57    /// in the url and in the request body.
58    #[cfg(feature = "image")]
59    pub fn image_generation_endpoint(&self, model: &str) -> Result<String, ImageGenerationError> {
60        match self {
61            SubProvider::HFInference => Ok(format!("/{model}")),
62            _ => Err(ImageGenerationError::ProviderError(format!(
63                "image generation endpoint is not supported yet for {self}"
64            ))),
65        }
66    }
67
68    pub fn model_identifier(&self, model: &str) -> String {
69        match self {
70            SubProvider::Fireworks => format!("accounts/fireworks/models/{model}"),
71            _ => model.to_string(),
72        }
73    }
74}
75
76impl From<&str> for SubProvider {
77    fn from(s: &str) -> Self {
78        SubProvider::Custom(s.to_string())
79    }
80}
81
82impl From<String> for SubProvider {
83    fn from(value: String) -> Self {
84        SubProvider::Custom(value)
85    }
86}
87
88impl Display for SubProvider {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        let route = match self {
91            SubProvider::HFInference => "hf-inference/models".to_string(),
92            SubProvider::Together => "together".to_string(),
93            SubProvider::SambaNova => "sambanova".to_string(),
94            SubProvider::Fireworks => "fireworks-ai".to_string(),
95            SubProvider::Hyperbolic => "hyperbolic".to_string(),
96            SubProvider::Nebius => "nebius".to_string(),
97            SubProvider::Novita => "novita".to_string(),
98            SubProvider::Custom(route) => route.clone(),
99        };
100
101        write!(f, "{route}")
102    }
103}
104
105pub struct ClientBuilder {
106    api_key: String,
107    base_url: String,
108    sub_provider: SubProvider,
109}
110
111impl ClientBuilder {
112    pub fn new(api_key: &str) -> Self {
113        Self {
114            api_key: api_key.to_string(),
115            base_url: HUGGINGFACE_API_BASE_URL.to_string(),
116            sub_provider: SubProvider::default(),
117        }
118    }
119
120    pub fn base_url(mut self, base_url: &str) -> Self {
121        self.base_url = base_url.to_string();
122        self
123    }
124
125    pub fn sub_provider(mut self, provider: impl Into<SubProvider>) -> Self {
126        self.sub_provider = provider.into();
127        self
128    }
129
130    pub fn build(self) -> Client {
131        let route = self.sub_provider.to_string();
132
133        let base_url = format!("{}/{}", self.base_url, route).replace("//", "/");
134
135        Client::from_url(self.api_key.as_str(), base_url.as_str(), self.sub_provider)
136    }
137}
138
139#[derive(Clone)]
140pub struct Client {
141    base_url: String,
142    default_headers: reqwest::header::HeaderMap,
143    api_key: String,
144    http_client: reqwest::Client,
145    pub(crate) sub_provider: SubProvider,
146}
147
148impl std::fmt::Debug for Client {
149    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150        f.debug_struct("Client")
151            .field("base_url", &self.base_url)
152            .field("http_client", &self.http_client)
153            .field("default_headers", &self.default_headers)
154            .field("sub_provider", &self.sub_provider)
155            .field("api_key", &"<REDACTED>")
156            .finish()
157    }
158}
159
160impl Client {
161    /// Create a new Huggingface client with the given API key.
162    pub fn new(api_key: &str) -> Self {
163        let base_url =
164            format!("{}/{}", HUGGINGFACE_API_BASE_URL, SubProvider::HFInference).replace("//", "/");
165        Self::from_url(api_key, &base_url, SubProvider::HFInference)
166    }
167
168    /// Create a new Client with the given API key and base API URL.
169    pub fn from_url(api_key: &str, base_url: &str, sub_provider: SubProvider) -> Self {
170        let mut default_headers = reqwest::header::HeaderMap::new();
171        default_headers.insert(
172            "Content-Type",
173            "application/json"
174                .parse()
175                .expect("Failed to parse Content-Type"),
176        );
177        let http_client = reqwest::Client::builder()
178            .build()
179            .expect("Failed to build HTTP client");
180
181        Self {
182            base_url: base_url.to_owned(),
183            api_key: api_key.to_string(),
184            default_headers,
185            http_client,
186            sub_provider,
187        }
188    }
189
190    /// Use your own `reqwest::Client`.
191    /// The API key will be automatically attached upon trying to make a request, so you shouldn't need to add it as a default header.
192    pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
193        self.http_client = client;
194
195        self
196    }
197
198    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
199        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
200        self.http_client
201            .post(url)
202            .bearer_auth(&self.api_key)
203            .headers(self.default_headers.clone())
204    }
205}
206
207impl ProviderClient for Client {
208    /// Create a new Huggingface client from the `HUGGINGFACE_API_KEY` environment variable.
209    /// Panics if the environment variable is not set.
210    fn from_env() -> Self {
211        let api_key = std::env::var("HUGGINGFACE_API_KEY").expect("HUGGINGFACE_API_KEY is not set");
212        Self::new(&api_key)
213    }
214}
215
216impl CompletionClient for Client {
217    type CompletionModel = CompletionModel;
218
219    /// Create a new completion model with the given name
220    ///
221    /// # Example
222    /// ```
223    /// use rig::providers::huggingface::{Client, self}
224    ///
225    /// // Initialize the Huggingface client
226    /// let client = Client::new("your-huggingface-api-key");
227    ///
228    /// let completion_model = client.completion_model(huggingface::GEMMA_2);
229    /// ```
230    fn completion_model(&self, model: &str) -> CompletionModel {
231        CompletionModel::new(self.clone(), model)
232    }
233}
234
235impl TranscriptionClient for Client {
236    type TranscriptionModel = TranscriptionModel;
237
238    /// Create a new transcription model with the given name
239    ///
240    /// # Example
241    /// ```
242    /// use rig::providers::huggingface::{Client, self}
243    ///
244    /// // Initialize the Huggingface client
245    /// let client = Client::new("your-huggingface-api-key");
246    ///
247    /// let completion_model = client.transcription_model(huggingface::WHISPER_LARGE_V3);
248    /// ```
249    ///
250    fn transcription_model(&self, model: &str) -> TranscriptionModel {
251        TranscriptionModel::new(self.clone(), model)
252    }
253}
254
255#[cfg(feature = "image")]
256impl ImageGenerationClient for Client {
257    type ImageGenerationModel = ImageGenerationModel;
258
259    /// Create a new image generation model with the given name
260    ///
261    /// # Example
262    /// ```
263    /// use rig::providers::huggingface::{Client, self}
264    ///
265    /// // Initialize the Huggingface client
266    /// let client = Client::new("your-huggingface-api-key");
267    ///
268    /// let completion_model = client.image_generation_model(huggingface::WHISPER_LARGE_V3);
269    /// ```
270    fn image_generation_model(&self, model: &str) -> ImageGenerationModel {
271        ImageGenerationModel::new(self.clone(), model)
272    }
273}
274
275impl_conversion_traits!(AsEmbeddings, AsAudioGeneration for Client);