rig/providers/huggingface/
client.rs1use super::completion::CompletionModel;
2#[cfg(feature = "image")]
3use crate::client::ImageGenerationClient;
4use crate::client::{CompletionClient, ProviderClient, TranscriptionClient};
5#[cfg(feature = "image")]
6use crate::image_generation::ImageGenerationError;
7#[cfg(feature = "image")]
8use crate::providers::huggingface::image_generation::ImageGenerationModel;
9use crate::providers::huggingface::transcription::TranscriptionModel;
10use crate::transcription::TranscriptionError;
11use rig::client::impl_conversion_traits;
12use std::fmt::Display;
13
14const HUGGINGFACE_API_BASE_URL: &str = "https://router.huggingface.co/";
18
19#[derive(Debug, Clone, PartialEq, Default)]
20pub enum SubProvider {
21 #[default]
22 HFInference,
23 Together,
24 SambaNova,
25 Fireworks,
26 Hyperbolic,
27 Nebius,
28 Novita,
29 Custom(String),
30}
31
32impl SubProvider {
33 pub fn completion_endpoint(&self, model: &str) -> String {
37 match self {
38 SubProvider::HFInference => format!("/{model}/v1/chat/completions"),
39 _ => "/v1/chat/completions".to_string(),
40 }
41 }
42
43 pub fn transcription_endpoint(&self, model: &str) -> Result<String, TranscriptionError> {
47 match self {
48 SubProvider::HFInference => Ok(format!("/{model}")),
49 _ => Err(TranscriptionError::ProviderError(format!(
50 "transcription endpoint is not supported yet for {self}"
51 ))),
52 }
53 }
54
55 #[cfg(feature = "image")]
59 pub fn image_generation_endpoint(&self, model: &str) -> Result<String, ImageGenerationError> {
60 match self {
61 SubProvider::HFInference => Ok(format!("/{model}")),
62 _ => Err(ImageGenerationError::ProviderError(format!(
63 "image generation endpoint is not supported yet for {self}"
64 ))),
65 }
66 }
67
68 pub fn model_identifier(&self, model: &str) -> String {
69 match self {
70 SubProvider::Fireworks => format!("accounts/fireworks/models/{model}"),
71 _ => model.to_string(),
72 }
73 }
74}
75
76impl From<&str> for SubProvider {
77 fn from(s: &str) -> Self {
78 SubProvider::Custom(s.to_string())
79 }
80}
81
82impl From<String> for SubProvider {
83 fn from(value: String) -> Self {
84 SubProvider::Custom(value)
85 }
86}
87
88impl Display for SubProvider {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 let route = match self {
91 SubProvider::HFInference => "hf-inference/models".to_string(),
92 SubProvider::Together => "together".to_string(),
93 SubProvider::SambaNova => "sambanova".to_string(),
94 SubProvider::Fireworks => "fireworks-ai".to_string(),
95 SubProvider::Hyperbolic => "hyperbolic".to_string(),
96 SubProvider::Nebius => "nebius".to_string(),
97 SubProvider::Novita => "novita".to_string(),
98 SubProvider::Custom(route) => route.clone(),
99 };
100
101 write!(f, "{route}")
102 }
103}
104
105pub struct ClientBuilder {
106 api_key: String,
107 base_url: String,
108 sub_provider: SubProvider,
109}
110
111impl ClientBuilder {
112 pub fn new(api_key: &str) -> Self {
113 Self {
114 api_key: api_key.to_string(),
115 base_url: HUGGINGFACE_API_BASE_URL.to_string(),
116 sub_provider: SubProvider::default(),
117 }
118 }
119
120 pub fn base_url(mut self, base_url: &str) -> Self {
121 self.base_url = base_url.to_string();
122 self
123 }
124
125 pub fn sub_provider(mut self, provider: impl Into<SubProvider>) -> Self {
126 self.sub_provider = provider.into();
127 self
128 }
129
130 pub fn build(self) -> Client {
131 let route = self.sub_provider.to_string();
132
133 let base_url = format!("{}/{}", self.base_url, route).replace("//", "/");
134
135 Client::from_url(self.api_key.as_str(), base_url.as_str(), self.sub_provider)
136 }
137}
138
139#[derive(Clone)]
140pub struct Client {
141 base_url: String,
142 default_headers: reqwest::header::HeaderMap,
143 api_key: String,
144 http_client: reqwest::Client,
145 pub(crate) sub_provider: SubProvider,
146}
147
148impl std::fmt::Debug for Client {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 f.debug_struct("Client")
151 .field("base_url", &self.base_url)
152 .field("http_client", &self.http_client)
153 .field("default_headers", &self.default_headers)
154 .field("sub_provider", &self.sub_provider)
155 .field("api_key", &"<REDACTED>")
156 .finish()
157 }
158}
159
160impl Client {
161 pub fn new(api_key: &str) -> Self {
163 let base_url =
164 format!("{}/{}", HUGGINGFACE_API_BASE_URL, SubProvider::HFInference).replace("//", "/");
165 Self::from_url(api_key, &base_url, SubProvider::HFInference)
166 }
167
168 pub fn from_url(api_key: &str, base_url: &str, sub_provider: SubProvider) -> Self {
170 let mut default_headers = reqwest::header::HeaderMap::new();
171 default_headers.insert(
172 "Content-Type",
173 "application/json"
174 .parse()
175 .expect("Failed to parse Content-Type"),
176 );
177 let http_client = reqwest::Client::builder()
178 .build()
179 .expect("Failed to build HTTP client");
180
181 Self {
182 base_url: base_url.to_owned(),
183 api_key: api_key.to_string(),
184 default_headers,
185 http_client,
186 sub_provider,
187 }
188 }
189
190 pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
193 self.http_client = client;
194
195 self
196 }
197
198 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
199 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
200 self.http_client
201 .post(url)
202 .bearer_auth(&self.api_key)
203 .headers(self.default_headers.clone())
204 }
205}
206
207impl ProviderClient for Client {
208 fn from_env() -> Self {
211 let api_key = std::env::var("HUGGINGFACE_API_KEY").expect("HUGGINGFACE_API_KEY is not set");
212 Self::new(&api_key)
213 }
214}
215
216impl CompletionClient for Client {
217 type CompletionModel = CompletionModel;
218
219 fn completion_model(&self, model: &str) -> CompletionModel {
231 CompletionModel::new(self.clone(), model)
232 }
233}
234
235impl TranscriptionClient for Client {
236 type TranscriptionModel = TranscriptionModel;
237
238 fn transcription_model(&self, model: &str) -> TranscriptionModel {
251 TranscriptionModel::new(self.clone(), model)
252 }
253}
254
255#[cfg(feature = "image")]
256impl ImageGenerationClient for Client {
257 type ImageGenerationModel = ImageGenerationModel;
258
259 fn image_generation_model(&self, model: &str) -> ImageGenerationModel {
271 ImageGenerationModel::new(self.clone(), model)
272 }
273}
274
275impl_conversion_traits!(AsEmbeddings, AsAudioGeneration for Client);