dataclass_macro/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::punctuated::Punctuated;
5use syn::token::Comma;
6use syn::{parse_macro_input, Attribute, Data, DeriveInput, Expr, Fields, Lit, Meta};
7
8// 定义配置选项结构体
9#[derive(Default)]
10struct DataclassOptions {
11    init: bool,
12    repr: bool,
13    eq: bool,
14    order: bool,
15    unsafe_hash: bool,
16    frozen: bool,
17    match_args: bool,
18    kw_only: bool,
19    slots: bool,
20    weakref_slot: bool,
21}
22
23impl DataclassOptions {
24    fn from_meta_list(meta_list: Punctuated<Meta, Comma>) -> Self {
25        let mut options = DataclassOptions {
26            init: true, // 默认值
27            repr: true,
28            eq: true,
29            order: false,
30            unsafe_hash: false,
31            frozen: false,
32            match_args: true,
33            kw_only: false,
34            slots: false,
35            weakref_slot: false,
36        };
37
38        for meta in meta_list {
39            match meta {
40                Meta::NameValue(nv) => {
41                    if let Some(ident) = nv.path.get_ident() {
42                        let value = match nv.value {
43                            Expr::Lit(expr_lit) => match expr_lit.lit {
44                                Lit::Bool(lit_bool) => lit_bool.value(),
45                                _ => panic!("Expected boolean value for option {}", ident),
46                            },
47                            _ => panic!("Expected literal value for option {}", ident),
48                        };
49
50                        match ident.to_string().as_str() {
51                            "init" => options.init = value,
52                            "repr" => options.repr = value,
53                            "eq" => options.eq = value,
54                            "order" => options.order = value,
55                            "unsafe_hash" => options.unsafe_hash = value,
56                            "kw_only" => options.kw_only = value,
57                            "slots" => options.slots = value,
58                            "frozen" => options.frozen = value,
59                            "match_args" => options.match_args = value,
60                            "weakref_slot" => options.weakref_slot = value,
61                            _ => panic!("Unknown option: {}", ident),
62                        }
63                    }
64                }
65                _ => panic!("Expected name = value pair"),
66            }
67        }
68
69        options
70    }
71}
72
73fn has_serde_attribute(attrs: &[Attribute]) -> bool {
74    attrs.iter().any(|attr| {
75        if let Ok(Meta::Path(path)) = attr.parse_args::<Meta>() {
76            path.is_ident("serde")
77        } else {
78            false
79        }
80    })
81}
82
83#[proc_macro_attribute]
84pub fn dataclass(args: TokenStream, input: TokenStream) -> TokenStream {
85    let args =
86        parse_macro_input!(args with syn::punctuated::Punctuated::<Meta, Comma>::parse_terminated);
87    let mut input = parse_macro_input!(input as DeriveInput);
88
89    let options = DataclassOptions::from_meta_list(args);
90
91    // check if serde attribute is already present
92    if !has_serde_attribute(&input.attrs) {
93        // add serde derive attribute
94        input.attrs.push(syn::parse_quote!(
95            #[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
96        ));
97    }
98
99    implement_dataclass(input, options)
100}
101
102fn implement_dataclass(input: DeriveInput, options: DataclassOptions) -> TokenStream {
103    let struct_name = &input.ident;
104    let attrs = &input.attrs;
105
106    let fields = match &input.data {
107        Data::Struct(data_struct) => match &data_struct.fields {
108            Fields::Named(fields_named) => &fields_named.named,
109            _ => panic!("Dataclass only works with named fields"),
110        },
111        _ => panic!("Dataclass only works with structs"),
112    };
113
114    let field_names: Vec<_> = fields
115        .iter()
116        .map(|field| field.ident.as_ref().unwrap())
117        .collect();
118    let field_types: Vec<_> = fields.iter().map(|field| &field.ty).collect();
119
120    let mut implementations = TokenStream2::new();
121
122    // (init option)
123    if options.init {
124        let constructor = if options.kw_only {
125            quote! {
126                impl #struct_name {
127                    pub fn new(#(#field_names: #field_types),*) -> Self {
128                        Self {
129                            #(#field_names,)*
130                        }
131                    }
132                }
133            }
134        } else {
135            quote! {
136                impl #struct_name {
137                    pub fn new(#(#field_names: #field_types),*) -> Self {
138                        Self {
139                            #(#field_names,)*
140                        }
141                    }
142                }
143            }
144        };
145        implementations.extend(constructor);
146    }
147
148    // Debug (repr option)
149    if options.repr {
150        let debug_impl = quote! {
151            impl std::fmt::Debug for #struct_name {
152                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153                    f.debug_struct(stringify!(#struct_name))
154                        #(.field(stringify!(#field_names), &self.#field_names))*
155                        .finish()
156                }
157            }
158        };
159        implementations.extend(debug_impl);
160    }
161
162    // (eq option)
163    if options.eq {
164        let eq_impl = quote! {
165            impl PartialEq for #struct_name {
166                fn eq(&self, other: &Self) -> bool {
167                    #(self.#field_names == other.#field_names)&&*
168                }
169            }
170
171            impl Eq for #struct_name {}
172        };
173        implementations.extend(eq_impl);
174    }
175
176    // (order option)
177    if options.order {
178        let ord_impl = quote! {
179            impl PartialOrd for #struct_name {
180                fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
181                    Some(self.cmp(other))
182                }
183            }
184
185            impl Ord for #struct_name {
186                fn cmp(&self, other: &Self) -> std::cmp::Ordering {
187                    #(
188                        if let std::cmp::Ordering::Equal = self.#field_names.cmp(&other.#field_names) {
189                        } else {
190                            return self.#field_names.cmp(&other.#field_names);
191                        }
192                    )*
193                    std::cmp::Ordering::Equal
194                }
195            }
196        };
197        implementations.extend(ord_impl);
198    }
199
200    // Hash (unsafe_hash option)
201    if options.unsafe_hash {
202        let hash_impl = quote! {
203            impl std::hash::Hash for #struct_name {
204                fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
205                    #(self.#field_names.hash(state);)*
206                }
207            }
208        };
209        implementations.extend(hash_impl);
210    }
211
212    // (frozen option)
213    let struct_fields = if options.frozen {
214        quote! {
215            #(pub(crate) #field_names: #field_types,)*
216        }
217    } else {
218        quote! {
219            #(pub #field_names: #field_types,)*
220        }
221    };
222
223    let expanded = quote! {
224        #[derive(Clone)]
225        #(#attrs)*
226        pub struct #struct_name {
227            #struct_fields
228        }
229
230        #implementations
231    };
232
233    TokenStream::from(expanded)
234}