rig/providers/huggingface/
client.rs

1use super::completion::CompletionModel;
2#[cfg(feature = "image")]
3use crate::client::ImageGenerationClient;
4use crate::client::{CompletionClient, ProviderClient, TranscriptionClient};
5#[cfg(feature = "image")]
6use crate::image_generation::ImageGenerationError;
7#[cfg(feature = "image")]
8use crate::providers::huggingface::image_generation::ImageGenerationModel;
9use crate::providers::huggingface::transcription::TranscriptionModel;
10use crate::transcription::TranscriptionError;
11use rig::client::impl_conversion_traits;
12use std::fmt::Display;
13
14// ================================================================
15// Main Huggingface Client
16// ================================================================
17const HUGGINGFACE_API_BASE_URL: &str = "https://router.huggingface.co/";
18
19#[derive(Debug, Clone, PartialEq, Default)]
20pub enum SubProvider {
21    #[default]
22    HFInference,
23    Together,
24    SambaNova,
25    Fireworks,
26    Hyperbolic,
27    Nebius,
28    Novita,
29    Custom(String),
30}
31
32impl SubProvider {
33    /// Get the chat completion endpoint for the SubProvider
34    /// Required because Huggingface Inference requires the model
35    /// in the url and in the request body.
36    pub fn completion_endpoint(&self, model: &str) -> String {
37        match self {
38            SubProvider::HFInference => format!("/{model}/v1/chat/completions"),
39            _ => "/v1/chat/completions".to_string(),
40        }
41    }
42
43    /// Get the transcription endpoint for the SubProvider
44    /// Required because Huggingface Inference requires the model
45    /// in the url and in the request body.
46    pub fn transcription_endpoint(&self, model: &str) -> Result<String, TranscriptionError> {
47        match self {
48            SubProvider::HFInference => Ok(format!("/{model}")),
49            _ => Err(TranscriptionError::ProviderError(format!(
50                "transcription endpoint is not supported yet for {self}"
51            ))),
52        }
53    }
54
55    /// Get the image generation endpoint for the SubProvider
56    /// Required because Huggingface Inference requires the model
57    /// in the url and in the request body.
58    #[cfg(feature = "image")]
59    pub fn image_generation_endpoint(&self, model: &str) -> Result<String, ImageGenerationError> {
60        match self {
61            SubProvider::HFInference => Ok(format!("/{}", model)),
62            _ => Err(ImageGenerationError::ProviderError(format!(
63                "image generation endpoint is not supported yet for {}",
64                self
65            ))),
66        }
67    }
68
69    pub fn model_identifier(&self, model: &str) -> String {
70        match self {
71            SubProvider::Fireworks => format!("accounts/fireworks/models/{model}"),
72            _ => model.to_string(),
73        }
74    }
75}
76
77impl From<&str> for SubProvider {
78    fn from(s: &str) -> Self {
79        SubProvider::Custom(s.to_string())
80    }
81}
82
83impl From<String> for SubProvider {
84    fn from(value: String) -> Self {
85        SubProvider::Custom(value)
86    }
87}
88
89impl Display for SubProvider {
90    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91        let route = match self {
92            SubProvider::HFInference => "hf-inference/models".to_string(),
93            SubProvider::Together => "together".to_string(),
94            SubProvider::SambaNova => "sambanova".to_string(),
95            SubProvider::Fireworks => "fireworks-ai".to_string(),
96            SubProvider::Hyperbolic => "hyperbolic".to_string(),
97            SubProvider::Nebius => "nebius".to_string(),
98            SubProvider::Novita => "novita".to_string(),
99            SubProvider::Custom(route) => route.clone(),
100        };
101
102        write!(f, "{route}")
103    }
104}
105
106pub struct ClientBuilder {
107    api_key: String,
108    base_url: String,
109    sub_provider: SubProvider,
110}
111
112impl ClientBuilder {
113    pub fn new(api_key: &str) -> Self {
114        Self {
115            api_key: api_key.to_string(),
116            base_url: HUGGINGFACE_API_BASE_URL.to_string(),
117            sub_provider: SubProvider::default(),
118        }
119    }
120
121    pub fn base_url(mut self, base_url: &str) -> Self {
122        self.base_url = base_url.to_string();
123        self
124    }
125
126    pub fn sub_provider(mut self, provider: impl Into<SubProvider>) -> Self {
127        self.sub_provider = provider.into();
128        self
129    }
130
131    pub fn build(self) -> Client {
132        let route = self.sub_provider.to_string();
133
134        let base_url = format!("{}/{}", self.base_url, route).replace("//", "/");
135
136        Client::from_url(self.api_key.as_str(), base_url.as_str(), self.sub_provider)
137    }
138}
139
140#[derive(Clone, Debug)]
141pub struct Client {
142    base_url: String,
143    http_client: reqwest::Client,
144    pub(crate) sub_provider: SubProvider,
145}
146
147impl Client {
148    /// Create a new Huggingface client with the given API key.
149    pub fn new(api_key: &str) -> Self {
150        let base_url =
151            format!("{}/{}", HUGGINGFACE_API_BASE_URL, SubProvider::HFInference).replace("//", "/");
152        Self::from_url(api_key, &base_url, SubProvider::HFInference)
153    }
154
155    /// Create a new Client with the given API key and base API URL.
156    pub fn from_url(api_key: &str, base_url: &str, sub_provider: SubProvider) -> Self {
157        let http_client = reqwest::Client::builder()
158            .default_headers({
159                let mut headers = reqwest::header::HeaderMap::new();
160                headers.insert(
161                    "Authorization",
162                    format!("Bearer {api_key}")
163                        .parse()
164                        .expect("Failed to parse API key"),
165                );
166                headers.insert(
167                    "Content-Type",
168                    "application/json"
169                        .parse()
170                        .expect("Failed to parse Content-Type"),
171                );
172                headers
173            })
174            .build()
175            .expect("Failed to build HTTP client");
176
177        Self {
178            base_url: base_url.to_owned(),
179            http_client,
180            sub_provider,
181        }
182    }
183
184    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
185        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
186        self.http_client.post(url)
187    }
188}
189
190impl ProviderClient for Client {
191    /// Create a new Huggingface client from the `HUGGINGFACE_API_KEY` environment variable.
192    /// Panics if the environment variable is not set.
193    fn from_env() -> Self {
194        let api_key = std::env::var("HUGGINGFACE_API_KEY").expect("HUGGINGFACE_API_KEY is not set");
195        Self::new(&api_key)
196    }
197}
198
199impl CompletionClient for Client {
200    type CompletionModel = CompletionModel;
201
202    /// Create a new completion model with the given name
203    ///
204    /// # Example
205    /// ```
206    /// use rig::providers::huggingface::{Client, self}
207    ///
208    /// // Initialize the Huggingface client
209    /// let client = Client::new("your-huggingface-api-key");
210    ///
211    /// let completion_model = client.completion_model(huggingface::GEMMA_2);
212    /// ```
213    fn completion_model(&self, model: &str) -> CompletionModel {
214        CompletionModel::new(self.clone(), model)
215    }
216}
217
218impl TranscriptionClient for Client {
219    type TranscriptionModel = TranscriptionModel;
220
221    /// Create a new transcription model with the given name
222    ///
223    /// # Example
224    /// ```
225    /// use rig::providers::huggingface::{Client, self}
226    ///
227    /// // Initialize the Huggingface client
228    /// let client = Client::new("your-huggingface-api-key");
229    ///
230    /// let completion_model = client.transcription_model(huggingface::WHISPER_LARGE_V3);
231    /// ```
232    ///
233    fn transcription_model(&self, model: &str) -> TranscriptionModel {
234        TranscriptionModel::new(self.clone(), model)
235    }
236}
237
238#[cfg(feature = "image")]
239impl ImageGenerationClient for Client {
240    type ImageGenerationModel = ImageGenerationModel;
241
242    /// Create a new image generation model with the given name
243    ///
244    /// # Example
245    /// ```
246    /// use rig::providers::huggingface::{Client, self}
247    ///
248    /// // Initialize the Huggingface client
249    /// let client = Client::new("your-huggingface-api-key");
250    ///
251    /// let completion_model = client.image_generation_model(huggingface::WHISPER_LARGE_V3);
252    /// ```
253    fn image_generation_model(&self, model: &str) -> ImageGenerationModel {
254        ImageGenerationModel::new(self.clone(), model)
255    }
256}
257
258impl_conversion_traits!(AsEmbeddings, AsAudioGeneration for Client);