near_schema_checker_macro/
lib.rs1use proc_macro::TokenStream;
2
3#[proc_macro_derive(ProtocolSchema)]
4pub fn protocol_schema(input: TokenStream) -> TokenStream {
5 helper::protocol_schema_impl(input)
6}
7
8#[cfg(all(enable_const_type_id, feature = "protocol_schema"))]
9mod helper {
10 use proc_macro::TokenStream;
11 use proc_macro2::TokenStream as TokenStream2;
12 use quote::{format_ident, quote};
13 use syn::{
14 Data, DeriveInput, Field, Fields, FieldsNamed, FieldsUnnamed, GenericArgument,
15 GenericParam, Generics, Index, Path, PathArguments, PathSegment, Type, TypePath, Variant,
16 parse_macro_input,
17 };
18
19 pub fn protocol_schema_impl(input: TokenStream) -> TokenStream {
20 let input = parse_macro_input!(input as DeriveInput);
21 let name = &input.ident;
22 let info_name = format_ident!("{}_INFO", name);
23 let generics = &input.generics;
24
25 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
26
27 let ty_generics_without_lifetimes = remove_lifetimes(generics);
29
30 let type_id = quote! { std::any::TypeId::of::<#name #ty_generics_without_lifetimes>() };
31 let info = match &input.data {
32 Data::Struct(data_struct) => {
33 let fields = extract_struct_fields(&data_struct.fields);
34 quote! {
35 near_schema_checker_lib::ProtocolSchemaInfo::Struct {
36 name: stringify!(#name),
37 type_id: #type_id,
38 fields: #fields,
39 }
40 }
41 }
42 Data::Enum(data_enum) => {
43 let variants = extract_enum_variants(&data_enum.variants);
44 quote! {
45 near_schema_checker_lib::ProtocolSchemaInfo::Enum {
46 name: stringify!(#name),
47 type_id: #type_id,
48 variants: #variants,
49 }
50 }
51 }
52 Data::Union(_) => panic!("Unions are not supported"),
53 };
54
55 let expanded = quote! {
56 #[allow(non_upper_case_globals)]
57 pub static #info_name: near_schema_checker_lib::ProtocolSchemaInfo = #info;
58
59 near_schema_checker_lib::inventory::submit! {
60 #info_name
61 }
62
63 impl #impl_generics near_schema_checker_lib::ProtocolSchema for #name #ty_generics #where_clause {
64 fn ensure_registration() {}
65 }
66 };
67
68 TokenStream::from(expanded)
69 }
70
71 fn extract_struct_fields(fields: &Fields) -> TokenStream2 {
72 match fields {
73 Fields::Named(FieldsNamed { named, .. }) => {
74 let fields = extract_from_named_fields(named);
75 quote! { &[#(#fields),*] }
76 }
77 Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
78 let fields = extract_from_unnamed_fields(unnamed);
79 quote! { &[#(#fields),*] }
80 }
81 Fields::Unit => quote! { &[] },
82 }
83 }
84
85 fn extract_enum_variants(
86 variants: &syn::punctuated::Punctuated<Variant, syn::token::Comma>,
87 ) -> TokenStream2 {
88 let variants = variants.iter().enumerate().map(|(idx, v)| {
89 let name = &v.ident;
90 let discriminant = match &v.discriminant {
91 Some((_, expr)) => quote! { #expr as _ },
92 None => quote! { #idx as _ },
93 };
94 let fields = match &v.fields {
95 Fields::Named(FieldsNamed { named, .. }) => {
96 let fields = extract_from_named_fields(named);
97 quote! { Some(&[#(#fields),*]) }
98 }
99 Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
100 let fields = extract_from_unnamed_fields(unnamed);
101 quote! { Some(&[#(#fields),*]) }
102 }
103 Fields::Unit => quote! { None },
104 };
105 quote! { (#discriminant, stringify!(#name), #fields) }
106 });
107 quote! { &[#(#variants),*] }
108 }
109
110 fn extract_type_ids_from_type(ty: &Type) -> Vec<TokenStream2> {
114 let mut result = vec![quote! { std::any::TypeId::of::<#ty>() }];
115 let type_path = match ty {
116 Type::Path(type_path) => type_path,
117 _ => return result,
118 };
119
120 let generic_params = &type_path.path.segments.last().unwrap().arguments;
126 let params = match generic_params {
127 PathArguments::AngleBracketed(params) => params,
128 _ => return result,
129 };
130
131 let inner_type_ids = params
132 .args
133 .iter()
134 .map(|arg| {
135 if let GenericArgument::Type(ty) = arg {
136 extract_type_ids_from_type(ty)
137 } else {
138 vec![]
139 }
140 })
141 .flatten()
142 .collect::<Vec<_>>();
143 result.extend(inner_type_ids);
144 result
145 }
146
147 fn extract_type_info(ty: &Type) -> TokenStream2 {
148 match ty {
149 Type::Path(type_path) => {
150 let type_name = &type_path.path.segments.last().unwrap().ident;
151 let type_without_lifetimes = remove_lifetimes_from_type(type_path);
152 let type_ids = extract_type_ids_from_type(&type_without_lifetimes);
153 let type_ids_count = type_ids.len();
154
155 quote! {
156 {
157 const TYPE_IDS_COUNT: usize = #type_ids_count;
158 const fn create_array() -> [std::any::TypeId; TYPE_IDS_COUNT] {
159 [#(#type_ids),*]
160 }
161 (stringify!(#type_name), &create_array())
162 }
163 }
164 }
165 Type::Reference(type_ref) => {
166 let elem = &type_ref.elem;
167 extract_type_info(elem)
168 }
169 Type::Array(array) => {
170 let elem = &array.elem;
171 let len = &array.len;
172 quote! {
173 (stringify!([#elem; #len]), &[std::any::TypeId::of::<#elem>()])
174 }
175 }
176 Type::Slice(slice) => {
177 let elem = &slice.elem;
178 quote! {
179 (stringify!([#elem]), &[std::any::TypeId::of::<#elem>()])
180 }
181 }
182 Type::Tuple(tuple) => {
183 quote! { (stringify!(#tuple), &[std::any::TypeId::of::<#tuple>()]) }
184 }
185 _ => {
186 println!("Unsupported type: {:?}", ty);
187 quote! { (stringify!(#ty), &[std::any::TypeId::of::<#ty>()]) }
188 }
189 }
190 }
191
192 fn remove_lifetimes_from_type(type_path: &TypePath) -> Type {
193 let segments = type_path.path.segments.iter().map(|segment| {
194 let mut new_segment =
195 PathSegment { ident: segment.ident.clone(), arguments: PathArguments::None };
196
197 if let PathArguments::AngleBracketed(args) = &segment.arguments {
198 let new_args: Vec<_> = args
199 .args
200 .iter()
201 .filter_map(|arg| match arg {
202 GenericArgument::Type(ty) => {
203 Some(GenericArgument::Type(remove_lifetimes_from_type_recursive(ty)))
204 }
205 GenericArgument::Const(c) => Some(GenericArgument::Const(c.clone())),
206 _ => None,
207 })
208 .collect();
209
210 if !new_args.is_empty() {
211 new_segment.arguments =
212 PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
213 colon2_token: args.colon2_token,
214 lt_token: args.lt_token,
215 args: new_args.into_iter().collect(),
216 gt_token: args.gt_token,
217 });
218 }
219 }
220
221 new_segment
222 });
223
224 Type::Path(TypePath {
226 qself: type_path.qself.clone(),
227 path: Path {
228 leading_colon: type_path.path.leading_colon,
229 segments: segments.collect(),
230 },
231 })
232 }
233
234 fn remove_lifetimes_from_type_recursive(ty: &Type) -> Type {
235 match ty {
236 Type::Path(type_path) => remove_lifetimes_from_type(type_path),
237 Type::Reference(type_ref) => Type::Reference(syn::TypeReference {
238 and_token: type_ref.and_token,
239 lifetime: None,
240 mutability: type_ref.mutability,
241 elem: Box::new(remove_lifetimes_from_type_recursive(&type_ref.elem)),
242 }),
243 _ => ty.clone(),
244 }
245 }
246
247 fn extract_from_named_fields(
248 named: &syn::punctuated::Punctuated<Field, syn::token::Comma>,
249 ) -> impl Iterator<Item = TokenStream2> + '_ {
250 named.iter().map(|f| {
251 let name = &f.ident;
252 let ty = &f.ty;
253 let type_info = extract_type_info(ty);
254 quote! { (stringify!(#name), #type_info) }
255 })
256 }
257
258 fn extract_from_unnamed_fields(
259 unnamed: &syn::punctuated::Punctuated<Field, syn::token::Comma>,
260 ) -> impl Iterator<Item = TokenStream2> + '_ {
261 unnamed.iter().enumerate().map(|(i, f)| {
262 let index = Index::from(i);
263 let ty = &f.ty;
264 let type_info = extract_type_info(ty);
265 quote! { (stringify!(#index), #type_info) }
266 })
267 }
268
269 fn remove_lifetimes(generics: &Generics) -> proc_macro2::TokenStream {
270 let params: Vec<_> = generics
271 .params
272 .iter()
273 .filter_map(|param| match param {
274 GenericParam::Type(type_param) => Some(quote! { #type_param }),
275 GenericParam::Const(const_param) => Some(quote! { #const_param }),
276 GenericParam::Lifetime(_) => None,
277 })
278 .collect();
279
280 if !params.is_empty() {
281 quote! { <#(#params),*> }
282 } else {
283 quote! {}
284 }
285 }
286}
287
288#[cfg(not(all(enable_const_type_id, feature = "protocol_schema")))]
289mod helper {
290 use proc_macro::TokenStream;
291
292 pub fn protocol_schema_impl(_input: TokenStream) -> TokenStream {
293 TokenStream::new()
294 }
295}