gradio/
client.rs

1use anyhow::{Error, Result};
2use rand::{distributions::Alphanumeric, Rng};
3use regex::Regex;
4
5use crate::constants::*;
6use crate::preprocess_data;
7use crate::structs::*;
8use crate::{
9    data::{PredictionInput, PredictionOutput},
10    space::wake_up_space,
11    stream::PredictionStream,
12};
13
14#[derive(Default)]
15pub struct ClientOptions {
16    pub hf_token: Option<String>,
17    pub auth: Option<(String, String)>,
18}
19
20impl ClientOptions {
21    pub fn with_hf_token(hf_token: String) -> ClientOptions {
22        Self {
23            hf_token: Some(hf_token),
24            auth: None,
25        }
26    }
27
28    pub fn with_auth(username: String, password: String) -> Self {
29        Self {
30            hf_token: None,
31            auth: Some((username, password)),
32        }
33    }
34}
35
36#[derive(Clone, Debug)]
37pub struct Client {
38    pub session_hash: String,
39    pub jwt: Option<String>,
40    pub http_client: reqwest::Client,
41    pub api_root: String,
42    pub space_id: Option<String>,
43    config: AppConfig,
44    api_info: ApiInfo,
45}
46
47impl Client {
48    /// Create a new client
49    ///
50    /// # Arguments
51    ///
52    /// * `app_reference` - The reference to the app
53    /// * `options` - The options for the client
54    ///
55    /// # Returns
56    ///
57    /// A new Gradio client
58    ///
59    /// # Errors
60    ///
61    /// If the client cannot be created
62    ///
63    /// # Example
64    ///
65    /// ```
66    /// use gradio::{Client, ClientOptions};
67    ///
68    /// #[tokio::main]
69    /// async fn main() {
70    ///    let client = Client::new(
71    ///         "gradio/hello_world",
72    ///         ClientOptions::default()
73    ///     ).await.unwrap();
74    ///     println!("{:?}", client);
75    /// }
76    /// ```
77    pub async fn new(app_reference: &str, options: ClientOptions) -> Result<Self> {
78        let session_hash: String = rand::thread_rng()
79            .sample_iter(&Alphanumeric)
80            .take(10)
81            .map(char::from)
82            .collect();
83
84        let http_client = Client::build_http_client(&options.hf_token)?;
85
86        let (mut api_root, space_id) =
87            Client::resolve_app_reference(&http_client, app_reference).await?;
88
89        if let Some((username, password)) = &options.auth {
90            Client::authenticate(&http_client, &api_root, username, password).await?;
91        }
92
93        if let Some(space_id) = &space_id {
94            wake_up_space(&http_client, space_id).await?;
95        }
96
97        let config = Client::fetch_config(&http_client, &api_root).await?;
98        if let Some(ref api_prefix) = config.api_prefix {
99            api_root.push_str(api_prefix);
100        }
101
102        let api_info = Client::fetch_api_info(&http_client, &api_root).await?;
103
104        Ok(Self {
105            session_hash,
106            jwt: None,
107            http_client,
108            api_root,
109            space_id,
110            config,
111            api_info,
112        })
113    }
114
115    pub fn view_config(&self) -> AppConfig {
116        self.config.clone()
117    }
118
119    pub fn view_api(&self) -> ApiInfo {
120        self.api_info.clone()
121    }
122
123    pub async fn submit(
124        &self,
125        route: &str,
126        data: Vec<PredictionInput>,
127    ) -> Result<PredictionStream> {
128        let data = preprocess_data(&self.http_client, &self.api_root, data).await?;
129        let fn_index = Client::resolve_fn_index(&self.config, route)?;
130        PredictionStream::new(&self.http_client, &self.api_root, fn_index, data).await
131    }
132
133    pub async fn predict(
134        &self,
135        route: &str,
136        data: Vec<PredictionInput>,
137    ) -> Result<Vec<PredictionOutput>> {
138        let mut stream = self.submit(route, data).await?;
139        while let Some(message) = stream.next().await {
140            match message {
141                Ok(message) => match message {
142                    QueueDataMessage::Open
143                    | QueueDataMessage::Estimation { .. }
144                    | QueueDataMessage::ProcessStarts { .. }
145                    | QueueDataMessage::Progress { .. }
146                    | QueueDataMessage::Log { .. }
147                    | QueueDataMessage::Heartbeat => {}
148                    QueueDataMessage::ProcessCompleted { output, .. } => {
149                        return output.try_into();
150                    }
151                    QueueDataMessage::UnexpectedError { message } => {
152                        return Err(Error::msg(
153                            message.unwrap_or_else(|| "Unexpected error".to_string()),
154                        ));
155                    }
156                    QueueDataMessage::Unknown(m) => {
157                        eprintln!("[warning] Skipping unknown message: {:?}", m);
158                    }
159                },
160                Err(err) => {
161                    return Err(err);
162                }
163            }
164        }
165
166        Err(Error::msg("Stream ended unexpectedly"))
167    }
168
169    fn build_http_client(hf_token: &Option<String>) -> Result<reqwest::Client> {
170        let mut http_client_builder = reqwest::Client::builder()
171            .cookie_store(true)
172            .user_agent("Rust Gradio Client");
173        if let Some(hf_token) = hf_token {
174            http_client_builder =
175                http_client_builder.default_headers(reqwest::header::HeaderMap::from_iter(vec![(
176                    reqwest::header::AUTHORIZATION,
177                    format!("Bearer {}", hf_token).parse()?,
178                )]));
179        }
180
181        http_client_builder.build().map_err(Error::new)
182    }
183
184    async fn resolve_app_reference(
185        http_client: &reqwest::Client,
186        app_reference: &str,
187    ) -> Result<(String, Option<String>)> {
188        let app_reference = app_reference.trim_end_matches('/').to_string();
189        let mut api_root = app_reference.clone();
190        let mut space_id = None;
191        if Regex::new("^[a-zA-Z0-9_\\-\\.]+\\/[a-zA-Z0-9_\\-\\.]+$")?.is_match(&app_reference) {
192            let url = format!(
193                "https://huggingface.co/api/spaces/{}/{}",
194                app_reference, HOST_URL
195            );
196            let res = http_client.get(&url).send().await?;
197            let res = res.json::<HuggingFaceAPIHost>().await?;
198            api_root.clone_from(&res.host);
199            space_id = Some(app_reference);
200        } else if Regex::new(".*hf\\.space\\/{0,1}$")?.is_match(&app_reference) {
201            space_id = Some(app_reference.replace(".hf.space", ""));
202        }
203
204        Ok((api_root, space_id))
205    }
206
207    async fn authenticate(
208        http_client: &reqwest::Client,
209        api_root: &str,
210        username: &str,
211        password: &str,
212    ) -> Result<()> {
213        let res = http_client
214            .post(&format!("{}/{}", api_root, LOGIN_URL))
215            .form(&[("username", username), ("password", password)])
216            .send()
217            .await?;
218        if !res.status().is_success() {
219            return Err(Error::msg("Login failed"));
220        }
221        Ok(())
222    }
223
224    async fn fetch_config(http_client: &reqwest::Client, api_root: &str) -> Result<AppConfig> {
225        let res = http_client
226            .get(&format!("{}/{}", api_root, CONFIG_URL))
227            .send()
228            .await?;
229        if !res.status().is_success() {
230            return Err(Error::msg("Could not resolve app config"));
231        }
232
233        let json = res.json::<serde_json::Value>().await?;
234        let config: AppConfigVersionOnly = serde_json::from_value(json.clone())?;
235
236        if !config.version.starts_with("5.") && !config.version.starts_with("4.") {
237            eprintln!(
238                "Warning: This client is supposed to work with Gradio 5 & 4. The current version of the app is {}, which may cause issues.",
239                config.version
240            );
241        }
242
243        serde_json::from_value(json).map_err(Error::new)
244    }
245
246    async fn fetch_api_info(http_client: &reqwest::Client, api_root: &str) -> Result<ApiInfo> {
247        let res = http_client
248            .get(&format!("{}/{}", api_root, API_INFO_URL))
249            .send()
250            .await?;
251        if !res.status().is_success() {
252            return Err(Error::msg("Could not get API info"));
253        }
254        res.json::<ApiInfo>().await.map_err(Error::new)
255    }
256
257    fn resolve_fn_index(config: &AppConfig, route: &str) -> Result<i64> {
258        let route = route.trim_start_matches('/');
259        let found = config
260            .dependencies
261            .iter()
262            .enumerate()
263            .find(|(_i, d)| d.api_name == route)
264            .ok_or_else(|| Error::msg("Invalid route"))?;
265
266        if found.1.id == -1 {
267            Ok(found.0 as i64)
268        } else {
269            Ok(found.1.id)
270        }
271    }
272}