rig/providers/huggingface/
client.rs

1use std::fmt::Display;
2
3use super::completion::CompletionModel;
4use crate::agent::AgentBuilder;
5#[cfg(feature = "image")]
6use crate::image_generation::ImageGenerationError;
7#[cfg(feature = "image")]
8use crate::providers::huggingface::image_generation::ImageGenerationModel;
9use crate::providers::huggingface::transcription::TranscriptionModel;
10use crate::transcription::TranscriptionError;
11
12// ================================================================
13// Main Huggingface Client
14// ================================================================
15const HUGGINGFACE_API_BASE_URL: &str = "https://router.huggingface.co/";
16
17#[derive(Debug, Clone, PartialEq, Default)]
18pub enum SubProvider {
19    #[default]
20    HFInference,
21    Together,
22    SambaNova,
23    Fireworks,
24    Hyperbolic,
25    Nebius,
26    Novita,
27    Custom(String),
28}
29
30impl SubProvider {
31    /// Get the chat completion endpoint for the SubProvider
32    /// Required because Huggingface Inference requires the model
33    /// in the url and in the request body.
34    pub fn completion_endpoint(&self, model: &str) -> String {
35        match self {
36            SubProvider::HFInference => format!("/{}/v1/chat/completions", model),
37            _ => "/v1/chat/completions".to_string(),
38        }
39    }
40
41    /// Get the transcription endpoint for the SubProvider
42    /// Required because Huggingface Inference requires the model
43    /// in the url and in the request body.
44    pub fn transcription_endpoint(&self, model: &str) -> Result<String, TranscriptionError> {
45        match self {
46            SubProvider::HFInference => Ok(format!("/{}", model)),
47            _ => Err(TranscriptionError::ProviderError(format!(
48                "transcription endpoint is not supported yet for {}",
49                self
50            ))),
51        }
52    }
53
54    /// Get the image generation endpoint for the SubProvider
55    /// Required because Huggingface Inference requires the model
56    /// in the url and in the request body.
57    #[cfg(feature = "image")]
58    pub fn image_generation_endpoint(&self, model: &str) -> Result<String, ImageGenerationError> {
59        match self {
60            SubProvider::HFInference => Ok(format!("/{}", model)),
61            _ => Err(ImageGenerationError::ProviderError(format!(
62                "image generation endpoint is not supported yet for {}",
63                self
64            ))),
65        }
66    }
67
68    pub fn model_identifier(&self, model: &str) -> String {
69        match self {
70            SubProvider::Fireworks => format!("accounts/fireworks/models/{}", model),
71            _ => model.to_string(),
72        }
73    }
74}
75
76impl From<&str> for SubProvider {
77    fn from(s: &str) -> Self {
78        SubProvider::Custom(s.to_string())
79    }
80}
81
82impl From<String> for SubProvider {
83    fn from(value: String) -> Self {
84        SubProvider::Custom(value)
85    }
86}
87
88impl Display for SubProvider {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        let route = match self {
91            SubProvider::HFInference => "hf-inference/models".to_string(),
92            SubProvider::Together => "together".to_string(),
93            SubProvider::SambaNova => "sambanova".to_string(),
94            SubProvider::Fireworks => "fireworks-ai".to_string(),
95            SubProvider::Hyperbolic => "hyperbolic".to_string(),
96            SubProvider::Nebius => "nebius".to_string(),
97            SubProvider::Novita => "novita".to_string(),
98            SubProvider::Custom(route) => route.clone(),
99        };
100
101        write!(f, "{}", route)
102    }
103}
104
105pub struct ClientBuilder {
106    api_key: String,
107    base_url: String,
108    sub_provider: SubProvider,
109}
110
111impl ClientBuilder {
112    pub fn new(api_key: &str) -> Self {
113        Self {
114            api_key: api_key.to_string(),
115            base_url: HUGGINGFACE_API_BASE_URL.to_string(),
116            sub_provider: SubProvider::default(),
117        }
118    }
119
120    pub fn base_url(mut self, base_url: &str) -> Self {
121        self.base_url = base_url.to_string();
122        self
123    }
124
125    pub fn sub_provider(mut self, provider: impl Into<SubProvider>) -> Self {
126        self.sub_provider = provider.into();
127        self
128    }
129
130    pub fn build(self) -> Client {
131        let route = self.sub_provider.to_string();
132
133        let base_url = format!("{}/{}", self.base_url, route).replace("//", "/");
134
135        Client::from_url(self.api_key.as_str(), base_url.as_str(), self.sub_provider)
136    }
137}
138
139#[derive(Clone)]
140pub struct Client {
141    base_url: String,
142    http_client: reqwest::Client,
143    pub(crate) sub_provider: SubProvider,
144}
145
146impl Client {
147    /// Create a new Huggingface client with the given API key.
148    pub fn new(api_key: &str) -> Self {
149        let base_url =
150            format!("{}/{}", HUGGINGFACE_API_BASE_URL, SubProvider::HFInference).replace("//", "/");
151        Self::from_url(api_key, &base_url, SubProvider::HFInference)
152    }
153
154    /// Create a new Client with the given API key and base API URL.
155    pub fn from_url(api_key: &str, base_url: &str, sub_provider: SubProvider) -> Self {
156        let http_client = reqwest::Client::builder()
157            .default_headers({
158                let mut headers = reqwest::header::HeaderMap::new();
159                headers.insert(
160                    "Authorization",
161                    format!("Bearer {api_key}")
162                        .parse()
163                        .expect("Failed to parse API key"),
164                );
165                headers.insert(
166                    "Content-Type",
167                    "application/json"
168                        .parse()
169                        .expect("Failed to parse Content-Type"),
170                );
171                headers
172            })
173            .build()
174            .expect("Failed to build HTTP client");
175
176        Self {
177            base_url: base_url.to_owned(),
178            http_client,
179            sub_provider,
180        }
181    }
182    /// Create a new Huggingface client from the `HUGGINGFACE_API_KEY` environment variable.
183    /// Panics if the environment variable is not set.
184    pub fn from_env() -> Self {
185        let api_key = std::env::var("HUGGINGFACE_API_KEY").expect("HUGGINGFACE_API_KEY is not set");
186        Self::new(&api_key)
187    }
188
189    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
190        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
191        self.http_client.post(url)
192    }
193
194    /// Create a new completion model with the given name
195    ///
196    /// # Example
197    /// ```
198    /// use rig::providers::huggingface::{Client, self}
199    ///
200    /// // Initialize the Huggingface client
201    /// let client = Client::new("your-huggingface-api-key");
202    ///
203    /// let completion_model = client.completion_model(huggingface::GEMMA_2);
204    /// ```
205    pub fn completion_model(&self, model: &str) -> CompletionModel {
206        CompletionModel::new(self.clone(), model)
207    }
208
209    /// Create a new transcription model with the given name
210    ///
211    /// # Example
212    /// ```
213    /// use rig::providers::huggingface::{Client, self}
214    ///
215    /// // Initialize the Huggingface client
216    /// let client = Client::new("your-huggingface-api-key");
217    ///
218    /// let completion_model = client.transcription_model(huggingface::WHISPER_LARGE_V3);
219    /// ```
220    ///
221    pub fn transcription_model(&self, model: &str) -> TranscriptionModel {
222        TranscriptionModel::new(self.clone(), model)
223    }
224
225    /// Create a new image generation model with the given name
226    ///
227    /// # Example
228    /// ```
229    /// use rig::providers::huggingface::{Client, self}
230    ///
231    /// // Initialize the Huggingface client
232    /// let client = Client::new("your-huggingface-api-key");
233    ///
234    /// let completion_model = client.image_generation_model(huggingface::WHISPER_LARGE_V3);
235    /// ```
236    #[cfg(feature = "image")]
237    pub fn image_generation_model(&self, model: &str) -> ImageGenerationModel {
238        ImageGenerationModel::new(self.clone(), model)
239    }
240
241    /// Create an agent builder with the given completion model.
242    ///
243    /// # Example
244    /// ```
245    /// use rig::providers::huggingface::{Client, self};
246    ///
247    /// // Initialize the Anthropic client
248    /// let client = Client::new("your-huggingface-api-key");
249    ///
250    /// let agent = client.agent(huggingface::GEMMA_2)
251    ///    .preamble("You are comedian AI with a mission to make people laugh.")
252    ///    .temperature(0.0)
253    ///    .build();
254    /// ```
255    pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
256        AgentBuilder::new(self.completion_model(model))
257    }
258}