near_schema_checker_macro/
lib.rs

1use proc_macro::TokenStream;
2
3#[proc_macro_derive(ProtocolSchema)]
4pub fn protocol_schema(input: TokenStream) -> TokenStream {
5    helper::protocol_schema_impl(input)
6}
7
8#[cfg(all(enable_const_type_id, feature = "protocol_schema"))]
9mod helper {
10    use proc_macro::TokenStream;
11    use proc_macro2::TokenStream as TokenStream2;
12    use quote::{format_ident, quote};
13    use syn::{
14        Data, DeriveInput, Field, Fields, FieldsNamed, FieldsUnnamed, GenericArgument,
15        GenericParam, Generics, Index, Path, PathArguments, PathSegment, Type, TypePath, Variant,
16        parse_macro_input,
17    };
18
19    pub fn protocol_schema_impl(input: TokenStream) -> TokenStream {
20        let input = parse_macro_input!(input as DeriveInput);
21        let name = &input.ident;
22        let info_name = format_ident!("{}_INFO", name);
23        let generics = &input.generics;
24
25        let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
26
27        // Create a version of ty_generics without lifetimes for TypeId
28        let ty_generics_without_lifetimes = remove_lifetimes(generics);
29
30        let type_id = quote! { std::any::TypeId::of::<#name #ty_generics_without_lifetimes>() };
31        let info = match &input.data {
32            Data::Struct(data_struct) => {
33                let fields = extract_struct_fields(&data_struct.fields);
34                quote! {
35                    near_schema_checker_lib::ProtocolSchemaInfo::Struct {
36                        name: stringify!(#name),
37                        type_id: #type_id,
38                        fields: #fields,
39                    }
40                }
41            }
42            Data::Enum(data_enum) => {
43                let variants = extract_enum_variants(&data_enum.variants);
44                quote! {
45                    near_schema_checker_lib::ProtocolSchemaInfo::Enum {
46                        name: stringify!(#name),
47                        type_id: #type_id,
48                        variants: #variants,
49                    }
50                }
51            }
52            Data::Union(_) => panic!("Unions are not supported"),
53        };
54
55        let expanded = quote! {
56            #[allow(non_upper_case_globals)]
57            pub static #info_name: near_schema_checker_lib::ProtocolSchemaInfo = #info;
58
59            near_schema_checker_lib::inventory::submit! {
60                #info_name
61            }
62
63            impl #impl_generics near_schema_checker_lib::ProtocolSchema for #name #ty_generics #where_clause {
64                fn ensure_registration() {}
65            }
66        };
67
68        TokenStream::from(expanded)
69    }
70
71    fn extract_struct_fields(fields: &Fields) -> TokenStream2 {
72        match fields {
73            Fields::Named(FieldsNamed { named, .. }) => {
74                let fields = extract_from_named_fields(named);
75                quote! { &[#(#fields),*] }
76            }
77            Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
78                let fields = extract_from_unnamed_fields(unnamed);
79                quote! { &[#(#fields),*] }
80            }
81            Fields::Unit => quote! { &[] },
82        }
83    }
84
85    fn extract_enum_variants(
86        variants: &syn::punctuated::Punctuated<Variant, syn::token::Comma>,
87    ) -> TokenStream2 {
88        let variants = variants.iter().enumerate().map(|(idx, v)| {
89            let name = &v.ident;
90            let discriminant = match &v.discriminant {
91                Some((_, expr)) => quote! { #expr as _ },
92                None => quote! { #idx as _ },
93            };
94            let fields = match &v.fields {
95                Fields::Named(FieldsNamed { named, .. }) => {
96                    let fields = extract_from_named_fields(named);
97                    quote! { Some(&[#(#fields),*]) }
98                }
99                Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
100                    let fields = extract_from_unnamed_fields(unnamed);
101                    quote! { Some(&[#(#fields),*]) }
102                }
103                Fields::Unit => quote! { None },
104            };
105            quote! { (#discriminant, stringify!(#name), #fields) }
106        });
107        quote! { &[#(#variants),*] }
108    }
109
110    /// Extracts type ids from the type and **all** its underlying generic
111    /// parameters, recursively.
112    /// For example, for `Vec<Vec<u32>>` it will return `[Vec, Vec, u32]`.
113    fn extract_type_ids_from_type(ty: &Type) -> Vec<TokenStream2> {
114        let mut result = vec![quote! { std::any::TypeId::of::<#ty>() }];
115        let type_path = match ty {
116            Type::Path(type_path) => type_path,
117            _ => return result,
118        };
119
120        // TODO (#11755): last segment does not necessarily cover all generics.
121        // For example, consider `<Apple as Fruit<Round>>::AssocType`. Here
122        // `AssocType` in `impl Fruit<Round>` for `Apple` can be a `Vec<Round>`
123        // or any other instantiation of a generic type.
124        // Not urgent because protocol structs are expected to be simple.
125        let generic_params = &type_path.path.segments.last().unwrap().arguments;
126        let params = match generic_params {
127            PathArguments::AngleBracketed(params) => params,
128            _ => return result,
129        };
130
131        let inner_type_ids = params
132            .args
133            .iter()
134            .map(|arg| {
135                if let GenericArgument::Type(ty) = arg {
136                    extract_type_ids_from_type(ty)
137                } else {
138                    vec![]
139                }
140            })
141            .flatten()
142            .collect::<Vec<_>>();
143        result.extend(inner_type_ids);
144        result
145    }
146
147    fn extract_type_info(ty: &Type) -> TokenStream2 {
148        match ty {
149            Type::Path(type_path) => {
150                let type_name = &type_path.path.segments.last().unwrap().ident;
151                let type_without_lifetimes = remove_lifetimes_from_type(type_path);
152                let type_ids = extract_type_ids_from_type(&type_without_lifetimes);
153                let type_ids_count = type_ids.len();
154
155                quote! {
156                    {
157                        const TYPE_IDS_COUNT: usize = #type_ids_count;
158                        const fn create_array() -> [std::any::TypeId; TYPE_IDS_COUNT] {
159                            [#(#type_ids),*]
160                        }
161                        (stringify!(#type_name), &create_array())
162                    }
163                }
164            }
165            Type::Reference(type_ref) => {
166                let elem = &type_ref.elem;
167                extract_type_info(elem)
168            }
169            Type::Array(array) => {
170                let elem = &array.elem;
171                let len = &array.len;
172                quote! {
173                    (stringify!([#elem; #len]), &[std::any::TypeId::of::<#elem>()])
174                }
175            }
176            Type::Slice(slice) => {
177                let elem = &slice.elem;
178                quote! {
179                    (stringify!([#elem]), &[std::any::TypeId::of::<#elem>()])
180                }
181            }
182            Type::Tuple(tuple) => {
183                quote! { (stringify!(#tuple), &[std::any::TypeId::of::<#tuple>()]) }
184            }
185            _ => {
186                println!("Unsupported type: {:?}", ty);
187                quote! { (stringify!(#ty), &[std::any::TypeId::of::<#ty>()]) }
188            }
189        }
190    }
191
192    fn remove_lifetimes_from_type(type_path: &TypePath) -> Type {
193        let segments = type_path.path.segments.iter().map(|segment| {
194            let mut new_segment =
195                PathSegment { ident: segment.ident.clone(), arguments: PathArguments::None };
196
197            if let PathArguments::AngleBracketed(args) = &segment.arguments {
198                let new_args: Vec<_> = args
199                    .args
200                    .iter()
201                    .filter_map(|arg| match arg {
202                        GenericArgument::Type(ty) => {
203                            Some(GenericArgument::Type(remove_lifetimes_from_type_recursive(ty)))
204                        }
205                        GenericArgument::Const(c) => Some(GenericArgument::Const(c.clone())),
206                        _ => None,
207                    })
208                    .collect();
209
210                if !new_args.is_empty() {
211                    new_segment.arguments =
212                        PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
213                            colon2_token: args.colon2_token,
214                            lt_token: args.lt_token,
215                            args: new_args.into_iter().collect(),
216                            gt_token: args.gt_token,
217                        });
218                }
219            }
220
221            new_segment
222        });
223
224        // cspell:ignore qself
225        Type::Path(TypePath {
226            qself: type_path.qself.clone(),
227            path: Path {
228                leading_colon: type_path.path.leading_colon,
229                segments: segments.collect(),
230            },
231        })
232    }
233
234    fn remove_lifetimes_from_type_recursive(ty: &Type) -> Type {
235        match ty {
236            Type::Path(type_path) => remove_lifetimes_from_type(type_path),
237            Type::Reference(type_ref) => Type::Reference(syn::TypeReference {
238                and_token: type_ref.and_token,
239                lifetime: None,
240                mutability: type_ref.mutability,
241                elem: Box::new(remove_lifetimes_from_type_recursive(&type_ref.elem)),
242            }),
243            _ => ty.clone(),
244        }
245    }
246
247    fn extract_from_named_fields(
248        named: &syn::punctuated::Punctuated<Field, syn::token::Comma>,
249    ) -> impl Iterator<Item = TokenStream2> + '_ {
250        named.iter().map(|f| {
251            let name = &f.ident;
252            let ty = &f.ty;
253            let type_info = extract_type_info(ty);
254            quote! { (stringify!(#name), #type_info) }
255        })
256    }
257
258    fn extract_from_unnamed_fields(
259        unnamed: &syn::punctuated::Punctuated<Field, syn::token::Comma>,
260    ) -> impl Iterator<Item = TokenStream2> + '_ {
261        unnamed.iter().enumerate().map(|(i, f)| {
262            let index = Index::from(i);
263            let ty = &f.ty;
264            let type_info = extract_type_info(ty);
265            quote! { (stringify!(#index), #type_info) }
266        })
267    }
268
269    fn remove_lifetimes(generics: &Generics) -> proc_macro2::TokenStream {
270        let params: Vec<_> = generics
271            .params
272            .iter()
273            .filter_map(|param| match param {
274                GenericParam::Type(type_param) => Some(quote! { #type_param }),
275                GenericParam::Const(const_param) => Some(quote! { #const_param }),
276                GenericParam::Lifetime(_) => None,
277            })
278            .collect();
279
280        if !params.is_empty() {
281            quote! { <#(#params),*> }
282        } else {
283            quote! {}
284        }
285    }
286}
287
288#[cfg(not(all(enable_const_type_id, feature = "protocol_schema")))]
289mod helper {
290    use proc_macro::TokenStream;
291
292    pub fn protocol_schema_impl(_input: TokenStream) -> TokenStream {
293        TokenStream::new()
294    }
295}