pocket_prover_derive/
lib.rs

1//! # pocket_prover-derive
2//!
3//! Derive procedural macros for `pocket_prover`.
4//!
5//! Example:
6//!
7//! ```ignore
8//! #[macro_use]
9//! extern crate pocket_prover_derive;
10//! extern crate pocket_prover;
11//!
12//! use pocket_prover::Construct;
13//!
14//! #[derive(Construct)]
15//! pub struct Foo {
16//!     pub a: u64,
17//!     pub b: u64,
18//! }
19//! ```
20//!
21//! Since `pocket_prover` uses only `u64`,
22//! it is the only valid concrete field type.
23//!
24//! The macro supports generic arguments, assuming that
25//! the inner type implements `Construct`:
26//!
27//! ```ignore
28//! #[derive(Construct)]
29//! pub struct Bar<T = ()> {
30//!     pub foo: T,
31//!     pub a: u64,
32//!     pub b: u64,
33//! }
34//! ```
35
36extern crate proc_macro;
37extern crate syn;
38#[macro_use]
39extern crate quote;
40
41use proc_macro::{TokenStream};
42use syn::{
43    Body, Ident, VariantData, PolyTraitRef, Ty, TyParamBound,
44    TraitBoundModifier, WherePredicate, WhereBoundPredicate
45};
46use quote::Tokens;
47
48#[proc_macro_derive(Construct)]
49pub fn construct(input: TokenStream) -> TokenStream {
50    // Construct a string representation of the type definition
51    let s = input.to_string();
52
53    // Parse the string representation
54    let ast = syn::parse_derive_input(&s).unwrap();
55
56    // Build the impl
57    let gen = impl_construct(&ast);
58
59    // Return the generated impl
60    gen.parse().unwrap()
61}
62
63fn impl_construct(ast: &syn::DeriveInput) -> quote::Tokens {
64    if let Body::Struct(ref body) = ast.body {
65        let name = &ast.ident;
66        let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
67
68        // Get field identifier and type for all struct fields.
69        let fields: Vec<(_, Ty)> = if let &VariantData::Struct(ref fields) = body {
70            fields.iter()
71                .map(|field| (
72                    field.ident.as_ref().unwrap(),
73                    field.ty.clone(),
74                )).collect()
75        } else {
76            panic!("Expected struct fields");
77        };
78
79        // Add constraints and compute offsets.
80        let mut where_clause = where_clause.clone();
81        let mut offsets = Tokens::new();
82        let mut i = 0;
83        let mut ns = vec![];
84        for &(_, ref ty) in &fields {
85            let ty_ident = if let &Ty::Path(_, ref parameters) = ty {
86                parameters.segments[0].ident.clone()
87            } else {
88                panic!("Could not find type identifier.")
89            };
90            if ty_ident != &Ident::new("u64") {
91                // Add `T: Construct` constraint.
92                where_clause.predicates.push(WherePredicate::BoundPredicate(WhereBoundPredicate {
93                    bound_lifetimes: vec![],
94                    bounded_ty: ty.clone(),
95                    bounds: vec![TyParamBound::Trait(PolyTraitRef {
96                        bound_lifetimes: vec![],
97                        trait_ref: Ident::new("Construct").into()
98                    }, TraitBoundModifier::None)]
99                }));
100                offsets.append(
101                    format!("let n{} = <{} as Construct>::n();", i, ty_ident)
102                );
103                i += 1;
104            } else {
105                ns.push(i);
106            }
107        }
108
109        // Map arguments to struct fields.
110        let mut field_tokens = Tokens::new();
111        let mut i = 0;
112        for &(ref field, ref ty) in &fields {
113            field_tokens.append(field);
114            field_tokens.append(":");
115            let ty_ident = if let &Ty::Path(_, ref parameters) = ty {
116                parameters.segments[0].ident.clone()
117            } else {
118                panic!("Could not find type identifier.")
119            };
120            if ty_ident == &Ident::new("u64") {
121                field_tokens.append("vs[");
122                if ns[i] != 0 {
123                    field_tokens.append(format!("n{}+", ns[i]-1));
124                }
125                field_tokens.append(&format!("{}", i));
126                field_tokens.append("]");
127                i += 1;
128            } else {
129                field_tokens.append("Construct::construct(vs)");
130            }
131            field_tokens.append(",");
132        }
133
134        quote! {
135            impl #impl_generics Construct for #name #ty_generics #where_clause {
136                fn construct(vs: &[u64]) -> Self {
137                    #offsets
138                    #name {
139                        #field_tokens
140                    }
141                }
142            }
143        }
144    } else {panic!("Must be a struct.")}
145}