risc0_zeroio_derive/
lib.rs

1// Copyright 2023 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#![doc = include_str!("../README.md")]
16
17use proc_macro2::TokenStream;
18use quote::{format_ident, quote, ToTokens};
19use syn::{
20    parse_macro_input, punctuated::Punctuated, spanned::Spanned, token::Comma, DataEnum,
21    DataStruct, DeriveInput, Error, Fields, GenericParam, Generics, Ident, Index, Lifetime,
22    LifetimeDef, Result, Variant, Visibility,
23};
24
25#[derive(Default, Debug)]
26struct GenericFragments {
27    // Results from syn's split_for_impl()
28    impl_generics: TokenStream,
29    ty_generics: TokenStream,
30    where_clause: TokenStream,
31
32    // Generated name for the "reference" type.
33    ref_name: TokenStream,
34
35    // Same, but for the generated reference type.
36    ref_impl_generics: TokenStream,
37    ref_ty_generics: TokenStream,
38    //    ref_ty_generics_turbofish: TokenStream,
39    ref_where_clause: TokenStream,
40
41    phantom_types: TokenStream,
42}
43
44fn to_tokens<T: ToTokens>(t: &T) -> TokenStream {
45    let mut tokens = TokenStream::new();
46    t.to_tokens(&mut tokens);
47    tokens
48}
49
50fn warning_inhibit() -> TokenStream {
51    quote! {
52        #[allow(unused_parens, non_camel_case_types, unused_variables, dead_code, missing_docs)]
53    }
54}
55
56fn make_generics(ref_name: &Ident, generics: &Generics) -> GenericFragments {
57    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
58
59    // We need this parameterized with a 'zeroio_deserialize' lifetime so that
60    // the reference structure outlives the buffer it points to.
61    let mut ref_generics = generics.clone();
62    let zeroio_lifetime = GenericParam::Lifetime(LifetimeDef::new(Lifetime::new(
63        "'zeroio_deserialize",
64        generics.span().unwrap().into(),
65    )));
66    ref_generics.params.push(zeroio_lifetime);
67    let (ref_impl_generics, ref_ty_generics, ref_where_clause) = ref_generics.split_for_impl();
68    //    let ref_ty_generics_turbofish = ref_ty_generics.as_turbofish();
69
70    // Since we don't actually hold any of the parameterized types, we
71    // have to reference them using PhantomData.
72    let mut phantom_types = Punctuated::<TokenStream, Comma>::from_iter(
73        generics.params.iter().filter_map(|p| match p {
74            GenericParam::Type(p) => Some(to_tokens(&p.ident)),
75            _ => None,
76        }),
77    );
78    phantom_types.push(quote! {&'zeroio_deserialize ()});
79
80    GenericFragments {
81        impl_generics: to_tokens(&impl_generics),
82        ty_generics: to_tokens(&ty_generics),
83        where_clause: to_tokens(&where_clause),
84        ref_name: to_tokens(&ref_name),
85        ref_impl_generics: to_tokens(&ref_impl_generics),
86        ref_ty_generics: to_tokens(&ref_ty_generics),
87        //        ref_ty_generics_turbofish: to_tokens(&ref_ty_generics_turbofish),
88        ref_where_clause: to_tokens(&ref_where_clause),
89        phantom_types: to_tokens(&phantom_types),
90    }
91}
92
93// Fragments for generating structures.
94#[derive(Debug)]
95struct StructFragments {
96    field_name: Vec<Ident>,
97
98    // Length calculation expressions
99    fixed_words: TokenStream,
100    tot_len: TokenStream,
101
102    // "fill" implementation (for Serialize)
103    fill: TokenStream,
104
105    // For named structures, implementations of accessor methods.
106    accessors: TokenStream,
107
108    // Body of the implementation of Deserialize.
109    deserialize_impl: TokenStream,
110
111    // If this is a tuple, generated names for the fields if we have to store them in variables.
112    tuple_id: Vec<TokenStream>,
113
114    // Create reference types
115    declare_ref: TokenStream,
116
117    // Convert from reference to original
118    from_ref: TokenStream,
119}
120
121// Returns fragments for building serializable and deserializable
122// structures.  We can compose these to either generate a named struct
123// or an enum struct variant.
124fn make_struct_impls(
125    selfname: &TokenStream,
126    vis: &Visibility,
127    fields: &Fields,
128    generics: &GenericFragments,
129    is_enum: bool,
130) -> StructFragments {
131    // Names and types of struct fields.
132    let field_name: Vec<Ident> = fields.iter().flat_map(|f| &f.ident).cloned().collect();
133    let field_ty: Vec<TokenStream> = fields.iter().map(|f| to_tokens(&f.ty)).collect();
134
135    // Names for tuple parts when accessing locally in a member function.
136    let tuple_id: Vec<_> = (0..fields.len())
137        .map(|idx| match is_enum {
138            true => to_tokens(&format_ident!("elem{idx}")),
139            false => {
140                // We want self.0, not self.0usize
141                let idx = Index::from(idx);
142                quote! { self . #idx }
143            }
144        })
145        .collect();
146
147    // Calculate starting offsets in the structure based off the fixed
148    // length of each field.
149    let field_offsets: Vec<_> = (0..field_ty.len())
150        .map(|idx| {
151            if idx == 0 {
152                quote! { 0 }
153            } else {
154                let part_ty = &field_ty[0..idx];
155                quote! {
156                    #( <#part_ty as risc0_zeroio::Deserialize<'_>> :: FIXED_WORDS)+*
157                }
158            }
159        })
160        .collect();
161
162    let fixed_words = match &fields {
163        Fields::Unit => quote! {0},
164        Fields::Named(_) | Fields::Unnamed(_) => quote! {
165            #(<#field_ty as risc0_zeroio::Serialize>::FIXED_WORDS)+*
166        },
167    };
168    let selfref = if is_enum {
169        quote! {}
170    } else {
171        quote! {self.}
172    };
173    let tot_len = match &fields {
174        Fields::Unit => quote! {0},
175        Fields::Named(_) => quote! {
176            #(#selfref #field_name . tot_len())+*
177        },
178        Fields::Unnamed(_) => quote! {
179            #(#tuple_id . tot_len())+*
180        },
181    };
182
183    let fill = match &fields {
184        Fields::Unit => {
185            quote! { Ok(()) }
186        }
187        Fields::Named(_) => {
188            quote! {
189                let pos: usize = 0;
190                #(
191                    #selfref #field_name . fill(&mut _buf.descend(
192                        pos, <#field_ty as risc0_zeroio::Serialize>::FIXED_WORDS)?, _a)?;
193                    let pos = pos + <#field_ty as risc0_zeroio::Serialize>::FIXED_WORDS;
194                )*
195
196                Ok(())
197            }
198        }
199        Fields::Unnamed(_) => {
200            quote! {
201                let pos: usize = 0;
202                #(
203                    #tuple_id . fill(&mut _buf.descend(
204                        pos, <#field_ty as risc0_zeroio::Serialize>::FIXED_WORDS)?, _a)?;
205                    let pos = pos + <#field_ty as risc0_zeroio::Serialize>::FIXED_WORDS;
206                )*
207
208                Ok(())
209            }
210        }
211    };
212
213    // Field accessors
214    let accessors = match &fields {
215        Fields::Unit => TokenStream::new(),
216        Fields::Named(_) => quote! {
217            // For each field, generate an accessor method.
218            #(
219                #vis fn #field_name(&self) ->
220                    <#field_ty as risc0_zeroio::Deserialize<'zeroio_deserialize>>::RefType {
221                        <#field_ty as risc0_zeroio::Deserialize<'zeroio_deserialize>>::deserialize_from(
222                            &self.buf[#field_offsets ..]
223                        )
224                    }
225            )*
226        },
227        Fields::Unnamed(_) => {
228            let tuple_method = (0..fields.len()).map(|idx| format_ident!("elem{}", idx));
229
230            quote! {#(
231                #vis fn #tuple_method(&self) -> <#field_ty as risc0_zeroio::Deserialize<'zeroio_deserialize>>::RefType {
232                    <#field_ty as risc0_zeroio::Deserialize<'zeroio_deserialize>>::deserialize_from(
233                        &self.buf[#field_offsets ..]
234                    )
235                }
236            )*}
237        }
238    };
239
240    let from_ref = match &fields {
241        Fields::Unit => {
242            quote! { #selfname }
243        }
244        Fields::Named(_) => quote! {
245            #selfname{#(
246                #field_name:
247                <#field_ty as risc0_zeroio::Deserialize<'zeroio_deserialize>>::from_ref(
248                    &<#field_ty as risc0_zeroio::Deserialize<'zeroio_deserialize>>::deserialize_from(&_val.buf[#field_offsets ..]))
249            ),*}
250        },
251        Fields::Unnamed(_) => quote! {
252        #selfname(#(
253            <#field_ty as risc0_zeroio::Deserialize<'zeroio_deserialize>>::from_ref(
254                &<#field_ty as risc0_zeroio::Deserialize<'zeroio_deserialize>>::deserialize_from(&_val.buf[#field_offsets ..]))
255        ),*)},
256    };
257
258    let GenericFragments {
259        ref_name,
260        ref_ty_generics,
261        phantom_types,
262        ..
263    } = generics;
264
265    let inhibit_warns = warning_inhibit();
266    let declare_ref = quote! {
267        #inhibit_warns
268        #vis struct #ref_name #ref_ty_generics {
269            buf: &'zeroio_deserialize [u32],
270            phantom: core::marker::PhantomData <(#phantom_types)>,
271        }
272    };
273
274    let deserialize_impl = match &fields {
275        Fields::Unit => quote! {
276
277            type RefType = #ref_name #ref_ty_generics;
278
279            const FIXED_WORDS : usize = 0;
280
281            fn deserialize_from(_buf: &'zeroio_deserialize [u32]) -> Self::RefType {
282                Self::RefType{phantom: core::marker::PhantomData}
283            }
284
285            fn from_ref(_val: &Self::RefType) -> Self {
286                #from_ref
287            }
288        },
289        Fields::Named(_) | Fields::Unnamed(_) => quote! {
290            type RefType = #ref_name #ref_ty_generics;
291
292            const FIXED_WORDS : usize =
293                #(<#field_ty as risc0_zeroio::Deserialize<'_>>::FIXED_WORDS)+* ;
294
295            fn deserialize_from(_buf: &'zeroio_deserialize [u32]) -> Self::RefType {
296                Self::RefType { buf: _buf, phantom: core::marker::PhantomData }
297            }
298
299            fn from_ref(_val: &Self::RefType) -> Self {
300                #from_ref
301            }
302        },
303    };
304
305    StructFragments {
306        tot_len,
307        fill,
308        declare_ref,
309        deserialize_impl,
310        tuple_id,
311        accessors,
312        fixed_words,
313        field_name,
314        from_ref,
315    }
316}
317
318fn emit_serialize_struct(input: &DeriveInput, st: &DataStruct) -> Result<TokenStream> {
319    // Name of this structure
320    let name = &input.ident;
321
322    let genfrags @ GenericFragments {
323        impl_generics,
324        ty_generics,
325        where_clause,
326        ..
327    } = &make_generics(&format_ident!("{}Ref", name), &input.generics);
328
329    let StructFragments {
330        fixed_words,
331        tot_len,
332        fill,
333        ..
334    } = make_struct_impls(&quote! {#name}, &input.vis, &st.fields, &genfrags, false);
335
336    // We serialize structures as simply all the fields in order.
337    let inhibit_warns = warning_inhibit();
338    let expanded = quote! {
339        #inhibit_warns
340        impl #impl_generics risc0_zeroio::Serialize for #name #ty_generics #where_clause {
341            const FIXED_WORDS : usize = #fixed_words;
342
343            fn tot_len(&self) -> usize { #tot_len }
344
345            fn fill(&self, _buf: & mut risc0_zeroio::AllocBuf,
346                    _a: &mut risc0_zeroio::Alloc) -> risc0_zeroio::Result<()> {
347                #fill
348            }
349        }
350    };
351
352    Ok(expanded.into())
353}
354
355fn emit_deserialize_struct(input: &DeriveInput, fields: &Fields) -> Result<TokenStream> {
356    // Name of this structure
357    let name = &input.ident;
358
359    let genfrags @ GenericFragments {
360        ty_generics,
361        where_clause,
362        ref_name,
363        ref_impl_generics,
364        ref_ty_generics,
365        ref_where_clause,
366        ..
367    } = &make_generics(&format_ident!("{}Ref", name), &input.generics);
368
369    let StructFragments {
370        accessors,
371        declare_ref,
372        deserialize_impl,
373        ..
374    } = make_struct_impls(&quote! {#name}, &input.vis, fields, &genfrags, false);
375
376    let inhibit_warns = warning_inhibit();
377    let vis = &input.vis;
378    let expanded = quote! {
379        #declare_ref
380
381        #inhibit_warns
382        impl #ref_impl_generics #ref_name #ref_ty_generics #ref_where_clause {
383            #accessors
384
385            #vis fn into_orig(&self) -> #name #ty_generics {
386                <#name #ty_generics as risc0_zeroio::Deserialize<'zeroio_deserialize>>::from_ref(&self)
387            }
388        }
389
390        // Deserialize trait; construct this from a buffer.
391        #inhibit_warns
392        impl #ref_impl_generics risc0_zeroio::Deserialize<'zeroio_deserialize> for #name #ty_generics #where_clause {
393            #deserialize_impl
394        }
395    };
396
397    Ok(expanded.into())
398}
399
400fn make_var_generated_type(name: &Ident, var_name: &Ident) -> Ident {
401    format_ident!("{}Ref", format!("{}::{}", name, var_name).replace(":", "_"))
402}
403
404// Fragments for each variant
405struct VarFragments {
406    var: Variant,
407
408    var_id: usize,
409    var_name: Ident,
410    // Generated types
411    var_ref_ty: Ident,
412
413    st: StructFragments,
414}
415
416fn make_var_frags(name: &Ident, vis: &Visibility, var: &Variant, var_id: usize) -> VarFragments {
417    let var_name = var.ident.clone();
418    let var_ref_ty = make_var_generated_type(name, &var_name);
419    let generics = make_generics(&var_ref_ty, &Generics::default());
420    let st = make_struct_impls(
421        &quote! {#name :: #var_name},
422        vis,
423        &var.fields,
424        &generics,
425        true,
426    );
427    VarFragments {
428        var: var.clone(),
429        var_id,
430        var_name,
431        var_ref_ty,
432        st,
433    }
434}
435
436fn make_enum_frags(name: &Ident, vis: &Visibility, en: &DataEnum) -> Vec<VarFragments> {
437    en.variants
438        .iter()
439        .enumerate()
440        .map(|(var_id, var)| make_var_frags(name, vis, var, var_id))
441        .collect()
442}
443
444fn emit_serialize_enum(input: &DeriveInput, en: &DataEnum) -> Result<TokenStream> {
445    // Name of this enum
446    let name = &input.ident;
447
448    let GenericFragments {
449        impl_generics,
450        ty_generics,
451        where_clause,
452        ..
453    } = make_generics(&format_ident!("{}Ref", name), &input.generics);
454
455    let vars = make_enum_frags(name, &input.vis, en);
456
457    let match_tot_len = vars.iter().map(
458        |VarFragments {
459             var,
460             var_name,
461             st:
462                 StructFragments {
463                     tot_len,
464                     field_name,
465                     tuple_id,
466                     ..
467                 },
468             ..
469         }| {
470            match var.fields {
471                Fields::Unit => quote! {
472                    #name :: #var_name => #tot_len
473                },
474                Fields::Named(_) => quote! {
475                    #name :: #var_name{ #(#field_name),* } => #tot_len
476                },
477                Fields::Unnamed(_) => quote! {
478                    #name :: #var_name( #(#tuple_id),* ) => #tot_len
479                },
480            }
481        },
482    );
483
484    let match_and_fill = vars.iter().map(
485        |VarFragments {
486             var,
487             var_id,
488             var_name,
489             st:
490                 StructFragments {
491                     fill,
492                     field_name,
493                     tuple_id,
494                     fixed_words,
495                     ..
496                 },
497             ..
498         }| {
499            let var_id = *var_id as u32;
500            let enumfill = quote! {
501                let mut vardata = _a.alloc(#fixed_words)?;
502                {
503                    let _buf = &mut vardata;
504                    #fill
505                }?;
506                _enumbuf.fill_from([#var_id, vardata.rel_ptr_from(_enumbuf)])
507            };
508            match var.fields {
509                Fields::Unit => quote! {
510                    #name :: #var_name => { #enumfill }
511                },
512                Fields::Named(_) => quote! {
513                    #name :: #var_name{ #(#field_name),* } => { #enumfill }
514                },
515                Fields::Unnamed(_) => quote! {
516                    #name :: #var_name( #(#tuple_id).* ) => { #enumfill }
517                },
518            }
519        },
520    );
521
522    // We serialize enums as a descriminant ID, and a pointer to the variant inside.
523    let inhibit_warns = warning_inhibit();
524    let expanded = quote! {
525        #inhibit_warns
526        impl #impl_generics risc0_zeroio::Serialize for #name #ty_generics #where_clause {
527            // One word for id, one word for pointer.
528            const FIXED_WORDS: usize = 2;
529
530            fn tot_len(&self) -> usize {
531                <Self as risc0_zeroio::Serialize>::FIXED_WORDS + match self {
532                    #(#match_tot_len,)*
533                }
534            }
535
536            fn fill(&self, _enumbuf: &mut risc0_zeroio::AllocBuf, _a: &mut risc0_zeroio::Alloc)
537                    -> risc0_zeroio::Result<()> {
538                match self {
539                    #(#match_and_fill,)*
540                }
541            }
542        }
543    };
544    Ok(expanded.into())
545}
546
547fn emit_deserialize_enum(input: &DeriveInput, en: &DataEnum) -> Result<TokenStream> {
548    // Name of this enum
549    let name = &input.ident;
550    let vis = &input.vis;
551
552    let GenericFragments {
553        ty_generics,
554        where_clause,
555        ref_name,
556        ref_impl_generics,
557        ref_ty_generics,
558        ..
559    } = &make_generics(&format_ident!("{}Ref", name), &input.generics);
560
561    let vars = &make_enum_frags(name, &input.vis, en);
562
563    let var_name: &Vec<_> = &vars.iter().map(|var| &var.var_name).collect();
564    let var_ref_ty = vars.iter().map(|var| &var.var_ref_ty);
565
566    let declare_ref = vars.iter().map(|var| &var.st.declare_ref);
567
568    let match_and_deser = vars.iter().map(
569        |VarFragments {
570             var_id,
571             var_name,
572             var_ref_ty,
573             ..
574         }| {
575            let var_id = *var_id as u32;
576            quote! {
577                #var_id => Self::RefType::#var_name(
578                    #var_ref_ty::<'zeroio_deserialize>::deserialize_from(&_buf[ptr as usize..]))
579            }
580        },
581    );
582
583    let ref_impl = vars.iter().map(
584        |VarFragments {
585             var_ref_ty,
586             st:
587                 StructFragments {
588                     accessors,
589                     from_ref,
590                     ..
591                 },
592             ..
593         }| {
594            quote! {
595                impl<'zeroio_deserialize> #var_ref_ty <'zeroio_deserialize> {
596                    #accessors
597
598                    fn deserialize_from(_buf: &'zeroio_deserialize [u32]) -> Self {
599                        Self{buf: _buf, phantom: core::marker::PhantomData}
600                    }
601
602                    pub fn into_orig(&self) -> #name {
603                        let _val = self;
604                        #from_ref
605                    }
606                }
607            }
608        },
609    );
610
611    let inhibit_warns = warning_inhibit();
612    let expanded = quote! {
613        #inhibit_warns
614        #vis enum #ref_name #ref_ty_generics {#(
615                #var_name(#var_ref_ty<'zeroio_deserialize>),
616        )*}
617
618        // Constructed reference types for each variant
619        #(#declare_ref)*
620        #(#ref_impl)*
621
622        // Deserialize trait; construct this from a buffer.
623        #inhibit_warns
624        impl #ref_impl_generics risc0_zeroio::Deserialize<'zeroio_deserialize> for #name #ty_generics #where_clause {
625            type RefType = #ref_name #ref_ty_generics;
626
627            const FIXED_WORDS : usize = 2;
628
629            fn deserialize_from(_buf: &'zeroio_deserialize [u32]) -> Self::RefType {
630                let (id, ptr) = (_buf[0], _buf[1]);
631                match id {
632                    #(#match_and_deser,)*
633                    _ => panic!("Unknown variant id {}", id)
634                }
635            }
636
637            fn from_ref(_val: &Self::RefType) -> Self {
638                match _val {#(
639                    Self::RefType::#var_name(ref var) => var.into_orig()
640                ,)*}
641            }
642        }
643    };
644
645    Ok(expanded.into())
646}
647
648// With the debug-derive feature, this dumps out all the generated
649// code and the includes it via include! so we can see errors in
650// context.
651fn debug_dump(ty: &str, ident: &Ident, res: &mut TokenStream) {
652    if cfg!(feature = "debug-derive") {
653        let filename = format!("{ty}-{ident}.rs");
654        let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap();
655        let path = std::path::Path::new(&manifest_dir)
656            .join("src")
657            .join(&filename);
658        std::fs::write(&path, format!("{}", res)).unwrap();
659
660        let pathname = path.display().to_string();
661        *res = quote! {
662            include!(#pathname);
663        };
664    }
665}
666
667#[proc_macro_derive(Serialize)]
668pub fn derive_serialize(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
669    let input = parse_macro_input!(input as DeriveInput);
670
671    let mut res = match &input.data {
672        syn::Data::Struct(ref st) => emit_serialize_struct(&input, st),
673        syn::Data::Enum(en) => emit_serialize_enum(&input, &en),
674        _ => Err(Error::new(
675            input.span().unwrap().into(),
676            "Zeroio derive only supports structs and enums",
677        )),
678    }
679    .unwrap_or_else(|err| Error::to_compile_error(&err).into());
680    debug_dump("ser", &input.ident, &mut res);
681    res.into()
682}
683
684#[proc_macro_derive(Deserialize)]
685pub fn derive_deserialize(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
686    let input = parse_macro_input!(input as DeriveInput);
687
688    let mut res = match &input.data {
689        syn::Data::Struct(st) => emit_deserialize_struct(&input, &st.fields),
690        syn::Data::Enum(en) => emit_deserialize_enum(&input, &en),
691        _ => Err(Error::new(
692            input.span().unwrap().into(),
693            "Zeroio derive only supports structs and enums",
694        )),
695    }
696    .unwrap_or_else(|err| Error::to_compile_error(&err).into());
697    debug_dump("deser", &input.ident, &mut res);
698    res.into()
699}