use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{parse_macro_input, punctuated::Punctuated, ItemFn, LitStr, Token};
pub fn sa_check_login_impl(_attr: TokenStream, item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as ItemFn);
let fn_name = &input.sig.ident;
let fn_inputs = &input.sig.inputs;
let fn_output = &input.sig.output;
let fn_body = &input.block;
let fn_attrs = &input.attrs;
let fn_vis = &input.vis;
let fn_asyncness = &input.sig.asyncness;
let fn_generics = &input.sig.generics;
let fn_where_clause = &input.sig.generics.where_clause;
if fn_asyncness.is_none() {
return syn::Error::new_spanned(fn_name, "sa_check_login requires async function")
.to_compile_error()
.into();
}
let check_code = quote! {
if !spring_sa_token::StpUtil::is_login_current() {
return Err(spring_web::error::KnownWebError::unauthorized("Not logged in").into());
}
};
let expanded: TokenStream2 = quote! {
#(#fn_attrs)*
#fn_vis #fn_asyncness fn #fn_name #fn_generics(#fn_inputs) #fn_output #fn_where_clause {
#check_code
#fn_body
}
};
expanded.into()
}
pub fn sa_check_role_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
let role = parse_macro_input!(attr as LitStr);
let input = parse_macro_input!(item as ItemFn);
let fn_name = &input.sig.ident;
let fn_inputs = &input.sig.inputs;
let fn_output = &input.sig.output;
let fn_body = &input.block;
let fn_attrs = &input.attrs;
let fn_vis = &input.vis;
let fn_asyncness = &input.sig.asyncness;
let fn_generics = &input.sig.generics;
let fn_where_clause = &input.sig.generics.where_clause;
let role_value = role.value();
if fn_asyncness.is_none() {
return syn::Error::new_spanned(fn_name, "sa_check_role requires async function")
.to_compile_error()
.into();
}
let check_code = quote! {
let __login_id = spring_sa_token::StpUtil::get_login_id_as_string()
.await
.map_err(|_| spring_web::error::KnownWebError::unauthorized("Not logged in"))?;
if !spring_sa_token::StpUtil::has_role(&__login_id, #role_value).await {
return Err(spring_web::error::KnownWebError::forbidden(
format!("Missing required role: {}", #role_value)
).into());
}
};
let expanded: TokenStream2 = quote! {
#(#fn_attrs)*
#fn_vis #fn_asyncness fn #fn_name #fn_generics(#fn_inputs) #fn_output #fn_where_clause {
#check_code
#fn_body
}
};
expanded.into()
}
pub fn sa_check_permission_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
let permission = parse_macro_input!(attr as LitStr);
let perm_value = permission.value();
if perm_value.trim().is_empty() {
return syn::Error::new_spanned(&permission, "Permission identifier cannot be empty")
.to_compile_error()
.into();
}
let input = parse_macro_input!(item as ItemFn);
let fn_name = &input.sig.ident;
let fn_inputs = &input.sig.inputs;
let fn_output = &input.sig.output;
let fn_body = &input.block;
let fn_attrs = &input.attrs;
let fn_vis = &input.vis;
let fn_asyncness = &input.sig.asyncness;
let fn_generics = &input.sig.generics;
let fn_where_clause = &input.sig.generics.where_clause;
if fn_asyncness.is_none() {
return syn::Error::new_spanned(fn_name, "sa_check_permission requires async function")
.to_compile_error()
.into();
}
let check_code = quote! {
let __login_id = spring_sa_token::StpUtil::get_login_id_as_string()
.await
.map_err(|_| spring_web::error::KnownWebError::unauthorized("Not logged in"))?;
if !spring_sa_token::StpUtil::has_permission(&__login_id, #perm_value).await {
return Err(spring_web::error::KnownWebError::forbidden(
format!("Missing required permission: {}", #perm_value)
).into());
}
};
let expanded: TokenStream2 = quote! {
#(#fn_attrs)*
#fn_vis #fn_asyncness fn #fn_name #fn_generics(#fn_inputs) #fn_output #fn_where_clause {
#check_code
#fn_body
}
};
expanded.into()
}
pub fn sa_check_roles_and_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
let roles = parse_macro_input!(attr with Punctuated::<LitStr, Token![,]>::parse_terminated);
let input = parse_macro_input!(item as ItemFn);
let fn_name = &input.sig.ident;
let fn_inputs = &input.sig.inputs;
let fn_output = &input.sig.output;
let fn_body = &input.block;
let fn_attrs = &input.attrs;
let fn_vis = &input.vis;
let fn_asyncness = &input.sig.asyncness;
let fn_generics = &input.sig.generics;
let fn_where_clause = &input.sig.generics.where_clause;
if fn_asyncness.is_none() {
return syn::Error::new_spanned(fn_name, "sa_check_roles_and requires async function")
.to_compile_error()
.into();
}
let role_values: Vec<String> = roles.iter().map(|r| r.value()).collect();
let role_checks = role_values.iter().map(|role| {
quote! {
if !spring_sa_token::StpUtil::has_role(&__login_id, #role).await {
return Err(spring_web::error::KnownWebError::forbidden(
format!("Missing required role: {}", #role)
).into());
}
}
});
let check_code = quote! {
let __login_id = spring_sa_token::StpUtil::get_login_id_as_string()
.await
.map_err(|_| spring_web::error::KnownWebError::unauthorized("Not logged in"))?;
#(#role_checks)*
};
let expanded: TokenStream2 = quote! {
#(#fn_attrs)*
#fn_vis #fn_asyncness fn #fn_name #fn_generics(#fn_inputs) #fn_output #fn_where_clause {
#check_code
#fn_body
}
};
expanded.into()
}
pub fn sa_check_roles_or_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
let roles = parse_macro_input!(attr with Punctuated::<LitStr, Token![,]>::parse_terminated);
let input = parse_macro_input!(item as ItemFn);
let fn_name = &input.sig.ident;
let fn_inputs = &input.sig.inputs;
let fn_output = &input.sig.output;
let fn_body = &input.block;
let fn_attrs = &input.attrs;
let fn_vis = &input.vis;
let fn_asyncness = &input.sig.asyncness;
let fn_generics = &input.sig.generics;
let fn_where_clause = &input.sig.generics.where_clause;
if fn_asyncness.is_none() {
return syn::Error::new_spanned(fn_name, "sa_check_roles_or requires async function")
.to_compile_error()
.into();
}
let role_values: Vec<String> = roles.iter().map(|r| r.value()).collect();
let role_checks = role_values.iter().map(|role| {
quote! {
if spring_sa_token::StpUtil::has_role(&__login_id, #role).await {
__has_any_role = true;
}
}
});
let roles_str = role_values.join(", ");
let check_code = quote! {
let __login_id = spring_sa_token::StpUtil::get_login_id_as_string()
.await
.map_err(|_| spring_web::error::KnownWebError::unauthorized("Not logged in"))?;
let mut __has_any_role = false;
#(#role_checks)*
if !__has_any_role {
return Err(spring_web::error::KnownWebError::forbidden(
format!("Missing any of required roles: {}", #roles_str)
).into());
}
};
let expanded: TokenStream2 = quote! {
#(#fn_attrs)*
#fn_vis #fn_asyncness fn #fn_name #fn_generics(#fn_inputs) #fn_output #fn_where_clause {
#check_code
#fn_body
}
};
expanded.into()
}
pub fn sa_check_permissions_and_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
let permissions =
parse_macro_input!(attr with Punctuated::<LitStr, Token![,]>::parse_terminated);
let input = parse_macro_input!(item as ItemFn);
let fn_name = &input.sig.ident;
let fn_inputs = &input.sig.inputs;
let fn_output = &input.sig.output;
let fn_body = &input.block;
let fn_attrs = &input.attrs;
let fn_vis = &input.vis;
let fn_asyncness = &input.sig.asyncness;
let fn_generics = &input.sig.generics;
let fn_where_clause = &input.sig.generics.where_clause;
if fn_asyncness.is_none() {
return syn::Error::new_spanned(
fn_name,
"sa_check_permissions_and requires async function",
)
.to_compile_error()
.into();
}
let perm_values: Vec<String> = permissions.iter().map(|p| p.value()).collect();
let perm_checks = perm_values.iter().map(|perm| {
quote! {
if !spring_sa_token::StpUtil::has_permission(&__login_id, #perm).await {
return Err(spring_web::error::KnownWebError::forbidden(
format!("Missing required permission: {}", #perm)
).into());
}
}
});
let check_code = quote! {
let __login_id = spring_sa_token::StpUtil::get_login_id_as_string()
.await
.map_err(|_| spring_web::error::KnownWebError::unauthorized("Not logged in"))?;
#(#perm_checks)*
};
let expanded: TokenStream2 = quote! {
#(#fn_attrs)*
#fn_vis #fn_asyncness fn #fn_name #fn_generics(#fn_inputs) #fn_output #fn_where_clause {
#check_code
#fn_body
}
};
expanded.into()
}
pub fn sa_check_permissions_or_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
let permissions =
parse_macro_input!(attr with Punctuated::<LitStr, Token![,]>::parse_terminated);
let input = parse_macro_input!(item as ItemFn);
let fn_name = &input.sig.ident;
let fn_inputs = &input.sig.inputs;
let fn_output = &input.sig.output;
let fn_body = &input.block;
let fn_attrs = &input.attrs;
let fn_vis = &input.vis;
let fn_asyncness = &input.sig.asyncness;
let fn_generics = &input.sig.generics;
let fn_where_clause = &input.sig.generics.where_clause;
if fn_asyncness.is_none() {
return syn::Error::new_spanned(
fn_name,
"sa_check_permissions_or requires async function",
)
.to_compile_error()
.into();
}
let perm_values: Vec<String> = permissions.iter().map(|p| p.value()).collect();
let perm_checks = perm_values.iter().map(|perm| {
quote! {
if spring_sa_token::StpUtil::has_permission(&__login_id, #perm).await {
__has_any_perm = true;
}
}
});
let perms_str = perm_values.join(", ");
let check_code = quote! {
let __login_id = spring_sa_token::StpUtil::get_login_id_as_string()
.await
.map_err(|_| spring_web::error::KnownWebError::unauthorized("Not logged in"))?;
let mut __has_any_perm = false;
#(#perm_checks)*
if !__has_any_perm {
return Err(spring_web::error::KnownWebError::forbidden(
format!("Missing any of required permissions: {}", #perms_str)
).into());
}
};
let expanded: TokenStream2 = quote! {
#(#fn_attrs)*
#fn_vis #fn_asyncness fn #fn_name #fn_generics(#fn_inputs) #fn_output #fn_where_clause {
#check_code
#fn_body
}
};
expanded.into()
}
pub fn sa_ignore_impl(_attr: TokenStream, item: TokenStream) -> TokenStream {
item
}