1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::punctuated::Punctuated;
5use syn::token::Comma;
6use syn::{parse_macro_input, Attribute, Data, DeriveInput, Expr, Fields, Lit, Meta};
7
8#[derive(Default)]
10struct DataclassOptions {
11 init: bool,
12 repr: bool,
13 eq: bool,
14 order: bool,
15 unsafe_hash: bool,
16 frozen: bool,
17 match_args: bool,
18 kw_only: bool,
19 slots: bool,
20 weakref_slot: bool,
21}
22
23impl DataclassOptions {
24 fn from_meta_list(meta_list: Punctuated<Meta, Comma>) -> Self {
25 let mut options = DataclassOptions {
26 init: true, repr: true,
28 eq: true,
29 order: false,
30 unsafe_hash: false,
31 frozen: false,
32 match_args: true,
33 kw_only: false,
34 slots: false,
35 weakref_slot: false,
36 };
37
38 for meta in meta_list {
39 match meta {
40 Meta::NameValue(nv) => {
41 if let Some(ident) = nv.path.get_ident() {
42 let value = match nv.value {
43 Expr::Lit(expr_lit) => match expr_lit.lit {
44 Lit::Bool(lit_bool) => lit_bool.value(),
45 _ => panic!("Expected boolean value for option {}", ident),
46 },
47 _ => panic!("Expected literal value for option {}", ident),
48 };
49
50 match ident.to_string().as_str() {
51 "init" => options.init = value,
52 "repr" => options.repr = value,
53 "eq" => options.eq = value,
54 "order" => options.order = value,
55 "unsafe_hash" => options.unsafe_hash = value,
56 "kw_only" => options.kw_only = value,
57 "slots" => options.slots = value,
58 "frozen" => options.frozen = value,
59 "match_args" => options.match_args = value,
60 "weakref_slot" => options.weakref_slot = value,
61 _ => panic!("Unknown option: {}", ident),
62 }
63 }
64 }
65 _ => panic!("Expected name = value pair"),
66 }
67 }
68
69 options
70 }
71}
72
73fn has_serde_attribute(attrs: &[Attribute]) -> bool {
74 attrs.iter().any(|attr| {
75 if let Ok(Meta::Path(path)) = attr.parse_args::<Meta>() {
76 path.is_ident("serde")
77 } else {
78 false
79 }
80 })
81}
82
83#[proc_macro_attribute]
84pub fn dataclass(args: TokenStream, input: TokenStream) -> TokenStream {
85 let args =
86 parse_macro_input!(args with syn::punctuated::Punctuated::<Meta, Comma>::parse_terminated);
87 let mut input = parse_macro_input!(input as DeriveInput);
88
89 let options = DataclassOptions::from_meta_list(args);
90
91 if !has_serde_attribute(&input.attrs) {
93 input.attrs.push(syn::parse_quote!(
95 #[cfg_attr(feature = "serde", derive(::serde::Serialize, ::serde::Deserialize))]
96 ));
97 }
98
99 implement_dataclass(input, options)
100}
101
102fn implement_dataclass(input: DeriveInput, options: DataclassOptions) -> TokenStream {
103 let struct_name = &input.ident;
104 let attrs = &input.attrs;
105
106 let fields = match &input.data {
107 Data::Struct(data_struct) => match &data_struct.fields {
108 Fields::Named(fields_named) => &fields_named.named,
109 _ => panic!("Dataclass only works with named fields"),
110 },
111 _ => panic!("Dataclass only works with structs"),
112 };
113
114 let field_names: Vec<_> = fields
115 .iter()
116 .map(|field| field.ident.as_ref().unwrap())
117 .collect();
118 let field_types: Vec<_> = fields.iter().map(|field| &field.ty).collect();
119
120 let mut implementations = TokenStream2::new();
121
122 if options.init {
124 let constructor = if options.kw_only {
125 quote! {
126 impl #struct_name {
127 pub fn new(#(#field_names: #field_types),*) -> Self {
128 Self {
129 #(#field_names,)*
130 }
131 }
132 }
133 }
134 } else {
135 quote! {
136 impl #struct_name {
137 pub fn new(#(#field_names: #field_types),*) -> Self {
138 Self {
139 #(#field_names,)*
140 }
141 }
142 }
143 }
144 };
145 implementations.extend(constructor);
146 }
147
148 if options.repr {
150 let debug_impl = quote! {
151 impl std::fmt::Debug for #struct_name {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 f.debug_struct(stringify!(#struct_name))
154 #(.field(stringify!(#field_names), &self.#field_names))*
155 .finish()
156 }
157 }
158 };
159 implementations.extend(debug_impl);
160 }
161
162 if options.eq {
164 let eq_impl = quote! {
165 impl PartialEq for #struct_name {
166 fn eq(&self, other: &Self) -> bool {
167 #(self.#field_names == other.#field_names)&&*
168 }
169 }
170
171 impl Eq for #struct_name {}
172 };
173 implementations.extend(eq_impl);
174 }
175
176 if options.order {
178 let ord_impl = quote! {
179 impl PartialOrd for #struct_name {
180 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
181 Some(self.cmp(other))
182 }
183 }
184
185 impl Ord for #struct_name {
186 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
187 #(
188 if let std::cmp::Ordering::Equal = self.#field_names.cmp(&other.#field_names) {
189 } else {
190 return self.#field_names.cmp(&other.#field_names);
191 }
192 )*
193 std::cmp::Ordering::Equal
194 }
195 }
196 };
197 implementations.extend(ord_impl);
198 }
199
200 if options.unsafe_hash {
202 let hash_impl = quote! {
203 impl std::hash::Hash for #struct_name {
204 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
205 #(self.#field_names.hash(state);)*
206 }
207 }
208 };
209 implementations.extend(hash_impl);
210 }
211
212 let struct_fields = if options.frozen {
214 quote! {
215 #(pub(crate) #field_names: #field_types,)*
216 }
217 } else {
218 quote! {
219 #(pub #field_names: #field_types,)*
220 }
221 };
222
223 let expanded = quote! {
224 #[derive(Clone)]
225 #(#attrs)*
226 pub struct #struct_name {
227 #struct_fields
228 }
229
230 #implementations
231 };
232
233 TokenStream::from(expanded)
234}