1use proc_macro::TokenStream;
28use quote::quote;
29use syn::{DeriveInput, Meta, parse_macro_input, spanned::Spanned};
30
31#[proc_macro_derive(Dataclass, attributes(dataclass))]
32pub fn dataclass_macro(input: TokenStream) -> TokenStream {
33 let input = parse_macro_input!(input as DeriveInput);
35
36 match impl_dataclass(&input) {
37 Ok(ts) => ts.into(),
38 Err(err) => err.to_compile_error().into(),
39 }
40}
41
42fn impl_dataclass(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
43 let name = &input.ident;
44 let generics = &input.generics;
45 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
46
47 let fields = match &input.data {
49 syn::Data::Struct(ds) => match &ds.fields {
50 syn::Fields::Named(named) => &named.named,
51 _ => {
52 return Err(syn::Error::new_spanned(
53 &input.ident,
54 "Dataclass macro only supports structs with named fields",
55 ));
56 }
57 },
58 _ => {
59 return Err(syn::Error::new_spanned(
60 &input.ident,
61 "Dataclass macro requires a struct",
62 ));
63 }
64 };
65
66 struct FieldInfo {
68 ident: syn::Ident,
69 ty: syn::Type,
70 default: Option<proc_macro2::TokenStream>,
71 }
72
73 let mut infos = Vec::new();
74 for field in fields.iter() {
75 let ident = field
76 .ident
77 .clone()
78 .expect("named fields should have idents");
79 let ty = field.ty.clone();
80
81 let mut default = None;
82 for attr in &field.attrs {
83 if attr.path().is_ident("dataclass") {
84 if let Meta::List(list) = &attr.meta {
87 let tokens = list.tokens.to_string();
88 let inside = tokens.trim();
90 let inside = inside.trim_start_matches('(').trim_end_matches(')');
91 let mut parts = Vec::new();
93 let mut start = 0usize;
94 let mut in_quotes = false;
95 for (i, c) in inside.char_indices() {
96 match c {
97 '"' => in_quotes = !in_quotes,
98 ',' if !in_quotes => {
99 parts.push(inside[start..i].trim());
100 start = i + 1;
101 }
102 _ => {}
103 }
104 }
105 if start < inside.len() {
106 parts.push(inside[start..].trim());
107 }
108 for part in parts.into_iter().filter(|s| !s.is_empty()) {
109 if part == "default" {
110 default = Some(quote! { ::core::default::Default::default() });
111 } else if part.starts_with("default=") || part.starts_with("default =") {
112 if let Some(eq_idx) = part.find('=') {
114 let rhs = part[eq_idx + 1..].trim();
115 let rhs = if rhs.starts_with('"') && rhs.ends_with('"') {
117 &rhs[1..rhs.len() - 1]
118 } else {
119 rhs
120 };
121 let expr: syn::Expr = syn::parse_str(rhs).map_err(|e| {
122 syn::Error::new(
123 field.span(),
124 format!("invalid default expression: {}", e),
125 )
126 })?;
127 default = Some(quote! { #expr });
128 }
129 } else {
130 return Err(syn::Error::new(
131 field.span(),
132 "unknown dataclass attribute",
133 ));
134 }
135 }
136 }
137 }
138 }
139
140 infos.push(FieldInfo { ident, ty, default });
141 }
142
143 let mut params = Vec::new();
145 let mut construct_fields = Vec::new();
146 let mut all_have_default = true;
147 for info in &infos {
148 let ident = &info.ident;
149 let ty = &info.ty;
150 if info.default.is_none() {
151 params.push(quote! { #ident: #ty });
152 construct_fields.push(quote! { #ident });
153 all_have_default = false;
154 } else {
155 let expr = info.default.as_ref().unwrap();
156 construct_fields.push(quote! { #ident: #expr });
157 }
158 }
159
160 let field_idents: Vec<_> = infos.iter().map(|f| f.ident.clone()).collect();
162 let field_idents_ref: Vec<_> = field_idents.iter().collect();
163
164 let type_idents: Vec<syn::Ident> = generics
166 .params
167 .iter()
168 .filter_map(|p| match p {
169 syn::GenericParam::Type(ty) => Some(ty.ident.clone()),
170 _ => None,
171 })
172 .collect();
173
174 let mut clone_bounds = where_clause.cloned();
175 let mut debug_bounds = where_clause.cloned();
176 let mut partial_bounds = where_clause.cloned();
177 let mut eq_bounds = where_clause.cloned();
178 let mut default_bounds = where_clause.cloned();
179
180 if !type_idents.is_empty() {
181 let bounds_tokens =
182 quote! { #(#type_idents: Clone + std::fmt::Debug + PartialEq + Eq + Default),* };
183 clone_bounds = Some(syn::parse2(quote! { where #bounds_tokens })?);
184 debug_bounds = clone_bounds.clone();
185 partial_bounds = clone_bounds.clone();
186 eq_bounds = clone_bounds.clone();
187 default_bounds = clone_bounds.clone();
188 }
189
190 let name_str = name.to_string();
192 let new_fn = quote! {
193 impl #impl_generics #name #ty_generics #where_clause {
194 pub fn new(#(#params),*) -> Self {
195 Self { #(#construct_fields),* }
196 }
197 }
198 };
199
200 let clone_assigns = field_idents_ref
202 .iter()
203 .map(|ident| quote! { #ident: self.#ident.clone() });
204 let clone_impl = quote! {
205 impl #impl_generics Clone for #name #ty_generics #clone_bounds {
206 fn clone(&self) -> Self {
207 Self { #(#clone_assigns),* }
208 }
209 }
210 };
211
212 let debug_fields = field_idents_ref
214 .iter()
215 .map(|ident| quote! { .field(stringify!(#ident), &self.#ident) });
216 let debug_impl = quote! {
217 impl #impl_generics std::fmt::Debug for #name #ty_generics #debug_bounds {
218 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219 f.debug_struct(#name_str)
220 #(#debug_fields)*
221 .finish()
222 }
223 }
224 };
225
226 let eq_checks = field_idents_ref
228 .iter()
229 .map(|ident| quote! { self.#ident == other.#ident });
230 let eq_impl = quote! {
231 impl #impl_generics PartialEq for #name #ty_generics #partial_bounds {
232 fn eq(&self, other: &Self) -> bool {
233 #(#eq_checks)&&*
234 }
235 }
236 impl #impl_generics Eq for #name #ty_generics #eq_bounds {}
237 };
238
239 let default_impl = if all_have_default {
241 let default_assigns = infos.iter().map(|f| {
242 let id = &f.ident;
243 let expr = f.default.as_ref().unwrap();
244 quote! { #id: #expr }
245 });
246 Some(quote! {
247 impl #impl_generics Default for #name #ty_generics #default_bounds {
248 fn default() -> Self {
249 Self { #(#default_assigns),* }
250 }
251 }
252 })
253 } else {
254 None
255 };
256
257 let expanded = quote! {
258 #new_fn
259 #clone_impl
260 #debug_impl
261 #eq_impl
262 #default_impl
263 };
264
265 Ok(expanded)
266}