use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::{Attribute, Field, Fields, ItemStruct};
use wheel_rs::str_utils::{CamelFormat, snake_to_pascal, split_camel_case};
pub fn crud_dto_macro(input: ItemStruct) -> TokenStream {
let struct_name = &input.ident;
let struct_name_str = struct_name.to_string();
if !struct_name_str.ends_with("Dto") {
return syn::Error::new_spanned(struct_name, "Struct name must end with 'Dto'")
.to_compile_error()
.into();
}
let struct_name_split = split_camel_case(&struct_name_str, CamelFormat::Upper);
if struct_name_split.is_err() {
return syn::Error::new_spanned(
struct_name,
"Struct name must be a valid upper camel case",
)
.to_compile_error()
.into();
}
let mut struct_name_split = struct_name_split.unwrap();
struct_name_split.pop();
let module_name = format_ident!("{}", struct_name_split.join("_").to_lowercase());
let entity_name = struct_name_str.strip_suffix("Dto").unwrap();
let add_dto_name = format_ident!("{}AddDto", entity_name);
let modify_dto_name = format_ident!("{}ModifyDto", entity_name);
let save_dto_name = format_ident!("{}SaveDto", entity_name);
let query_dto_name = format_ident!("{}QueryDto", entity_name);
let vis = &input.vis;
let fields = &input.fields;
let (add_fields, modify_fields, save_fields, query_fields, query_field_to_condition) =
match fields {
Fields::Named(named_fields) => {
let mut add_field_tokens = Vec::new();
let mut modify_field_tokens = Vec::new();
let mut save_field_tokens = Vec::new();
let mut query_field_tokens = Vec::new();
let mut query_field_to_condition_tokens = Vec::new();
for field in &named_fields.named {
if field.ident.as_ref().unwrap() == "id" {
continue;
}
add_field_tokens.push(process_field(field, "add"));
modify_field_tokens.push(process_field(field, "modify"));
save_field_tokens.push(process_field(field, "save"));
query_field_tokens.push(process_field(field, "query"));
query_field_to_condition_tokens
.push(add_query_field_to_condition_tokens(field));
}
(
quote! { #(#add_field_tokens)* },
quote! { #(#modify_field_tokens)* },
quote! { #(#save_field_tokens)* },
quote! { #(#query_field_tokens)* },
quote! { #(#query_field_to_condition_tokens)* },
)
}
_ => (quote! {}, quote! {}, quote! {}, quote! {}, quote! {}),
};
let expanded = quote! {
use std::fmt::{Display, Formatter};
use derive_setters::Setters;
use sea_orm::{ActiveValue, ColumnTrait, Condition};
use typed_builder::TypedBuilder;
use wheel_rs::serde::{option_option_serde, u64_option_serde};
use crate::model::#module_name::{ActiveModel, Column};
#[derive(o2o::o2o, utoipa::ToSchema, Debug, Default, serde::Deserialize, validator::Validate, Setters, TypedBuilder)]
#[serde(default, rename_all = "camelCase")]
#[owned_into(ActiveModel)]
#[ghosts(
updator_id: Default::default(),
create_timestamp: Default::default(),
update_timestamp: Default::default(),
)]
#[builder]
#vis struct #add_dto_name {
#[into(match ~ {Some(v)=>ActiveValue::Set(v as i64),None=>ActiveValue::NotSet})]
#[serde(with = "u64_option_serde")]
#[builder(default, setter(strip_option))]
pub id: Option<u64>,
#add_fields
#[serde(skip_deserializing)]
#[into(creator_id, ActiveValue::Set(~ as i64))]
pub _current_user_id: u64,
}
#[derive(o2o::o2o, utoipa::ToSchema, Debug, Default, serde::Deserialize, validator::Validate, Setters, TypedBuilder)]
#[serde(default, rename_all = "camelCase")]
#[owned_into(ActiveModel)]
#[ghosts(
creator_id: Default::default(),
create_timestamp: Default::default(),
update_timestamp: Default::default(),
)]
#[builder]
#vis struct #modify_dto_name {
#[validate(required(message = "id不能为空"))]
#[into(match ~ {Some(v)=>ActiveValue::Set(v as i64),None=>ActiveValue::NotSet})]
#[builder(default, setter(strip_option))]
#[serde(with = "u64_option_serde")]
pub id: Option<u64>,
#modify_fields
#[serde(skip_deserializing)]
#[into(updator_id, ActiveValue::Set(~ as i64))]
pub _current_user_id: u64,
}
#[derive(o2o::o2o, utoipa::ToSchema, Debug, Default, serde::Deserialize, Setters, TypedBuilder)]
#[serde(default, rename_all = "camelCase")]
#[owned_into(#add_dto_name)]
#[owned_into(#modify_dto_name)]
#[builder]
#vis struct #save_dto_name {
#[serde(with = "u64_option_serde")]
#[builder(default, setter(strip_option))]
pub id: Option<u64>,
#save_fields
#[serde(skip_deserializing)]
pub _current_user_id: u64,
}
#[derive(utoipa::ToSchema, utoipa::IntoParams, Debug, Default, serde::Deserialize, Setters, TypedBuilder)]
#[serde(default, rename_all = "camelCase")]
#[builder]
#vis struct #query_dto_name {
#[serde(with = "u64_option_serde")]
#[builder(default, setter(strip_option))]
pub id: Option<u64>,
#query_fields
#[serde(rename = "_keyword")]
#[builder(default, setter(strip_option))]
pub _keyword: Option<String>,
#[serde(skip_deserializing)]
#[builder(default, setter(strip_option))]
pub _current_user_id: Option<u64>,
#[serde(rename = "_orderBy")]
#[builder(default, setter(strip_option))]
pub _order_by: Option<String>,
#[serde(rename = "_page")]
#[builder(default, setter(strip_option))]
pub _page: Option<u64>,
#[serde(rename = "_size")]
#[builder(default, setter(strip_option))]
pub _size: Option<u64>,
}
impl Display for #query_dto_name {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
impl #query_dto_name {
pub fn to_condition(&self) -> Condition {
let mut condition = Condition::all();
if let Some(id) = self.id {
condition = condition.add(Column::Id.eq(id as i64));
}
#query_field_to_condition
condition
}
}
};
TokenStream::from(expanded)
}
fn process_field(field: &Field, target: &str) -> TokenStream {
let field_name = &field.ident;
let field_ty = &field.ty;
let original_attrs = &field.attrs;
let has_into = has_attr(original_attrs, "into");
let has_validate = has_attr(original_attrs, "validate");
let has_builder = has_attr(original_attrs, "builder");
let has_serde = has_attr(original_attrs, "serde");
let mut new_attrs = Vec::new();
for attr in original_attrs {
if target == "modify" || target == "save" {
if !attr.path().is_ident("validate") {
new_attrs.push(attr.clone());
}
} else {
new_attrs.push(attr.clone());
}
}
if target == "add" && !has_validate {
if let Some(attr) = generate_validate_attr(field) {
new_attrs.push(attr);
}
}
if (target == "add" || target == "modify") && !has_into {
new_attrs.push(generate_into_attr(field));
}
if !has_builder {
new_attrs.push(generate_builder_attr());
}
if !has_serde {
if let Some(attr) = generate_serde_attr(field) {
new_attrs.push(attr);
}
}
let wrapped_ty = wrap_type(field_ty);
quote! {
#(#new_attrs)*
pub #field_name: #wrapped_ty,
}
}
fn extract_field_comment(field: &Field) -> String {
for attr in &field.attrs {
if let Some(comment) = get_doc_comment(attr) {
return comment.trim().to_string();
}
}
field.ident.as_ref().unwrap().to_string()
}
fn get_doc_comment(attr: &Attribute) -> Option<String> {
if attr.path().is_ident("doc") {
if let syn::Meta::NameValue(meta) = &attr.meta {
if let syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(lit),
..
}) = &meta.value
{
return Some(lit.value().trim().to_string());
}
}
}
None
}
fn has_attr(attrs: &[Attribute], name: &str) -> bool {
attrs.iter().any(|attr| attr.path().is_ident(name))
}
fn generate_validate_attr(field: &Field) -> Option<Attribute> {
let ty = &field.ty;
if let syn::Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.last() {
let type_name = segment.ident.to_string();
if "Option" != type_name.as_str() {
let field_required = format!("{}不能为空", extract_field_comment(field));
return Some(if type_name.as_str() == "String" {
syn::parse_quote!(
#[validate(
required(message = #field_required),
length(min = 1, message = #field_required)
)]
)
} else {
syn::parse_quote!(
#[validate(required(message = #field_required))]
)
});
}
}
}
None
}
fn generate_into_attr(field: &Field) -> Attribute {
let ty = &field.ty;
if is_option_type(ty) {
if let syn::Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.last() {
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
if let syn::Type::Path(inner_path) = inner_ty {
if let Some(inner_segment) = inner_path.path.segments.last() {
let inner_type_name = inner_segment.ident.to_string();
return match inner_type_name.as_str() {
"u64" | "u32" | "u16" | "u8" => {
syn::parse_quote!(
#[into(match ~ {Some(v)=>ActiveValue::Set(v.map(|x| x as i64)),None=>ActiveValue::NotSet})]
)
}
"i64" | "i32" | "i16" | "i8" => {
syn::parse_quote!(
#[into(match ~ {Some(v)=>ActiveValue::Set(v),None=>ActiveValue::NotSet})]
)
}
"String" => {
syn::parse_quote!(
#[into(match ~ {Some(v)=>ActiveValue::Set(v),None=>ActiveValue::NotSet})]
)
}
_ => {
syn::parse_quote!(
#[into(match ~ {Some(v)=>ActiveValue::Set(v),None=>ActiveValue::NotSet})]
)
}
};
}
}
}
}
}
}
} else {
if let syn::Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.last() {
let type_name = segment.ident.to_string();
return match type_name.as_str() {
"u64" | "u32" | "u16" | "u8" => {
syn::parse_quote!(
#[into(match ~ {Some(v)=>ActiveValue::Set(v as i64),None=>ActiveValue::NotSet})]
)
}
"i64" | "i32" | "i16" | "i8" => {
syn::parse_quote!(
#[into(match ~ {Some(v)=>ActiveValue::Set(v),None=>ActiveValue::NotSet})]
)
}
"String" => {
syn::parse_quote!(
#[into(match ~ {Some(v)=>ActiveValue::Set(v),None=>ActiveValue::NotSet})]
)
}
_ => {
syn::parse_quote!(
#[into(match ~ {Some(v)=>ActiveValue::Set(v),None=>ActiveValue::NotSet})]
)
}
};
}
}
}
syn::parse_quote!(
#[into(match ~ {Some(v)=>ActiveValue::Set(v),None=>ActiveValue::NotSet})]
)
}
fn generate_builder_attr() -> Attribute {
syn::parse_quote!(
#[builder(default, setter(strip_option))]
)
}
fn generate_serde_attr(field: &Field) -> Option<Attribute> {
let ty = &field.ty;
if let syn::Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.last() {
let type_name = segment.ident.to_string();
match type_name.as_str() {
"u64" => {
return Some(syn::parse_quote!(
#[serde(with = "u64_option_serde")]
));
}
"Option" => {
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
if let syn::Type::Path(inner_path) = inner_ty {
if let Some(inner_segment) = inner_path.path.segments.last() {
if inner_segment.ident == "String" {
return Some(syn::parse_quote!(
#[serde(deserialize_with = "option_option_serde::deserialize")]
));
}
}
}
}
}
}
_ => {}
}
}
}
None
}
fn is_option_type(ty: &syn::Type) -> bool {
if let syn::Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.last() {
return segment.ident == "Option";
}
}
false
}
fn wrap_type(ty: &syn::Type) -> syn::Type {
let type_str = quote! { #ty }.to_string();
let wrapped_str = format!("Option<{}>", type_str);
syn::parse_str(&wrapped_str).unwrap_or_else(|_| ty.clone())
}
fn add_query_field_to_condition_tokens(field: &Field) -> TokenStream {
let field_name = field.ident.as_ref().unwrap();
let column_name = format_ident!("{}", snake_to_pascal(&field_name.to_string()));
let ty = &field.ty;
if is_option_type(ty) {
if let syn::Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.last() {
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
if let syn::Type::Path(inner_path) = inner_ty {
if let Some(inner_segment) = inner_path.path.segments.last() {
let inner_type_name = inner_segment.ident.to_string();
let v = get_value_token_stream(&inner_type_name);
return quote! {
if let Some(v) = self.#field_name.as_ref() {
if let Some(v) = v {
condition = condition.add(Column::#column_name.eq(#v));
}else {
condition = condition.add(Column::#column_name.is_null());
}
}
};
}
}
}
}
}
}
}
if let syn::Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.last() {
let inner_type_name = segment.ident.to_string();
let v = get_value_token_stream(&inner_type_name);
return quote! {
if let Some(v) = self.#field_name.as_ref() {
condition = condition.add(Column::#column_name.eq(#v));
}
};
}
}
syn::Error::new_spanned(
ty,
format!(
"字段 '{}' 的类型无法识别,请使用基本类型或 Option<T> 包裹的基本类型",
field_name
),
)
.to_compile_error()
.into()
}
fn get_value_token_stream(type_name: &str) -> TokenStream {
match type_name {
"u8" => {
quote! { *v as i8 }
}
"u16" => {
quote! { *v as i16 }
}
"u32" => {
quote! { *v as i32 }
}
"u64" => {
quote! { *v as i64 }
}
"u128" => {
quote! { *v as i128 }
}
"bool" => {
quote! { *v }
}
_ => {
quote! { v }
}
}
}