Skip to main content

json_steroids_derive/
lib.rs

1//! Procedural macros for json-steroids
2//!
3//! Generates efficient serializers and deserializers for data structures.
4
5use proc_macro::TokenStream;
6use proc_macro2::TokenStream as TokenStream2;
7use proc_macro_crate::{crate_name, FoundCrate};
8use quote::{format_ident, quote};
9use syn::{parse_macro_input, Data, DeriveInput, Fields, Ident, Type};
10
11/// Get the crate path - correctly resolves whether we are inside
12/// the `json_steroids` crate itself or an external consumer.
13fn crate_path() -> TokenStream2 {
14    match crate_name("json-steroids") {
15        Ok(FoundCrate::Itself) => quote! { crate },
16        Ok(FoundCrate::Name(name)) => {
17            let ident = proc_macro2::Ident::new(&name, proc_macro2::Span::call_site());
18            quote! { ::#ident }
19        }
20        // Fallback: we are inside the crate being compiled (unit tests, benchmarks)
21        Err(_) => quote! { crate },
22    }
23}
24
25/// Derive macro for JSON serialization
26///
27/// # Example
28/// ```ignore
29/// #[derive(JsonSerialize)]
30/// struct Person {
31///     name: String,
32///     age: u32,
33/// }
34/// ```
35#[proc_macro_derive(JsonSerialize, attributes(json))]
36pub fn derive_json_serialize(input: TokenStream) -> TokenStream {
37    let input = parse_macro_input!(input as DeriveInput);
38    let name = &input.ident;
39    let generics = &input.generics;
40    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
41    let krate = crate_path();
42
43    let serialize_body = generate_serialize_body(&input.data, name, &krate);
44
45    let expanded = quote! {
46        impl #impl_generics #krate::JsonSerialize for #name #ty_generics #where_clause {
47            fn json_serialize<W: #krate::writer::Writer>(&self, writer: &mut #krate::JsonWriter<W>) {
48                #serialize_body
49            }
50        }
51    };
52
53    TokenStream::from(expanded)
54}
55
56/// Derive macro for JSON deserialization
57///
58/// # Example
59/// ```ignore
60/// #[derive(JsonDeserialize)]
61/// struct Person {
62///     name: String,
63///     age: u32,
64/// }
65/// ```
66#[proc_macro_derive(JsonDeserialize, attributes(json))]
67pub fn derive_json_deserialize(input: TokenStream) -> TokenStream {
68    let input = parse_macro_input!(input as DeriveInput);
69    let name = &input.ident;
70    let generics = &input.generics;
71    let krate = crate_path();
72
73    let deserialize_body = generate_deserialize_body(&input.data, name, &krate);
74
75    // Add 'de lifetime to generics only if it doesn't already exist
76    let mut generics_with_de = generics.clone();
77    let has_de_lifetime = generics.lifetimes().any(|lt| lt.lifetime.ident == "de");
78    if !has_de_lifetime {
79        generics_with_de.params.insert(0, syn::parse_quote!('de));
80    }
81    let (impl_generics, _, _) = generics_with_de.split_for_impl();
82    let (_, ty_generics, where_clause) = generics.split_for_impl();
83
84    let expanded = quote! {
85        impl #impl_generics #krate::JsonDeserialize<'de> for #name #ty_generics #where_clause {
86            fn json_deserialize(parser: &mut #krate::JsonParser<'de>) -> #krate::Result<Self> {
87                #deserialize_body
88            }
89        }
90    };
91
92    TokenStream::from(expanded)
93}
94
95/// Combined derive for both serialization and deserialization
96#[proc_macro_derive(Json, attributes(json))]
97pub fn derive_json(input: TokenStream) -> TokenStream {
98    let input = parse_macro_input!(input as DeriveInput);
99    let name = &input.ident;
100    let generics = &input.generics;
101    let krate = crate_path();
102
103    let serialize_body = generate_serialize_body(&input.data, name, &krate);
104    let deserialize_body = generate_deserialize_body(&input.data, name, &krate);
105
106    // For serialize: use normal generics
107    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
108
109    // For deserialize: add 'de lifetime only if it doesn't already exist
110    let mut generics_with_de = generics.clone();
111    let has_de_lifetime = generics.lifetimes().any(|lt| lt.lifetime.ident == "de");
112    if !has_de_lifetime {
113        generics_with_de.params.insert(0, syn::parse_quote!('de));
114    }
115    let (impl_generics_de, _, _) = generics_with_de.split_for_impl();
116
117    let expanded = quote! {
118        impl #impl_generics #krate::JsonSerialize for #name #ty_generics #where_clause {
119            fn json_serialize<W: #krate::writer::Writer>(&self, writer: &mut #krate::JsonWriter<W>) {
120                #serialize_body
121            }
122        }
123
124        impl #impl_generics_de #krate::JsonDeserialize<'de> for #name #ty_generics #where_clause {
125            fn json_deserialize(parser: &mut #krate::JsonParser<'de>) -> #krate::Result<Self> {
126                #deserialize_body
127            }
128        }
129    };
130
131    TokenStream::from(expanded)
132}
133
134fn generate_serialize_body(data: &Data, _name: &Ident, krate: &TokenStream2) -> TokenStream2 {
135    match data {
136        Data::Struct(data_struct) => {
137            match &data_struct.fields {
138                Fields::Named(fields) => {
139                    let field_serializations: Vec<TokenStream2> = fields.named.iter().enumerate().map(|(i, f)| {
140                        let field_name = f.ident.as_ref().unwrap();
141                        let field_name_str = get_field_name(&f.attrs, field_name);
142                        let is_first = i == 0;
143
144                        if is_first {
145                            quote! {
146                                writer.write_unescape_key(#field_name_str);
147                                #krate::JsonSerialize::json_serialize(&self.#field_name, writer);
148                            }
149                        } else {
150                            quote! {
151                                writer.write_comma();
152                                writer.write_unescape_key(#field_name_str);
153                                #krate::JsonSerialize::json_serialize(&self.#field_name, writer);
154                            }
155                        }
156                    }).collect();
157
158                    quote! {
159                        writer.begin_object();
160                        #(#field_serializations)*
161                        writer.end_object();
162                    }
163                }
164                Fields::Unnamed(fields) => {
165                    let field_serializations: Vec<TokenStream2> = (0..fields.unnamed.len())
166                        .enumerate()
167                        .map(|(i, idx)| {
168                            let index = syn::Index::from(idx);
169                            if i == 0 {
170                                quote! {
171                                    #krate::JsonSerialize::json_serialize(&self.#index, writer);
172                                }
173                            } else {
174                                quote! {
175                                    writer.write_comma();
176                                    #krate::JsonSerialize::json_serialize(&self.#index, writer);
177                                }
178                            }
179                        })
180                        .collect();
181
182                    quote! {
183                        writer.begin_array();
184                        #(#field_serializations)*
185                        writer.end_array();
186                    }
187                }
188                Fields::Unit => {
189                    quote! { writer.write_null(); }
190                }
191            }
192        }
193        Data::Enum(data_enum) => {
194            let variants: Vec<TokenStream2> = data_enum.variants.iter().map(|variant| {
195                let variant_name = &variant.ident;
196                let variant_name_str = variant_name.to_string();
197
198                match &variant.fields {
199                    Fields::Unit => {
200                        quote! {
201                            Self::#variant_name => {
202                                writer.write_string(#variant_name_str);
203                            }
204                        }
205                    }
206                    Fields::Unnamed(fields) => {
207                        let field_names: Vec<Ident> = (0..fields.unnamed.len())
208                            .map(|i| format_ident!("f{}", i))
209                            .collect();
210                        let field_serializations: Vec<TokenStream2> = field_names.iter().enumerate().map(|(i, name)| {
211                            if i == 0 {
212                                quote! { #krate::JsonSerialize::json_serialize(#name, writer); }
213                            } else {
214                                quote! {
215                                    writer.write_comma();
216                                    #krate::JsonSerialize::json_serialize(#name, writer);
217                                }
218                            }
219                        }).collect();
220
221                        quote! {
222                            Self::#variant_name(#(#field_names),*) => {
223                                writer.begin_object();
224                                writer.write_key(#variant_name_str);
225                                writer.begin_array();
226                                #(#field_serializations)*
227                                writer.end_array();
228                                writer.end_object();
229                            }
230                        }
231                    }
232                    Fields::Named(fields) => {
233                        let field_names: Vec<&Ident> = fields.named.iter()
234                            .map(|f| f.ident.as_ref().unwrap())
235                            .collect();
236                        let field_serializations: Vec<TokenStream2> = field_names.iter().enumerate().map(|(i, name)| {
237                            let name_str = name.to_string();
238                            if i == 0 {
239                                quote! {
240                                    writer.write_key(#name_str);
241                                    #krate::JsonSerialize::json_serialize(#name, writer);
242                                }
243                            } else {
244                                quote! {
245                                    writer.write_comma();
246                                    writer.write_unescape_key(#name_str);
247                                    #krate::JsonSerialize::json_serialize(#name, writer);
248                                }
249                            }
250                        }).collect();
251
252                        quote! {
253                            Self::#variant_name { #(#field_names),* } => {
254                                writer.begin_object();
255                                writer.write_key(#variant_name_str);
256                                writer.begin_object();
257                                #(#field_serializations)*
258                                writer.end_object();
259                                writer.end_object();
260                            }
261                        }
262                    }
263                }
264            }).collect();
265
266            quote! {
267                match self {
268                    #(#variants)*
269                }
270            }
271        }
272        Data::Union(_) => {
273            quote! { compile_error!("Unions are not supported"); }
274        }
275    }
276}
277
278fn generate_deserialize_body(data: &Data, name: &Ident, krate: &TokenStream2) -> TokenStream2 {
279    match data {
280        Data::Struct(data_struct) => match &data_struct.fields {
281            Fields::Named(fields) => {
282                let field_declarations: Vec<TokenStream2> = fields
283                    .named
284                    .iter()
285                    .map(|f| {
286                        let field_name = f.ident.as_ref().unwrap();
287                        quote! { let mut #field_name = None; }
288                    })
289                    .collect();
290
291                let field_matches: Vec<TokenStream2> = fields.named.iter().map(|f| {
292                        let field_name = f.ident.as_ref().unwrap();
293                        let field_name_str = get_field_name(&f.attrs, field_name);
294                        quote! {
295                            #field_name_str => {
296                                #field_name = Some(#krate::JsonDeserialize::json_deserialize(parser)?);
297                            }
298                        }
299                    }).collect();
300
301                let field_unwraps: Vec<TokenStream2> = fields.named.iter().map(|f| {
302                        let field_name = f.ident.as_ref().unwrap();
303                        let field_name_str = field_name.to_string();
304                        let is_option = is_option_type(&f.ty);
305
306                        if is_option {
307                            quote! {
308                                #field_name: #field_name.unwrap_or(None)
309                            }
310                        } else {
311                            quote! {
312                                #field_name: #field_name.ok_or_else(|| #krate::JsonError::MissingField(#field_name_str.to_string()))?
313                            }
314                        }
315                    }).collect();
316
317                quote! {
318                    parser.expect_object_start()?;
319                    #(#field_declarations)*
320
321                    loop {
322                        match parser.next_object_key()? {
323                            Some(key) => {
324                                match key.as_ref() {
325                                    #(#field_matches)*
326                                    _ => {
327                                        parser.skip_value()?;
328                                    }
329                                }
330                            }
331                            None => break,
332                        }
333                    }
334
335                    parser.expect_object_end()?;
336
337                    Ok(#name {
338                        #(#field_unwraps),*
339                    })
340                }
341            }
342            Fields::Unnamed(fields) => {
343                let field_deserializations: Vec<TokenStream2> = (0..fields.unnamed.len())
344                    .enumerate()
345                    .map(|(i, _)| {
346                        if i == 0 {
347                            quote! { #krate::JsonDeserialize::json_deserialize(parser)? }
348                        } else {
349                            quote! {
350                                {
351                                    parser.expect_comma()?;
352                                    #krate::JsonDeserialize::json_deserialize(parser)?
353                                }
354                            }
355                        }
356                    })
357                    .collect();
358
359                quote! {
360                    parser.expect_array_start()?;
361                    let result = #name(#(#field_deserializations),*);
362                    parser.expect_array_end()?;
363                    Ok(result)
364                }
365            }
366            Fields::Unit => {
367                quote! {
368                    parser.expect_null()?;
369                    Ok(#name)
370                }
371            }
372        },
373        Data::Enum(data_enum) => {
374            let variant_matches: Vec<TokenStream2> = data_enum.variants.iter().map(|variant| {
375                let variant_name = &variant.ident;
376                let variant_name_str = variant_name.to_string();
377
378                match &variant.fields {
379                    Fields::Unit => {
380                        quote! {
381                            #variant_name_str => Ok(#name::#variant_name)
382                        }
383                    }
384                    Fields::Unnamed(fields) => {
385                        let field_deserializations: Vec<TokenStream2> = (0..fields.unnamed.len()).enumerate().map(|(i, _)| {
386                            if i == 0 {
387                                quote! { #krate::JsonDeserialize::json_deserialize(parser)? }
388                            } else {
389                                quote! {
390                                    {
391                                        parser.expect_comma()?;
392                                        #krate::JsonDeserialize::json_deserialize(parser)?
393                                    }
394                                }
395                            }
396                        }).collect();
397
398                        quote! {
399                            #variant_name_str => {
400                                parser.expect_array_start()?;
401                                let result = #name::#variant_name(#(#field_deserializations),*);
402                                parser.expect_array_end()?;
403                                Ok(result)
404                            }
405                        }
406                    }
407                    Fields::Named(fields) => {
408                        let field_declarations: Vec<TokenStream2> = fields.named.iter().map(|f| {
409                            let field_name = f.ident.as_ref().unwrap();
410                            quote! { let mut #field_name = None; }
411                        }).collect();
412
413                        let field_matches: Vec<TokenStream2> = fields.named.iter().map(|f| {
414                            let field_name = f.ident.as_ref().unwrap();
415                            let field_name_str = field_name.to_string();
416                            quote! {
417                                #field_name_str => {
418                                    #field_name = Some(#krate::JsonDeserialize::json_deserialize(parser)?);
419                                }
420                            }
421                        }).collect();
422
423                        let field_unwraps: Vec<TokenStream2> = fields.named.iter().map(|f| {
424                            let field_name = f.ident.as_ref().unwrap();
425                            let field_name_str = field_name.to_string();
426                            quote! {
427                                #field_name: #field_name.ok_or_else(|| #krate::JsonError::MissingField(#field_name_str.to_string()))?
428                            }
429                        }).collect();
430
431                        quote! {
432                            #variant_name_str => {
433                                parser.expect_object_start()?;
434                                #(#field_declarations)*
435
436                                loop {
437                                    match parser.next_object_key()? {
438                                        Some(key) => {
439                                            match key.as_ref() {
440                                                #(#field_matches)*
441                                                _ => { parser.skip_value()?; }
442                                            }
443                                        }
444                                        None => break,
445                                    }
446                                }
447
448                                Ok(#name::#variant_name {
449                                    #(#field_unwraps),*
450                                })
451                            }
452                        }
453                    }
454                }
455            }).collect();
456
457            quote! {
458                // Try string first (for unit variants)
459                if parser.peek_is_string()? {
460                    let variant_str = parser.parse_string()?;
461                    match variant_str.as_ref() {
462                        #(#variant_matches),*,
463                        _ => Err(#krate::JsonError::UnknownVariant(variant_str.to_string()))
464                    }
465                } else {
466                    // Object format: {"VariantName": ...}
467                    parser.expect_object_start()?;
468                    let key = parser.next_object_key()?.ok_or(#krate::JsonError::UnexpectedEnd)?;
469                    let result = match key.as_ref() {
470                        #(#variant_matches),*,
471                        _ => Err(#krate::JsonError::UnknownVariant(key.to_string()))
472                    };
473                    parser.expect_object_end()?;
474                    result
475                }
476            }
477        }
478        Data::Union(_) => {
479            quote! { compile_error!("Unions are not supported"); }
480        }
481    }
482}
483
484fn get_field_name(attrs: &[syn::Attribute], default: &Ident) -> String {
485    for attr in attrs {
486        if attr.path().is_ident("json") {
487            if let syn::Meta::List(meta_list) = &attr.meta {
488                let tokens = meta_list.tokens.to_string();
489                if let Some(name) = tokens.strip_prefix("rename = ") {
490                    return name.trim_matches('"').to_string();
491                }
492            }
493        }
494    }
495    default.to_string()
496}
497
498fn is_option_type(ty: &Type) -> bool {
499    if let Type::Path(type_path) = ty {
500        if let Some(segment) = type_path.path.segments.last() {
501            return segment.ident == "Option";
502        }
503    }
504    false
505}