rig/providers/huggingface/
client.rs1use super::completion::CompletionModel;
2#[cfg(feature = "image")]
3use crate::client::ImageGenerationClient;
4use crate::client::{CompletionClient, ProviderClient, TranscriptionClient};
5#[cfg(feature = "image")]
6use crate::image_generation::ImageGenerationError;
7#[cfg(feature = "image")]
8use crate::providers::huggingface::image_generation::ImageGenerationModel;
9use crate::providers::huggingface::transcription::TranscriptionModel;
10use crate::transcription::TranscriptionError;
11use rig::client::impl_conversion_traits;
12use std::fmt::Display;
13
14const HUGGINGFACE_API_BASE_URL: &str = "https://router.huggingface.co/";
18
19#[derive(Debug, Clone, PartialEq, Default)]
20pub enum SubProvider {
21 #[default]
22 HFInference,
23 Together,
24 SambaNova,
25 Fireworks,
26 Hyperbolic,
27 Nebius,
28 Novita,
29 Custom(String),
30}
31
32impl SubProvider {
33 pub fn completion_endpoint(&self, model: &str) -> String {
37 match self {
38 SubProvider::HFInference => format!("/{model}/v1/chat/completions"),
39 _ => "/v1/chat/completions".to_string(),
40 }
41 }
42
43 pub fn transcription_endpoint(&self, model: &str) -> Result<String, TranscriptionError> {
47 match self {
48 SubProvider::HFInference => Ok(format!("/{model}")),
49 _ => Err(TranscriptionError::ProviderError(format!(
50 "transcription endpoint is not supported yet for {self}"
51 ))),
52 }
53 }
54
55 #[cfg(feature = "image")]
59 pub fn image_generation_endpoint(&self, model: &str) -> Result<String, ImageGenerationError> {
60 match self {
61 SubProvider::HFInference => Ok(format!("/{}", model)),
62 _ => Err(ImageGenerationError::ProviderError(format!(
63 "image generation endpoint is not supported yet for {}",
64 self
65 ))),
66 }
67 }
68
69 pub fn model_identifier(&self, model: &str) -> String {
70 match self {
71 SubProvider::Fireworks => format!("accounts/fireworks/models/{model}"),
72 _ => model.to_string(),
73 }
74 }
75}
76
77impl From<&str> for SubProvider {
78 fn from(s: &str) -> Self {
79 SubProvider::Custom(s.to_string())
80 }
81}
82
83impl From<String> for SubProvider {
84 fn from(value: String) -> Self {
85 SubProvider::Custom(value)
86 }
87}
88
89impl Display for SubProvider {
90 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91 let route = match self {
92 SubProvider::HFInference => "hf-inference/models".to_string(),
93 SubProvider::Together => "together".to_string(),
94 SubProvider::SambaNova => "sambanova".to_string(),
95 SubProvider::Fireworks => "fireworks-ai".to_string(),
96 SubProvider::Hyperbolic => "hyperbolic".to_string(),
97 SubProvider::Nebius => "nebius".to_string(),
98 SubProvider::Novita => "novita".to_string(),
99 SubProvider::Custom(route) => route.clone(),
100 };
101
102 write!(f, "{route}")
103 }
104}
105
106pub struct ClientBuilder {
107 api_key: String,
108 base_url: String,
109 sub_provider: SubProvider,
110}
111
112impl ClientBuilder {
113 pub fn new(api_key: &str) -> Self {
114 Self {
115 api_key: api_key.to_string(),
116 base_url: HUGGINGFACE_API_BASE_URL.to_string(),
117 sub_provider: SubProvider::default(),
118 }
119 }
120
121 pub fn base_url(mut self, base_url: &str) -> Self {
122 self.base_url = base_url.to_string();
123 self
124 }
125
126 pub fn sub_provider(mut self, provider: impl Into<SubProvider>) -> Self {
127 self.sub_provider = provider.into();
128 self
129 }
130
131 pub fn build(self) -> Client {
132 let route = self.sub_provider.to_string();
133
134 let base_url = format!("{}/{}", self.base_url, route).replace("//", "/");
135
136 Client::from_url(self.api_key.as_str(), base_url.as_str(), self.sub_provider)
137 }
138}
139
140#[derive(Clone, Debug)]
141pub struct Client {
142 base_url: String,
143 http_client: reqwest::Client,
144 pub(crate) sub_provider: SubProvider,
145}
146
147impl Client {
148 pub fn new(api_key: &str) -> Self {
150 let base_url =
151 format!("{}/{}", HUGGINGFACE_API_BASE_URL, SubProvider::HFInference).replace("//", "/");
152 Self::from_url(api_key, &base_url, SubProvider::HFInference)
153 }
154
155 pub fn from_url(api_key: &str, base_url: &str, sub_provider: SubProvider) -> Self {
157 let http_client = reqwest::Client::builder()
158 .default_headers({
159 let mut headers = reqwest::header::HeaderMap::new();
160 headers.insert(
161 "Authorization",
162 format!("Bearer {api_key}")
163 .parse()
164 .expect("Failed to parse API key"),
165 );
166 headers.insert(
167 "Content-Type",
168 "application/json"
169 .parse()
170 .expect("Failed to parse Content-Type"),
171 );
172 headers
173 })
174 .build()
175 .expect("Failed to build HTTP client");
176
177 Self {
178 base_url: base_url.to_owned(),
179 http_client,
180 sub_provider,
181 }
182 }
183
184 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
185 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
186 self.http_client.post(url)
187 }
188}
189
190impl ProviderClient for Client {
191 fn from_env() -> Self {
194 let api_key = std::env::var("HUGGINGFACE_API_KEY").expect("HUGGINGFACE_API_KEY is not set");
195 Self::new(&api_key)
196 }
197}
198
199impl CompletionClient for Client {
200 type CompletionModel = CompletionModel;
201
202 fn completion_model(&self, model: &str) -> CompletionModel {
214 CompletionModel::new(self.clone(), model)
215 }
216}
217
218impl TranscriptionClient for Client {
219 type TranscriptionModel = TranscriptionModel;
220
221 fn transcription_model(&self, model: &str) -> TranscriptionModel {
234 TranscriptionModel::new(self.clone(), model)
235 }
236}
237
238#[cfg(feature = "image")]
239impl ImageGenerationClient for Client {
240 type ImageGenerationModel = ImageGenerationModel;
241
242 fn image_generation_model(&self, model: &str) -> ImageGenerationModel {
254 ImageGenerationModel::new(self.clone(), model)
255 }
256}
257
258impl_conversion_traits!(AsEmbeddings, AsAudioGeneration for Client);