rig/providers/huggingface/
client.rs

1use super::completion::CompletionModel;
2#[cfg(feature = "image")]
3use crate::client::ImageGenerationClient;
4use crate::client::{ClientBuilderError, CompletionClient, ProviderClient, TranscriptionClient};
5#[cfg(feature = "image")]
6use crate::image_generation::ImageGenerationError;
7#[cfg(feature = "image")]
8use crate::providers::huggingface::image_generation::ImageGenerationModel;
9use crate::providers::huggingface::transcription::TranscriptionModel;
10use crate::transcription::TranscriptionError;
11use rig::client::impl_conversion_traits;
12use std::fmt::Display;
13
14// ================================================================
15// Main Huggingface Client
16// ================================================================
17const HUGGINGFACE_API_BASE_URL: &str = "https://router.huggingface.co/";
18
19#[derive(Debug, Clone, PartialEq, Default)]
20pub enum SubProvider {
21    #[default]
22    HFInference,
23    Together,
24    SambaNova,
25    Fireworks,
26    Hyperbolic,
27    Nebius,
28    Novita,
29    Custom(String),
30}
31
32impl SubProvider {
33    /// Get the chat completion endpoint for the SubProvider
34    /// Required because Huggingface Inference requires the model
35    /// in the url and in the request body.
36    pub fn completion_endpoint(&self, model: &str) -> String {
37        match self {
38            SubProvider::HFInference => format!("/{model}/v1/chat/completions"),
39            _ => "/v1/chat/completions".to_string(),
40        }
41    }
42
43    /// Get the transcription endpoint for the SubProvider
44    /// Required because Huggingface Inference requires the model
45    /// in the url and in the request body.
46    pub fn transcription_endpoint(&self, model: &str) -> Result<String, TranscriptionError> {
47        match self {
48            SubProvider::HFInference => Ok(format!("/{model}")),
49            _ => Err(TranscriptionError::ProviderError(format!(
50                "transcription endpoint is not supported yet for {self}"
51            ))),
52        }
53    }
54
55    /// Get the image generation endpoint for the SubProvider
56    /// Required because Huggingface Inference requires the model
57    /// in the url and in the request body.
58    #[cfg(feature = "image")]
59    pub fn image_generation_endpoint(&self, model: &str) -> Result<String, ImageGenerationError> {
60        match self {
61            SubProvider::HFInference => Ok(format!("/{model}")),
62            _ => Err(ImageGenerationError::ProviderError(format!(
63                "image generation endpoint is not supported yet for {self}"
64            ))),
65        }
66    }
67
68    pub fn model_identifier(&self, model: &str) -> String {
69        match self {
70            SubProvider::Fireworks => format!("accounts/fireworks/models/{model}"),
71            _ => model.to_string(),
72        }
73    }
74}
75
76impl From<&str> for SubProvider {
77    fn from(s: &str) -> Self {
78        SubProvider::Custom(s.to_string())
79    }
80}
81
82impl From<String> for SubProvider {
83    fn from(value: String) -> Self {
84        SubProvider::Custom(value)
85    }
86}
87
88impl Display for SubProvider {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        let route = match self {
91            SubProvider::HFInference => "hf-inference/models".to_string(),
92            SubProvider::Together => "together".to_string(),
93            SubProvider::SambaNova => "sambanova".to_string(),
94            SubProvider::Fireworks => "fireworks-ai".to_string(),
95            SubProvider::Hyperbolic => "hyperbolic".to_string(),
96            SubProvider::Nebius => "nebius".to_string(),
97            SubProvider::Novita => "novita".to_string(),
98            SubProvider::Custom(route) => route.clone(),
99        };
100
101        write!(f, "{route}")
102    }
103}
104
105pub struct ClientBuilder {
106    api_key: String,
107    base_url: String,
108    sub_provider: SubProvider,
109    http_client: Option<reqwest::Client>,
110}
111
112impl ClientBuilder {
113    pub fn new(api_key: &str) -> Self {
114        Self {
115            api_key: api_key.to_string(),
116            base_url: HUGGINGFACE_API_BASE_URL.to_string(),
117            sub_provider: SubProvider::default(),
118            http_client: None,
119        }
120    }
121
122    pub fn base_url(mut self, base_url: &str) -> Self {
123        self.base_url = base_url.to_string();
124        self
125    }
126
127    pub fn sub_provider(mut self, provider: impl Into<SubProvider>) -> Self {
128        self.sub_provider = provider.into();
129        self
130    }
131
132    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
133        self.http_client = Some(client);
134        self
135    }
136
137    pub fn build(self) -> Result<Client, ClientBuilderError> {
138        let route = self.sub_provider.to_string();
139        let base_url = format!("{}/{}", self.base_url, route).replace("//", "/");
140
141        let mut default_headers = reqwest::header::HeaderMap::new();
142        default_headers.insert(
143            "Content-Type",
144            "application/json"
145                .parse()
146                .expect("Failed to parse Content-Type"),
147        );
148        let http_client = if let Some(http_client) = self.http_client {
149            http_client
150        } else {
151            reqwest::Client::builder().build()?
152        };
153
154        Ok(Client {
155            base_url,
156            default_headers,
157            api_key: self.api_key,
158            http_client,
159            sub_provider: self.sub_provider,
160        })
161    }
162}
163
164#[derive(Clone)]
165pub struct Client {
166    base_url: String,
167    default_headers: reqwest::header::HeaderMap,
168    api_key: String,
169    http_client: reqwest::Client,
170    pub(crate) sub_provider: SubProvider,
171}
172
173impl std::fmt::Debug for Client {
174    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175        f.debug_struct("Client")
176            .field("base_url", &self.base_url)
177            .field("http_client", &self.http_client)
178            .field("default_headers", &self.default_headers)
179            .field("sub_provider", &self.sub_provider)
180            .field("api_key", &"<REDACTED>")
181            .finish()
182    }
183}
184
185impl Client {
186    /// Create a new Huggingface client builder.
187    ///
188    /// # Example
189    /// ```
190    /// use rig::providers::huggingface::{ClientBuilder, self};
191    ///
192    /// // Initialize the Huggingface client
193    /// let client = Client::builder("your-huggingface-api-key")
194    ///    .build()
195    /// ```
196    pub fn builder(api_key: &str) -> ClientBuilder {
197        ClientBuilder::new(api_key)
198    }
199
200    /// Create a new Huggingface client. For more control, use the `builder` method.
201    ///
202    /// # Panics
203    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
204    pub fn new(api_key: &str) -> Self {
205        Self::builder(api_key)
206            .build()
207            .expect("Huggingface client should build")
208    }
209
210    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
211        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
212        self.http_client
213            .post(url)
214            .bearer_auth(&self.api_key)
215            .headers(self.default_headers.clone())
216    }
217}
218
219impl ProviderClient for Client {
220    /// Create a new Huggingface client from the `HUGGINGFACE_API_KEY` environment variable.
221    /// Panics if the environment variable is not set.
222    fn from_env() -> Self {
223        let api_key = std::env::var("HUGGINGFACE_API_KEY").expect("HUGGINGFACE_API_KEY is not set");
224        Self::new(&api_key)
225    }
226
227    fn from_val(input: crate::client::ProviderValue) -> Self {
228        let crate::client::ProviderValue::Simple(api_key) = input else {
229            panic!("Incorrect provider value type")
230        };
231        Self::new(&api_key)
232    }
233}
234
235impl CompletionClient for Client {
236    type CompletionModel = CompletionModel;
237
238    /// Create a new completion model with the given name
239    ///
240    /// # Example
241    /// ```
242    /// use rig::providers::huggingface::{Client, self}
243    ///
244    /// // Initialize the Huggingface client
245    /// let client = Client::new("your-huggingface-api-key");
246    ///
247    /// let completion_model = client.completion_model(huggingface::GEMMA_2);
248    /// ```
249    fn completion_model(&self, model: &str) -> CompletionModel {
250        CompletionModel::new(self.clone(), model)
251    }
252}
253
254impl TranscriptionClient for Client {
255    type TranscriptionModel = TranscriptionModel;
256
257    /// Create a new transcription model with the given name
258    ///
259    /// # Example
260    /// ```
261    /// use rig::providers::huggingface::{Client, self}
262    ///
263    /// // Initialize the Huggingface client
264    /// let client = Client::new("your-huggingface-api-key");
265    ///
266    /// let completion_model = client.transcription_model(huggingface::WHISPER_LARGE_V3);
267    /// ```
268    ///
269    fn transcription_model(&self, model: &str) -> TranscriptionModel {
270        TranscriptionModel::new(self.clone(), model)
271    }
272}
273
274#[cfg(feature = "image")]
275impl ImageGenerationClient for Client {
276    type ImageGenerationModel = ImageGenerationModel;
277
278    /// Create a new image generation model with the given name
279    ///
280    /// # Example
281    /// ```
282    /// use rig::providers::huggingface::{Client, self}
283    ///
284    /// // Initialize the Huggingface client
285    /// let client = Client::new("your-huggingface-api-key");
286    ///
287    /// let completion_model = client.image_generation_model(huggingface::WHISPER_LARGE_V3);
288    /// ```
289    fn image_generation_model(&self, model: &str) -> ImageGenerationModel {
290        ImageGenerationModel::new(self.clone(), model)
291    }
292}
293
294impl_conversion_traits!(AsEmbeddings, AsAudioGeneration for Client);