use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use std::collections::HashSet;
use syn::parse::{Parse, ParseStream};
use syn::{parse_macro_input, FnArg, Ident, ItemFn, LitStr, Pat, PatType, Token, Type};
struct RouteAttr {
path: LitStr,
extractors: HashSet<String>,
}
impl Parse for RouteAttr {
fn parse(input: ParseStream) -> syn::Result<Self> {
let path: LitStr = input.parse()?;
let mut extractors = HashSet::new();
while !input.is_empty() {
let _ = input.parse::<Token![,]>();
if input.is_empty() {
break;
}
let name: Ident = input.parse()?;
extractors.insert(name.to_string());
if input.peek(Token![:]) {
let _: Token![:] = input.parse()?;
let _: Type = input.parse()?;
}
}
Ok(RouteAttr { path, extractors })
}
}
fn is_option_type(ty: &Type) -> bool {
if let Type::Path(tp) = ty {
if let Some(seg) = tp.path.segments.last() {
return seg.ident == "Option";
}
}
false
}
fn unwrap_form_type(ty: &Type) -> Option<Type> {
if let Type::Path(tp) = ty {
if let Some(seg) = tp.path.segments.last() {
if seg.ident == "Form" {
if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
return Some(inner.clone());
}
}
}
}
}
None
}
fn parse_path(path: &str) -> (String, Vec<String>, Vec<String>) {
let (path_part, query_part) = match path.find('?') {
Some(i) => (&path[..i], Some(&path[i + 1..])),
None => (path, None),
};
let mut path_args = Vec::new();
let mut template = String::new();
let mut chars = path_part.chars().peekable();
let mut prev = '\0';
while let Some(c) = chars.next() {
if c == '{' {
let mut name = String::new();
while let Some(&nc) = chars.peek() {
if nc == '}' {
chars.next();
break;
}
name.push(nc);
chars.next();
}
path_args.push(name);
template.push_str("{}");
prev = '}';
} else if c == ':' && (prev == '\0' || prev == '/') {
let mut name = String::new();
while let Some(&nc) = chars.peek() {
if nc == '/' {
break;
}
name.push(nc);
chars.next();
}
prev = name.chars().last().unwrap_or(':');
path_args.push(name);
template.push_str("{}");
} else if c == '}' {
template.push(c);
prev = c;
} else {
template.push(c);
prev = c;
}
}
let query_args = query_part
.map(|q| {
q.split('&')
.filter(|s| !s.is_empty())
.map(|s| s.to_string())
.collect()
})
.unwrap_or_default();
(template, path_args, query_args)
}
pub fn server_fn_impl(method: &str, attr: TokenStream, item: TokenStream) -> TokenStream {
let route = parse_macro_input!(attr as RouteAttr);
let func = parse_macro_input!(item as ItemFn);
let fn_vis = &func.vis;
let fn_sig = &func.sig;
let fn_name = &fn_sig.ident;
let fn_generics = &fn_sig.generics;
let fn_output = &fn_sig.output;
let fn_attrs = &func.attrs;
let (path_template, path_args, query_args) = parse_path(&route.path.value());
let path_arg_set: HashSet<&String> = path_args.iter().collect();
let query_arg_set: HashSet<&String> = query_args.iter().collect();
let typed_args: Vec<(&Ident, &Type)> = fn_sig
.inputs
.iter()
.filter_map(|input| match input {
FnArg::Typed(PatType { pat, ty, .. }) => {
if let Pat::Ident(pi) = pat.as_ref() {
Some((&pi.ident, ty.as_ref()))
} else {
None
}
}
_ => None,
})
.collect();
let client_args: Vec<(&Ident, &Type)> = typed_args
.iter()
.filter(|(name, _)| !route.extractors.contains(&name.to_string()))
.copied()
.collect();
let path_idents: Vec<&Ident> = path_args
.iter()
.filter_map(|name| {
client_args
.iter()
.find(|(n, _)| n.to_string() == *name)
.map(|(n, _)| *n)
})
.collect();
let query_idents: Vec<&Ident> = query_args
.iter()
.filter_map(|name| {
client_args
.iter()
.find(|(n, _)| n.to_string() == *name)
.map(|(n, _)| *n)
})
.collect();
let query_names: Vec<String> = query_idents.iter().map(|i| i.to_string()).collect();
let body_idents: Vec<&Ident> = client_args
.iter()
.filter(|(n, _)| {
let s = n.to_string();
!path_arg_set.contains(&s) && !query_arg_set.contains(&s)
})
.map(|(n, _)| *n)
.collect();
let stub_inputs: Vec<TokenStream2> = client_args
.iter()
.map(|(name, ty)| {
if let Some(inner) = unwrap_form_type(ty) {
quote! { #name: #inner }
} else {
quote! { #name: #ty }
}
})
.collect();
let path_format = if path_idents.is_empty() {
quote! { let __path: ::std::string::String = #path_template.to_string(); }
} else {
let tpl = LitStr::new(&path_template, route.path.span());
quote! {
let __path: ::std::string::String = format!(
#tpl,
#( ::urlencoding::encode(
&crate::common::fullstack::server_fn::to_url_value(&#path_idents)
) ),*
);
}
};
let query_attach = if query_idents.is_empty() {
quote! { let __url: ::std::string::String = __path; }
} else {
let pushers = query_idents.iter().zip(query_names.iter()).map(|(ident, name)| {
let ty = client_args
.iter()
.find(|(n, _)| n.to_string() == ident.to_string())
.map(|(_, t)| *t);
let is_option = ty.map(is_option_type).unwrap_or(false);
if is_option {
quote! {
if let ::std::option::Option::Some(v) = &#ident {
if __has_q { __url.push('&'); } else { __url.push('?'); __has_q = true; }
__url.push_str(#name);
__url.push('=');
__url.push_str(&::urlencoding::encode(
&crate::common::fullstack::server_fn::to_url_value(v)
));
}
}
} else {
quote! {
{
if __has_q { __url.push('&'); } else { __url.push('?'); __has_q = true; }
__url.push_str(#name);
__url.push('=');
__url.push_str(&::urlencoding::encode(
&crate::common::fullstack::server_fn::to_url_value(&#ident)
));
}
}
}
});
quote! {
let mut __url = __path;
let mut __has_q = false;
#( #pushers )*
}
};
let send_call = match method {
"GET" | "DELETE" => {
let fn_name = format_ident!("{}", method.to_lowercase());
quote! {
crate::common::fullstack::server_fn::#fn_name(&__url)
.await
.map_err(::std::convert::Into::into)
}
}
"POST" | "PUT" | "PATCH" => {
let fn_name = format_ident!("{}", method.to_lowercase());
if let Some(body) = body_idents.first() {
let body_name = LitStr::new(&body.to_string(), body.span());
quote! {
crate::common::fullstack::server_fn::#fn_name(
&__url,
&::serde_json::json!({ #body_name: &#body }),
)
.await
.map_err(::std::convert::Into::into)
}
} else {
quote! {
crate::common::fullstack::server_fn::#fn_name(&__url, &())
.await
.map_err(::std::convert::Into::into)
}
}
}
_ => unreachable!("unsupported method {method}"),
};
let tauri = quote! {
#(#fn_attrs)*
#[allow(unused_variables, unused_mut)]
#fn_vis async fn #fn_name #fn_generics ( #( #stub_inputs ),* ) #fn_output {
#path_format
#query_attach
#send_call
}
};
quote! {
#tauri
}
.into()
}