1use gradio::ClientOptions;
2use heck::ToSnakeCase;
3use proc_macro2::{Ident, Span};
4use proc_macro::TokenStream;
5use syn::{parse_macro_input, punctuated::Punctuated, Expr, ItemStruct, Meta};
6use quote::quote;
7
8
9enum Syncity {
10 Sync,
11 Async,
12}
13
14fn make_compile_error(message: &str) -> TokenStream {
15 syn::Error::new(Span::call_site(), message).to_compile_error().into()
16}
17
18#[proc_macro_attribute]
69pub fn gradio_api(args: TokenStream, input: TokenStream) -> TokenStream {
70 let args = parse_macro_input!(args with Punctuated::<Meta, syn::Token![,]>::parse_terminated);
71 let input = parse_macro_input!(input as ItemStruct);
72 let (mut url, mut option, mut grad_token, mut grad_login, mut grad_password) = (None, None, None, None, None);
73
74 for item in args.iter() {
76 let Ok(meta_value) = item.require_name_value() else {continue;};
77 let Expr::Lit(ref lit_val) = meta_value.value else {continue;};
78 let syn::Lit::Str(ref lit_val) = lit_val.lit else {continue;};
79 let arg_value = lit_val.value();
80 if item.path().is_ident("url") {
81 url = Some(arg_value);
82 } else if item.path().is_ident("option") {
83 option = Some(if arg_value == "sync" { Syncity::Sync } else { Syncity::Async });
84 } else if item.path().is_ident("hf_token") {
85 grad_token = Some(arg_value);
86 } else if item.path().is_ident("auth_username") {
87 grad_login = Some(arg_value);
88 } else if item.path().is_ident("auth_password") {
89 grad_password = Some(arg_value);
90 }
91 }
92
93 if url.is_none() {
95 return make_compile_error("url is required");
96 }
97 let mut grad_opts = ClientOptions::default();
98 let mut grad_auth = None;
99 if grad_token.is_some() {
100 grad_opts.hf_token = grad_token.clone();
101 }
102 if grad_login.is_some() ^ grad_password.is_some() {
103 return make_compile_error("Both login and password must be present!");
104 } else if grad_login.is_some() && grad_password.is_some() {
105 grad_auth = Some((grad_login.clone().unwrap(), grad_password.clone().unwrap()));
106 grad_opts.auth = grad_auth.clone();
107 }
108
109 let Some(option) = option else {
111 return make_compile_error("option is required");
112 };
113
114 let client = gradio::Client::new_sync(&(url.clone().unwrap()[..]), grad_opts).unwrap();
116 let api = client.view_api().named_endpoints;
117
118 let grad_auth_ts = if grad_auth.is_some() {
120 quote! {Some((#grad_login, #grad_password))}
121 } else { quote!{None}};
122 let grad_token_ts = if let Some(val) = grad_token {
123 quote! {Some(#val)}
124 } else { quote!{None}};
125 let grad_opts_ts = quote! {
126 gradio::ClientOptions {
127 auth: #grad_auth_ts,
128 hf_token: #grad_token_ts
129 }
130 };
131
132
133 let mut functions: Vec<proc_macro2::TokenStream> = Vec::new();
135 for (name, info) in api.iter() {
136 let method_name = Ident::new(&(name.to_snake_case()), Span::call_site());
137 let background_name = Ident::new(&format!("{}_background", name.to_snake_case()), Span::call_site());
138
139 let (args, args_call): (Vec<proc_macro2::TokenStream>, Vec<proc_macro2::TokenStream>) = info.parameters.iter().enumerate().map(|(i, arg)| {
140 let (_arg_name, arg_ident) = match &arg.label {
141 Some(label) => (label.clone(), Ident::new(&label.to_snake_case(), Span::call_site())),
142 None => (format!("arg{}", i), Ident::new(&format!("arg{}", i), Span::call_site())),
143 };
144 let is_file = arg.python_type.r#type == "filepath";
145 let arg_type: proc_macro2::TokenStream = if is_file {
146 quote! { impl Into<std::path::PathBuf> }
147 } else {
148 quote! { impl gradio::serde::Serialize }
149 };
150 (quote! { #arg_ident: #arg_type },
151 if is_file { quote! { gradio::PredictionInput::from_file(#arg_ident) } }
152 else { quote! { gradio::PredictionInput::from_value(#arg_ident) } })
153 }).unzip();
154
155 let function: TokenStream = match option {
157 Syncity::Sync => {
158 quote! {
159 pub fn #method_name(&self, #(#args),*) -> Result<Vec<gradio::PredictionOutput>, gradio::anyhow::Error> {
160 self.client.predict_sync(#name, vec![#(#args_call.into()),*])
161 }
162
163 pub fn #background_name(&self, #(#args),*) -> Result<gradio::PredictionStream, gradio::anyhow::Error> {
164 self.client.submit_sync(#name, vec![#(#args_call.into()),*])
165 }
166 }
167 },
168 Syncity::Async => {
169 quote! {
170 pub async fn #method_name(&self, #(#args),*) -> Result<Vec<gradio::PredictionOutput>, gradio::anyhow::Error> {
171 self.client.predict(#name, vec![#(#args_call.into()),*]).await
172 }
173
174 pub async fn #background_name(&self, #(#args),*) -> Result<gradio::PredictionStream, gradio::anyhow::Error> {
175 self.client.submit(#name, vec![#(#args_call.into()),*]).await
176 }
177 }
178 },
179 }.into();
180
181 functions.push(function.into());
182 }
183
184 let vis = input.vis.clone();
186 let struct_name = input.ident.clone();
187 let api_struct = match option {
188 Syncity::Sync => {
189 quote! {
190 #vis struct #struct_name {
191 client: gradio::Client
192 }
193
194 impl #struct_name {
195 pub fn new() -> Result<Self, ()> {
196 match gradio::Client::new_sync(#url, #grad_opts_ts) {
197 Ok(client) => Ok(Self { client }),
198 Err(_) => Err(())
199 }
200 }
201
202 pub fn custom_endpoint(&self, endpoint: &str, arguments: Vec<gradio::PredictionInput>) -> Result<Vec<gradio::PredictionOutput>, gradio::anyhow::Error> {
203 self.client.predict_sync(endpoint, arguments)
204 }
205
206 pub fn custom_endpoint_background(&self, endpoint: &str, arguments: Vec<gradio::PredictionInput>) -> Result<gradio::PredictionStream, gradio::anyhow::Error> {
207 self.client.submit_sync(endpoint, arguments)
208 }
209
210
211 #(#functions)*
212 }
213 }
214 },
215 Syncity::Async => {
216 quote! {
217 #vis struct #struct_name {
218 client: gradio::Client
219 }
220
221 impl #struct_name {
222 pub async fn new() -> Result<Self, ()> {
223 match gradio::Client::new(#url, #grad_opts_ts).await {
224 Ok(client) => Ok(Self { client }),
225 Err(_) => Err(())
226 }
227 }
228
229 pub async fn custom_endpoint(&self, endpoint: &str, arguments: Vec<gradio::PredictionInput>) -> Result<Vec<gradio::PredictionOutput>, gradio::anyhow::Error> {
230 self.client.predict(endpoint, arguments).await
231 }
232
233 pub async fn custom_endpoint_background(&self, endpoint: &str, arguments: Vec<gradio::PredictionInput>) -> Result<gradio::PredictionStream, gradio::anyhow::Error> {
234 self.client.submit(endpoint, arguments).await
235 }
236
237 #(#functions)*
238 }
239 }
240 },
241 };
242
243 api_struct.into()
244}