dataclasses/
lib.rs

1//! A simple `Dataclass` derive macro inspired by Python's `dataclasses`.
2//!
3//! Features implemented:
4//! * Generate `pub fn new(...) -> Self` constructor which accepts values for non-default fields
5//! * Support field-level defaults via `#[dataclass(default)]` and `#[dataclass(default = "expr")]` (expr as string literal)
6//! * Implement `Clone`, `Debug`, `PartialEq`, `Eq` for the struct
7//! * Implement `Default` when all fields have defaults
8//!
9//! Examples:
10//!
11//! ```rust
12//! use dataclasses::Dataclass;
13//!
14//! #[derive(Dataclass)]
15//! struct Person {
16//!     name: String,
17//!     age: i32,
18//!     #[dataclass(default)]
19//!     nickname: Option<String>,
20//!     #[dataclass(default = "Vec::new()")]
21//!     tags: Vec<String>,
22//! }
23//!
24//! let p = Person::new("Alice".into(), 30);
25//! assert_eq!(p.nickname, None);
26//! ```
27use proc_macro::TokenStream;
28use quote::quote;
29use syn::{DeriveInput, Meta, parse_macro_input, spanned::Spanned};
30
31#[proc_macro_derive(Dataclass, attributes(dataclass))]
32pub fn dataclass_macro(input: TokenStream) -> TokenStream {
33    // Parse input
34    let input = parse_macro_input!(input as DeriveInput);
35
36    match impl_dataclass(&input) {
37        Ok(ts) => ts.into(),
38        Err(err) => err.to_compile_error().into(),
39    }
40}
41
42fn impl_dataclass(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
43    let name = &input.ident;
44    let generics = &input.generics;
45    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
46
47    // Only support structs with named fields
48    let fields = match &input.data {
49        syn::Data::Struct(ds) => match &ds.fields {
50            syn::Fields::Named(named) => &named.named,
51            _ => {
52                return Err(syn::Error::new_spanned(
53                    &input.ident,
54                    "Dataclass macro only supports structs with named fields",
55                ));
56            }
57        },
58        _ => {
59            return Err(syn::Error::new_spanned(
60                &input.ident,
61                "Dataclass macro requires a struct",
62            ));
63        }
64    };
65
66    // For each field collect info
67    struct FieldInfo {
68        ident: syn::Ident,
69        ty: syn::Type,
70        default: Option<proc_macro2::TokenStream>,
71    }
72
73    let mut infos = Vec::new();
74    for field in fields.iter() {
75        let ident = field
76            .ident
77            .clone()
78            .expect("named fields should have idents");
79        let ty = field.ty.clone();
80
81        let mut default = None;
82        for attr in &field.attrs {
83            if attr.path().is_ident("dataclass") {
84                // attr like #[dataclass(default)] or #[dataclass(default = "Vec::new()")]
85                // For syn 2: the meta tokens are in attr.meta or Meta::List(meta)
86                if let Meta::List(list) = &attr.meta {
87                    let tokens = list.tokens.to_string();
88                    // tokens look like "(default)" or "(default = \"Vec::new()\")"
89                    let inside = tokens.trim();
90                    let inside = inside.trim_start_matches('(').trim_end_matches(')');
91                    // Split on commas, but ignore commas inside double quotes
92                    let mut parts = Vec::new();
93                    let mut start = 0usize;
94                    let mut in_quotes = false;
95                    for (i, c) in inside.char_indices() {
96                        match c {
97                            '"' => in_quotes = !in_quotes,
98                            ',' if !in_quotes => {
99                                parts.push(inside[start..i].trim());
100                                start = i + 1;
101                            }
102                            _ => {}
103                        }
104                    }
105                    if start < inside.len() {
106                        parts.push(inside[start..].trim());
107                    }
108                    for part in parts.into_iter().filter(|s| !s.is_empty()) {
109                        if part == "default" {
110                            default = Some(quote! { ::core::default::Default::default() });
111                        } else if part.starts_with("default=") || part.starts_with("default =") {
112                            // find the literal string after '='
113                            if let Some(eq_idx) = part.find('=') {
114                                let rhs = part[eq_idx + 1..].trim();
115                                // strip possible surrounding quotes
116                                let rhs = if rhs.starts_with('"') && rhs.ends_with('"') {
117                                    &rhs[1..rhs.len() - 1]
118                                } else {
119                                    rhs
120                                };
121                                let expr: syn::Expr = syn::parse_str(rhs).map_err(|e| {
122                                    syn::Error::new(
123                                        field.span(),
124                                        format!("invalid default expression: {}", e),
125                                    )
126                                })?;
127                                default = Some(quote! { #expr });
128                            }
129                        } else {
130                            return Err(syn::Error::new(
131                                field.span(),
132                                "unknown dataclass attribute",
133                            ));
134                        }
135                    }
136                }
137            }
138        }
139
140        infos.push(FieldInfo { ident, ty, default });
141    }
142
143    // Build 'new' function params and body
144    let mut params = Vec::new();
145    let mut construct_fields = Vec::new();
146    let mut all_have_default = true;
147    for info in &infos {
148        let ident = &info.ident;
149        let ty = &info.ty;
150        if info.default.is_none() {
151            params.push(quote! { #ident: #ty });
152            construct_fields.push(quote! { #ident });
153            all_have_default = false;
154        } else {
155            let expr = info.default.as_ref().unwrap();
156            construct_fields.push(quote! { #ident: #expr });
157        }
158    }
159
160    // Collect clones for clone impl and fields for Debug/PartialEq
161    let field_idents: Vec<_> = infos.iter().map(|f| f.ident.clone()).collect();
162    let field_idents_ref: Vec<_> = field_idents.iter().collect();
163
164    // Determine generics type params for where clauses
165    let type_idents: Vec<syn::Ident> = generics
166        .params
167        .iter()
168        .filter_map(|p| match p {
169            syn::GenericParam::Type(ty) => Some(ty.ident.clone()),
170            _ => None,
171        })
172        .collect();
173
174    let mut clone_bounds = where_clause.cloned();
175    let mut debug_bounds = where_clause.cloned();
176    let mut partial_bounds = where_clause.cloned();
177    let mut eq_bounds = where_clause.cloned();
178    let mut default_bounds = where_clause.cloned();
179
180    if !type_idents.is_empty() {
181        let bounds_tokens =
182            quote! { #(#type_idents: Clone + std::fmt::Debug + PartialEq + Eq + Default),* };
183        clone_bounds = Some(syn::parse2(quote! { where #bounds_tokens })?);
184        debug_bounds = clone_bounds.clone();
185        partial_bounds = clone_bounds.clone();
186        eq_bounds = clone_bounds.clone();
187        default_bounds = clone_bounds.clone();
188    }
189
190    // Build impl tokens
191    let name_str = name.to_string();
192    let new_fn = quote! {
193        impl #impl_generics #name #ty_generics #where_clause {
194            pub fn new(#(#params),*) -> Self {
195                Self { #(#construct_fields),* }
196            }
197        }
198    };
199
200    // Clone impl
201    let clone_assigns = field_idents_ref
202        .iter()
203        .map(|ident| quote! { #ident: self.#ident.clone() });
204    let clone_impl = quote! {
205        impl #impl_generics Clone for #name #ty_generics #clone_bounds {
206            fn clone(&self) -> Self {
207                Self { #(#clone_assigns),* }
208            }
209        }
210    };
211
212    // Debug impl
213    let debug_fields = field_idents_ref
214        .iter()
215        .map(|ident| quote! { .field(stringify!(#ident), &self.#ident) });
216    let debug_impl = quote! {
217        impl #impl_generics std::fmt::Debug for #name #ty_generics #debug_bounds {
218            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219                f.debug_struct(#name_str)
220                    #(#debug_fields)*
221                    .finish()
222            }
223        }
224    };
225
226    // PartialEq impl
227    let eq_checks = field_idents_ref
228        .iter()
229        .map(|ident| quote! { self.#ident == other.#ident });
230    let eq_impl = quote! {
231        impl #impl_generics PartialEq for #name #ty_generics #partial_bounds {
232            fn eq(&self, other: &Self) -> bool {
233                #(#eq_checks)&&*
234            }
235        }
236        impl #impl_generics Eq for #name #ty_generics #eq_bounds {}
237    };
238
239    // Default impl only if all fields have a default expression
240    let default_impl = if all_have_default {
241        let default_assigns = infos.iter().map(|f| {
242            let id = &f.ident;
243            let expr = f.default.as_ref().unwrap();
244            quote! { #id: #expr }
245        });
246        Some(quote! {
247            impl #impl_generics Default for #name #ty_generics #default_bounds {
248                fn default() -> Self {
249                    Self { #(#default_assigns),* }
250                }
251            }
252        })
253    } else {
254        None
255    };
256
257    let expanded = quote! {
258        #new_fn
259        #clone_impl
260        #debug_impl
261        #eq_impl
262        #default_impl
263    };
264
265    Ok(expanded)
266}