serde_inner_serialize_core/
lib.rs

1use proc_macro2::TokenStream;
2use quote::{quote, ToTokens};
3use syn::{parse2, AngleBracketedGenericArguments, Data, DeriveInput, GenericArgument, Turbofish, Type };
4
5mod tests;
6
7pub trait InnerSerializableTrait {
8    const TYPE_NAME: &'static str;
9    fn count_fields() -> usize;
10    fn inner_serialize<S>(&self, state: &mut S) -> Result<(), S::Error>
11    where
12        S: serde::ser::SerializeStruct;
13}
14
15pub trait OuterSerializableTrait<T> where T: InnerSerializableTrait {
16    const TYPE_NAME: &'static str;
17
18    fn get_full_type_name(&self) -> &'static str;
19    fn _get_full_type_name(&self) -> String;
20}
21
22pub fn inner_serializable_core(input: TokenStream) -> TokenStream {
23
24    let input = parse2::<DeriveInput>(input).unwrap();
25    let name = &input.ident;
26    let generics = &input.generics;
27    let (impl_generics, ty_generics, where_clause) = &generics.split_for_impl();
28
29    let type_name_impl = if cfg!(feature = "const_type_name") {
30        quote! {
31            const TYPE_NAME: &'static str = std::any::type_name::<#name>();
32        }
33    } else {
34        quote! {
35            const TYPE_NAME: &'static str = stringify!(#name);
36        }
37    };
38
39    let count = if let syn::Data::Struct(data) = &input.data {
40        data.fields.iter().count()
41    } else {
42        0
43    };
44
45    let serialize_fields = if let Data::Struct(data) = &input.data {
46        data.fields.iter().map(|field| {
47            let field_name = &field.ident;
48            let field_name_str = field_name.as_ref().unwrap().to_string();
49            quote! {
50                state.serialize_field(#field_name_str, &self.#field_name)?;
51            }
52        }).collect::<Vec<_>>().into_iter()
53    } else {
54        Vec::new().into_iter()
55    };
56    
57    let expanded = quote! {      
58
59        impl #impl_generics serde_inner_serialize::InnerSerializableTrait for #name #ty_generics #where_clause {
60    
61            #type_name_impl
62
63            fn count_fields() -> usize {
64                #count
65            }
66
67            fn inner_serialize<S>(&self, state: &mut S) -> Result<(), S::Error>
68            where
69                S: serde::ser::SerializeStruct,
70            {
71                #(#serialize_fields)*
72                Ok(())
73            }
74        }
75    };
76
77    TokenStream::from(expanded)
78
79}
80
81fn extract_types_from_turbofish(turbofish: &Turbofish) -> Vec<Type> {
82    let ts = turbofish.to_token_stream();
83    let abga = parse2::<AngleBracketedGenericArguments>(ts).unwrap();
84    abga.args.iter()
85        .filter_map(|arg| match arg {
86            // Only look at type arguments, ignore lifetimes and const params
87            GenericArgument::Type(ty) => Some(ty.clone()),
88            _ => None,
89        })
90        .collect()
91}
92
93
94pub fn outer_serializable_core(input: TokenStream) -> TokenStream {
95    let input = parse2::<DeriveInput>(input).unwrap();
96    let name = &input.ident;
97    let generics = &input.generics;
98    let (impl_generics, ty_generics, where_clause) = &generics.split_for_impl();
99    let turbofish = ty_generics.as_turbofish();
100    let ty_inner_type = extract_types_from_turbofish(&turbofish)[0].clone();
101
102    let type_name_impl = if cfg!(feature = "const_type_name") {
103        quote! {
104            const TYPE_NAME: &'static str = std::any::type_name::<#name>();
105        }
106    } else {
107        quote! {
108            const TYPE_NAME: &'static str = stringify!(#name);
109        }
110    };
111
112    let expanded = quote! {
113        impl #impl_generics OuterSerializableTrait<#ty_inner_type> for #name #ty_generics #where_clause {
114    
115            #type_name_impl
116
117            fn _get_full_type_name(&self) -> String {
118                let mut s = String::from(#name #turbofish :: TYPE_NAME);
119                s.push_str("->");
120                s.push_str(#ty_inner_type :: TYPE_NAME);
121                s
122            }
123
124            fn get_full_type_name(&self) -> &'static str {
125                static TYPEMAP: std::sync::LazyLock<std::sync::Mutex<std::collections::HashMap<&'static str, &'static str>>> = std::sync::LazyLock::new(|| std::sync::Mutex::new(std::collections::HashMap::<&'static str, &'static str>::new()));
126                let mut jtn = TYPEMAP.lock().unwrap();
127                let e = jtn.entry(#ty_inner_type :: TYPE_NAME);
128                e.or_insert_with(|| {
129                    let s : &str = &self._get_full_type_name();
130                    let leaked_str: &'static str = s.to_string().leak();
131                    // println!("LEAKING: {} @ {:p}", leaked_str, leaked_str);
132                    leaked_str
133                })
134            }
135        }
136    };
137    TokenStream::from(expanded)
138}
139