use quote::quote;
use syn;
use syn::{Data, DataStruct, DeriveInput, Fields, FieldsNamed, Type};
use crate::proc_macro::TokenStream;
enum NestedType {
Plain, Option, Vec, }
const PRIMITIVE_TYPES: &[&str] = &[
"i8", "i16", "i32", "i64", "i128",
"u8", "u16", "u32", "u64", "u128",
"f32", "f64",
"isize", "usize",
"String", "str",
"bool",
"char",
];
fn is_primitive_type(type_name: &str) -> bool {
PRIMITIVE_TYPES.contains(&type_name)
}
fn detect_wrapper_type(ty: &Type) -> NestedType {
if let Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.last() {
if segment.ident == "Option" {
return NestedType::Option;
} else if segment.ident == "Vec" {
return NestedType::Vec;
}
}
}
NestedType::Plain
}
fn extract_type_name(ty: &Type) -> String {
if let Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.last() {
if segment.ident == "Option" || segment.ident == "Vec" {
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
for arg in &args.args {
if let syn::GenericArgument::Type(inner_ty) = arg {
return extract_type_name(inner_ty);
}
}
}
}
return segment.ident.to_string();
}
}
String::new()
}
pub(crate) fn impl_casbin(input: &mut DeriveInput) -> TokenStream {
let name = &input.ident;
let nested_field_info: Vec<(String, NestedType, String)> = if let Data::Struct(DataStruct {
fields: Fields::Named(FieldsNamed { named, .. }),
..
}) = &input.data
{
named
.iter()
.filter_map(|f| {
let inner_type_name = extract_type_name(&f.ty);
if !inner_type_name.is_empty() && !is_primitive_type(&inner_type_name) {
f.ident.as_ref().map(|ident| {
let field_name = ident.to_string();
let wrapper_type = detect_wrapper_type(&f.ty);
(field_name, wrapper_type, inner_type_name)
})
} else {
None
}
})
.collect()
} else {
vec![]
};
if let Data::Struct(DataStruct {
fields: Fields::Named(FieldsNamed { named, .. }),
..
}) = &mut input.data
{
for field in named.iter_mut() {
field.attrs.retain(|attr| {
!(attr.path.segments.len() == 1 && attr.path.segments[0].ident == "casbin")
});
}
}
let nested_filter_code: Vec<proc_macro2::TokenStream> = nested_field_info
.iter()
.map(|(field_name, wrapper_type, type_name)| {
match wrapper_type {
NestedType::Plain => {
quote! {
if let Some(v) = map.get_mut(#field_name) {
genies_auth::casbin_filter_object(v, #type_name, enforcer, subject);
}
}
}
NestedType::Option => {
quote! {
if let Some(v) = map.get_mut(#field_name) {
if !v.is_null() {
genies_auth::casbin_filter_object(v, #type_name, enforcer, subject);
}
}
}
}
NestedType::Vec => {
quote! {
if let Some(serde_json::Value::Array(arr)) = map.get_mut(#field_name) {
for item in arr.iter_mut() {
genies_auth::casbin_filter_object(item, #type_name, enforcer, subject);
}
}
}
}
}
})
.collect();
let expanded = quote! {
impl #name {
pub fn casbin_filter(
value: &mut serde_json::Value,
enforcer: &casbin::Enforcer,
subject: &str,
) {
use casbin::CoreApi;
let type_name = salvo::oapi::naming::assign_name::<#name>(salvo::oapi::naming::NameRule::Auto);
if let serde_json::Value::Object(map) = value {
let keys: Vec<String> = map.keys().cloned().collect();
for key in keys {
let resource = format!("{}.{}", type_name, key);
match enforcer.enforce((subject, &resource, "read")) {
Ok(false) => { map.remove(&key); }
_ => {} }
}
#(#nested_filter_code)*
}
}
}
};
let writer_impl = quote! {
#[async_trait::async_trait]
impl salvo::writing::Writer for #name {
async fn write(mut self, _req: &mut salvo::prelude::Request, depot: &mut salvo::prelude::Depot, res: &mut salvo::prelude::Response) {
let enforcer = depot.get::<std::sync::Arc<casbin::Enforcer>>("casbin_enforcer").ok().cloned();
let subject = depot.get::<String>("casbin_subject").ok().cloned();
match serde_json::to_value(&self) {
Ok(mut value) => {
if let (Some(ref e), Some(ref s)) = (enforcer, subject) {
Self::casbin_filter(&mut value, e, s);
}
res.render(salvo::prelude::Json(value));
}
Err(e) => {
res.status_code(salvo::http::StatusCode::INTERNAL_SERVER_ERROR);
res.render(format!("Serialization error: {}", e));
}
}
}
}
};
TokenStream::from(quote! {
#input
#expanded
#writer_impl
})
}