rig/providers/huggingface/
client.rs1use super::completion::CompletionModel;
2#[cfg(feature = "image")]
3use crate::client::ImageGenerationClient;
4use crate::client::{
5 ClientBuilderError, CompletionClient, ProviderClient, TranscriptionClient, VerifyClient,
6 VerifyError,
7};
8#[cfg(feature = "image")]
9use crate::image_generation::ImageGenerationError;
10#[cfg(feature = "image")]
11use crate::providers::huggingface::image_generation::ImageGenerationModel;
12use crate::providers::huggingface::transcription::TranscriptionModel;
13use crate::transcription::TranscriptionError;
14use rig::client::impl_conversion_traits;
15use std::fmt::Display;
16
17const HUGGINGFACE_API_BASE_URL: &str = "https://router.huggingface.co/";
21
22#[derive(Debug, Clone, PartialEq, Default)]
23pub enum SubProvider {
24 #[default]
25 HFInference,
26 Together,
27 SambaNova,
28 Fireworks,
29 Hyperbolic,
30 Nebius,
31 Novita,
32 Custom(String),
33}
34
35impl SubProvider {
36 pub fn completion_endpoint(&self, model: &str) -> String {
40 match self {
41 SubProvider::HFInference => format!("/{model}/v1/chat/completions"),
42 _ => "/v1/chat/completions".to_string(),
43 }
44 }
45
46 pub fn transcription_endpoint(&self, model: &str) -> Result<String, TranscriptionError> {
50 match self {
51 SubProvider::HFInference => Ok(format!("/{model}")),
52 _ => Err(TranscriptionError::ProviderError(format!(
53 "transcription endpoint is not supported yet for {self}"
54 ))),
55 }
56 }
57
58 #[cfg(feature = "image")]
62 pub fn image_generation_endpoint(&self, model: &str) -> Result<String, ImageGenerationError> {
63 match self {
64 SubProvider::HFInference => Ok(format!("/{model}")),
65 _ => Err(ImageGenerationError::ProviderError(format!(
66 "image generation endpoint is not supported yet for {self}"
67 ))),
68 }
69 }
70
71 pub fn model_identifier(&self, model: &str) -> String {
72 match self {
73 SubProvider::Fireworks => format!("accounts/fireworks/models/{model}"),
74 _ => model.to_string(),
75 }
76 }
77}
78
79impl From<&str> for SubProvider {
80 fn from(s: &str) -> Self {
81 SubProvider::Custom(s.to_string())
82 }
83}
84
85impl From<String> for SubProvider {
86 fn from(value: String) -> Self {
87 SubProvider::Custom(value)
88 }
89}
90
91impl Display for SubProvider {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 let route = match self {
94 SubProvider::HFInference => "hf-inference/models".to_string(),
95 SubProvider::Together => "together".to_string(),
96 SubProvider::SambaNova => "sambanova".to_string(),
97 SubProvider::Fireworks => "fireworks-ai".to_string(),
98 SubProvider::Hyperbolic => "hyperbolic".to_string(),
99 SubProvider::Nebius => "nebius".to_string(),
100 SubProvider::Novita => "novita".to_string(),
101 SubProvider::Custom(route) => route.clone(),
102 };
103
104 write!(f, "{route}")
105 }
106}
107
108pub struct ClientBuilder {
109 api_key: String,
110 base_url: String,
111 sub_provider: SubProvider,
112 http_client: Option<reqwest::Client>,
113}
114
115impl ClientBuilder {
116 pub fn new(api_key: &str) -> Self {
117 Self {
118 api_key: api_key.to_string(),
119 base_url: HUGGINGFACE_API_BASE_URL.to_string(),
120 sub_provider: SubProvider::default(),
121 http_client: None,
122 }
123 }
124
125 pub fn base_url(mut self, base_url: &str) -> Self {
126 self.base_url = base_url.to_string();
127 self
128 }
129
130 pub fn sub_provider(mut self, provider: impl Into<SubProvider>) -> Self {
131 self.sub_provider = provider.into();
132 self
133 }
134
135 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
136 self.http_client = Some(client);
137 self
138 }
139
140 pub fn build(self) -> Result<Client, ClientBuilderError> {
141 let route = self.sub_provider.to_string();
142 let base_url = format!("{}/{}", self.base_url, route).replace("//", "/");
143
144 let mut default_headers = reqwest::header::HeaderMap::new();
145 default_headers.insert(
146 "Content-Type",
147 "application/json"
148 .parse()
149 .expect("Failed to parse Content-Type"),
150 );
151 let http_client = if let Some(http_client) = self.http_client {
152 http_client
153 } else {
154 reqwest::Client::builder().build()?
155 };
156
157 Ok(Client {
158 base_url,
159 default_headers,
160 api_key: self.api_key,
161 http_client,
162 sub_provider: self.sub_provider,
163 })
164 }
165}
166
167#[derive(Clone)]
168pub struct Client {
169 base_url: String,
170 default_headers: reqwest::header::HeaderMap,
171 api_key: String,
172 http_client: reqwest::Client,
173 pub(crate) sub_provider: SubProvider,
174}
175
176impl std::fmt::Debug for Client {
177 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
178 f.debug_struct("Client")
179 .field("base_url", &self.base_url)
180 .field("http_client", &self.http_client)
181 .field("default_headers", &self.default_headers)
182 .field("sub_provider", &self.sub_provider)
183 .field("api_key", &"<REDACTED>")
184 .finish()
185 }
186}
187
188impl Client {
189 pub fn builder(api_key: &str) -> ClientBuilder {
200 ClientBuilder::new(api_key)
201 }
202
203 pub fn new(api_key: &str) -> Self {
208 Self::builder(api_key)
209 .build()
210 .expect("Huggingface client should build")
211 }
212
213 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
214 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
215 self.http_client
216 .post(url)
217 .bearer_auth(&self.api_key)
218 .headers(self.default_headers.clone())
219 }
220
221 pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder {
222 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
223 self.http_client
224 .get(url)
225 .bearer_auth(&self.api_key)
226 .headers(self.default_headers.clone())
227 }
228}
229
230impl ProviderClient for Client {
231 fn from_env() -> Self {
234 let api_key = std::env::var("HUGGINGFACE_API_KEY").expect("HUGGINGFACE_API_KEY is not set");
235 Self::new(&api_key)
236 }
237
238 fn from_val(input: crate::client::ProviderValue) -> Self {
239 let crate::client::ProviderValue::Simple(api_key) = input else {
240 panic!("Incorrect provider value type")
241 };
242 Self::new(&api_key)
243 }
244}
245
246impl CompletionClient for Client {
247 type CompletionModel = CompletionModel;
248
249 fn completion_model(&self, model: &str) -> CompletionModel {
261 CompletionModel::new(self.clone(), model)
262 }
263}
264
265impl TranscriptionClient for Client {
266 type TranscriptionModel = TranscriptionModel;
267
268 fn transcription_model(&self, model: &str) -> TranscriptionModel {
281 TranscriptionModel::new(self.clone(), model)
282 }
283}
284
285#[cfg(feature = "image")]
286impl ImageGenerationClient for Client {
287 type ImageGenerationModel = ImageGenerationModel;
288
289 fn image_generation_model(&self, model: &str) -> ImageGenerationModel {
301 ImageGenerationModel::new(self.clone(), model)
302 }
303}
304
305impl VerifyClient for Client {
306 #[cfg_attr(feature = "worker", worker::send)]
307 async fn verify(&self) -> Result<(), VerifyError> {
308 let response = self.get("/api/whoami-v2").send().await?;
309 match response.status() {
310 reqwest::StatusCode::OK => Ok(()),
311 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
312 reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
313 Err(VerifyError::ProviderError(response.text().await?))
314 }
315 _ => {
316 response.error_for_status()?;
317 Ok(())
318 }
319 }
320 }
321}
322
323impl_conversion_traits!(AsEmbeddings, AsAudioGeneration for Client);