1use anyhow::{Error, Result};
2use rand::{distributions::Alphanumeric, Rng};
3use regex::Regex;
4
5use crate::constants::*;
6use crate::preprocess_data;
7use crate::structs::*;
8use crate::{
9 data::{PredictionInput, PredictionOutput},
10 space::wake_up_space,
11 stream::PredictionStream,
12};
13
14#[derive(Default)]
15pub struct ClientOptions {
16 pub hf_token: Option<String>,
17 pub auth: Option<(String, String)>,
18}
19
20impl ClientOptions {
21 pub fn with_hf_token(hf_token: String) -> ClientOptions {
22 Self {
23 hf_token: Some(hf_token),
24 auth: None,
25 }
26 }
27
28 pub fn with_auth(username: String, password: String) -> Self {
29 Self {
30 hf_token: None,
31 auth: Some((username, password)),
32 }
33 }
34}
35
36#[derive(Clone, Debug)]
37pub struct Client {
38 pub session_hash: String,
39 pub jwt: Option<String>,
40 pub http_client: reqwest::Client,
41 pub api_root: String,
42 pub space_id: Option<String>,
43 config: AppConfig,
44 api_info: ApiInfo,
45}
46
47impl Client {
48 pub async fn new(app_reference: &str, options: ClientOptions) -> Result<Self> {
78 let session_hash: String = rand::thread_rng()
79 .sample_iter(&Alphanumeric)
80 .take(10)
81 .map(char::from)
82 .collect();
83
84 let http_client = Client::build_http_client(&options.hf_token)?;
85
86 let (mut api_root, space_id) =
87 Client::resolve_app_reference(&http_client, app_reference).await?;
88
89 if let Some((username, password)) = &options.auth {
90 Client::authenticate(&http_client, &api_root, username, password).await?;
91 }
92
93 if let Some(space_id) = &space_id {
94 wake_up_space(&http_client, space_id).await?;
95 }
96
97 let config = Client::fetch_config(&http_client, &api_root).await?;
98 if let Some(ref api_prefix) = config.api_prefix {
99 api_root.push_str(api_prefix);
100 }
101
102 let api_info = Client::fetch_api_info(&http_client, &api_root).await?;
103
104 Ok(Self {
105 session_hash,
106 jwt: None,
107 http_client,
108 api_root,
109 space_id,
110 config,
111 api_info,
112 })
113 }
114
115 pub fn view_config(&self) -> AppConfig {
116 self.config.clone()
117 }
118
119 pub fn view_api(&self) -> ApiInfo {
120 self.api_info.clone()
121 }
122
123 pub async fn submit(
124 &self,
125 route: &str,
126 data: Vec<PredictionInput>,
127 ) -> Result<PredictionStream> {
128 let data = preprocess_data(&self.http_client, &self.api_root, data).await?;
129 let fn_index = Client::resolve_fn_index(&self.config, route)?;
130 PredictionStream::new(&self.http_client, &self.api_root, fn_index, data).await
131 }
132
133 pub async fn predict(
134 &self,
135 route: &str,
136 data: Vec<PredictionInput>,
137 ) -> Result<Vec<PredictionOutput>> {
138 let mut stream = self.submit(route, data).await?;
139 while let Some(message) = stream.next().await {
140 match message {
141 Ok(message) => match message {
142 QueueDataMessage::Open
143 | QueueDataMessage::Estimation { .. }
144 | QueueDataMessage::ProcessStarts { .. }
145 | QueueDataMessage::Progress { .. }
146 | QueueDataMessage::Log { .. }
147 | QueueDataMessage::Heartbeat => {}
148 QueueDataMessage::ProcessCompleted { output, .. } => {
149 return output.try_into();
150 }
151 QueueDataMessage::UnexpectedError { message } => {
152 return Err(Error::msg(
153 message.unwrap_or_else(|| "Unexpected error".to_string()),
154 ));
155 }
156 QueueDataMessage::Unknown(m) => {
157 eprintln!("[warning] Skipping unknown message: {:?}", m);
158 }
159 },
160 Err(err) => {
161 return Err(err);
162 }
163 }
164 }
165
166 Err(Error::msg("Stream ended unexpectedly"))
167 }
168
169 fn build_http_client(hf_token: &Option<String>) -> Result<reqwest::Client> {
170 let mut http_client_builder = reqwest::Client::builder()
171 .cookie_store(true)
172 .user_agent("Rust Gradio Client");
173 if let Some(hf_token) = hf_token {
174 http_client_builder =
175 http_client_builder.default_headers(reqwest::header::HeaderMap::from_iter(vec![(
176 reqwest::header::AUTHORIZATION,
177 format!("Bearer {}", hf_token).parse()?,
178 )]));
179 }
180
181 http_client_builder.build().map_err(Error::new)
182 }
183
184 async fn resolve_app_reference(
185 http_client: &reqwest::Client,
186 app_reference: &str,
187 ) -> Result<(String, Option<String>)> {
188 let app_reference = app_reference.trim_end_matches('/').to_string();
189 let mut api_root = app_reference.clone();
190 let mut space_id = None;
191 if Regex::new("^[a-zA-Z0-9_\\-\\.]+\\/[a-zA-Z0-9_\\-\\.]+$")?.is_match(&app_reference) {
192 let url = format!(
193 "https://huggingface.co/api/spaces/{}/{}",
194 app_reference, HOST_URL
195 );
196 let res = http_client.get(&url).send().await?;
197 let res = res.json::<HuggingFaceAPIHost>().await?;
198 api_root.clone_from(&res.host);
199 space_id = Some(app_reference);
200 } else if Regex::new(".*hf\\.space\\/{0,1}$")?.is_match(&app_reference) {
201 space_id = Some(app_reference.replace(".hf.space", ""));
202 }
203
204 Ok((api_root, space_id))
205 }
206
207 async fn authenticate(
208 http_client: &reqwest::Client,
209 api_root: &str,
210 username: &str,
211 password: &str,
212 ) -> Result<()> {
213 let res = http_client
214 .post(&format!("{}/{}", api_root, LOGIN_URL))
215 .form(&[("username", username), ("password", password)])
216 .send()
217 .await?;
218 if !res.status().is_success() {
219 return Err(Error::msg("Login failed"));
220 }
221 Ok(())
222 }
223
224 async fn fetch_config(http_client: &reqwest::Client, api_root: &str) -> Result<AppConfig> {
225 let res = http_client
226 .get(&format!("{}/{}", api_root, CONFIG_URL))
227 .send()
228 .await?;
229 if !res.status().is_success() {
230 return Err(Error::msg("Could not resolve app config"));
231 }
232
233 let json = res.json::<serde_json::Value>().await?;
234 let config: AppConfigVersionOnly = serde_json::from_value(json.clone())?;
235
236 if !config.version.starts_with("5.") && !config.version.starts_with("4.") {
237 eprintln!(
238 "Warning: This client is supposed to work with Gradio 5 & 4. The current version of the app is {}, which may cause issues.",
239 config.version
240 );
241 }
242
243 serde_json::from_value(json).map_err(Error::new)
244 }
245
246 async fn fetch_api_info(http_client: &reqwest::Client, api_root: &str) -> Result<ApiInfo> {
247 let res = http_client
248 .get(&format!("{}/{}", api_root, API_INFO_URL))
249 .send()
250 .await?;
251 if !res.status().is_success() {
252 return Err(Error::msg("Could not get API info"));
253 }
254 res.json::<ApiInfo>().await.map_err(Error::new)
255 }
256
257 fn resolve_fn_index(config: &AppConfig, route: &str) -> Result<i64> {
258 let route = route.trim_start_matches('/');
259 let found = config
260 .dependencies
261 .iter()
262 .enumerate()
263 .find(|(_i, d)| d.api_name == route)
264 .ok_or_else(|| Error::msg("Invalid route"))?;
265
266 if found.1.id == -1 {
267 Ok(found.0 as i64)
268 } else {
269 Ok(found.1.id)
270 }
271 }
272}