rig/providers/huggingface/
client.rs1use std::fmt::Display;
2
3use super::completion::CompletionModel;
4use crate::agent::AgentBuilder;
5#[cfg(feature = "image")]
6use crate::image_generation::ImageGenerationError;
7#[cfg(feature = "image")]
8use crate::providers::huggingface::image_generation::ImageGenerationModel;
9use crate::providers::huggingface::transcription::TranscriptionModel;
10use crate::transcription::TranscriptionError;
11
12const HUGGINGFACE_API_BASE_URL: &str = "https://router.huggingface.co/";
16
17#[derive(Debug, Clone, PartialEq, Default)]
18pub enum SubProvider {
19 #[default]
20 HFInference,
21 Together,
22 SambaNova,
23 Fireworks,
24 Hyperbolic,
25 Nebius,
26 Novita,
27 Custom(String),
28}
29
30impl SubProvider {
31 pub fn completion_endpoint(&self, model: &str) -> String {
35 match self {
36 SubProvider::HFInference => format!("/{}/v1/chat/completions", model),
37 _ => "/v1/chat/completions".to_string(),
38 }
39 }
40
41 pub fn transcription_endpoint(&self, model: &str) -> Result<String, TranscriptionError> {
45 match self {
46 SubProvider::HFInference => Ok(format!("/{}", model)),
47 _ => Err(TranscriptionError::ProviderError(format!(
48 "transcription endpoint is not supported yet for {}",
49 self
50 ))),
51 }
52 }
53
54 #[cfg(feature = "image")]
58 pub fn image_generation_endpoint(&self, model: &str) -> Result<String, ImageGenerationError> {
59 match self {
60 SubProvider::HFInference => Ok(format!("/{}", model)),
61 _ => Err(ImageGenerationError::ProviderError(format!(
62 "image generation endpoint is not supported yet for {}",
63 self
64 ))),
65 }
66 }
67
68 pub fn model_identifier(&self, model: &str) -> String {
69 match self {
70 SubProvider::Fireworks => format!("accounts/fireworks/models/{}", model),
71 _ => model.to_string(),
72 }
73 }
74}
75
76impl From<&str> for SubProvider {
77 fn from(s: &str) -> Self {
78 SubProvider::Custom(s.to_string())
79 }
80}
81
82impl From<String> for SubProvider {
83 fn from(value: String) -> Self {
84 SubProvider::Custom(value)
85 }
86}
87
88impl Display for SubProvider {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 let route = match self {
91 SubProvider::HFInference => "hf-inference/models".to_string(),
92 SubProvider::Together => "together".to_string(),
93 SubProvider::SambaNova => "sambanova".to_string(),
94 SubProvider::Fireworks => "fireworks-ai".to_string(),
95 SubProvider::Hyperbolic => "hyperbolic".to_string(),
96 SubProvider::Nebius => "nebius".to_string(),
97 SubProvider::Novita => "novita".to_string(),
98 SubProvider::Custom(route) => route.clone(),
99 };
100
101 write!(f, "{}", route)
102 }
103}
104
105pub struct ClientBuilder {
106 api_key: String,
107 base_url: String,
108 sub_provider: SubProvider,
109}
110
111impl ClientBuilder {
112 pub fn new(api_key: &str) -> Self {
113 Self {
114 api_key: api_key.to_string(),
115 base_url: HUGGINGFACE_API_BASE_URL.to_string(),
116 sub_provider: SubProvider::default(),
117 }
118 }
119
120 pub fn base_url(mut self, base_url: &str) -> Self {
121 self.base_url = base_url.to_string();
122 self
123 }
124
125 pub fn sub_provider(mut self, provider: impl Into<SubProvider>) -> Self {
126 self.sub_provider = provider.into();
127 self
128 }
129
130 pub fn build(self) -> Client {
131 let route = self.sub_provider.to_string();
132
133 let base_url = format!("{}/{}", self.base_url, route).replace("//", "/");
134
135 Client::from_url(self.api_key.as_str(), base_url.as_str(), self.sub_provider)
136 }
137}
138
139#[derive(Clone)]
140pub struct Client {
141 base_url: String,
142 http_client: reqwest::Client,
143 pub(crate) sub_provider: SubProvider,
144}
145
146impl Client {
147 pub fn new(api_key: &str) -> Self {
149 let base_url =
150 format!("{}/{}", HUGGINGFACE_API_BASE_URL, SubProvider::HFInference).replace("//", "/");
151 Self::from_url(api_key, &base_url, SubProvider::HFInference)
152 }
153
154 pub fn from_url(api_key: &str, base_url: &str, sub_provider: SubProvider) -> Self {
156 let http_client = reqwest::Client::builder()
157 .default_headers({
158 let mut headers = reqwest::header::HeaderMap::new();
159 headers.insert(
160 "Authorization",
161 format!("Bearer {api_key}")
162 .parse()
163 .expect("Failed to parse API key"),
164 );
165 headers.insert(
166 "Content-Type",
167 "application/json"
168 .parse()
169 .expect("Failed to parse Content-Type"),
170 );
171 headers
172 })
173 .build()
174 .expect("Failed to build HTTP client");
175
176 Self {
177 base_url: base_url.to_owned(),
178 http_client,
179 sub_provider,
180 }
181 }
182 pub fn from_env() -> Self {
185 let api_key = std::env::var("HUGGINGFACE_API_KEY").expect("HUGGINGFACE_API_KEY is not set");
186 Self::new(&api_key)
187 }
188
189 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
190 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
191 self.http_client.post(url)
192 }
193
194 pub fn completion_model(&self, model: &str) -> CompletionModel {
206 CompletionModel::new(self.clone(), model)
207 }
208
209 pub fn transcription_model(&self, model: &str) -> TranscriptionModel {
222 TranscriptionModel::new(self.clone(), model)
223 }
224
225 #[cfg(feature = "image")]
237 pub fn image_generation_model(&self, model: &str) -> ImageGenerationModel {
238 ImageGenerationModel::new(self.clone(), model)
239 }
240
241 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
256 AgentBuilder::new(self.completion_model(model))
257 }
258}