persian_rug_derive/
lib.rs

1use proc_macro::{self, TokenStream};
2use proc_macro2 as pm2;
3use quote::ToTokens;
4
5enum ConstraintItem {
6    Context(syn::Ident),
7    Access(Vec<syn::Type>),
8}
9
10impl syn::parse::Parse for ConstraintItem {
11    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
12        let attr: syn::Ident = input.parse()?;
13        match attr.to_string().as_str() {
14            "context" => {
15                let _: syn::Token![=] = input.parse()?;
16                let value = input.parse()?;
17                Ok(ConstraintItem::Context(value))
18            }
19            "access" => {
20                let content;
21                let _: syn::token::Paren = syn::parenthesized!(content in input);
22                let punc =
23                    syn::punctuated::Punctuated::<syn::Type, syn::Token![,]>::parse_terminated(
24                        &content,
25                    )?;
26                Ok(ConstraintItem::Access(punc.into_iter().collect()))
27            }
28            _ => Err(syn::Error::new_spanned(
29                attr,
30                "unsupported persian-rug constraint",
31            )),
32        }
33    }
34}
35
36struct ConstraintArgs {
37    pub context: syn::Ident,
38    pub used_types: Vec<syn::Type>,
39}
40
41impl syn::parse::Parse for ConstraintArgs {
42    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
43        let punc =
44            syn::punctuated::Punctuated::<ConstraintItem, syn::Token![,]>::parse_terminated(input)?;
45        let mut context = None;
46        let mut used_types = Vec::new();
47
48        for item in punc.into_iter() {
49            match item {
50                ConstraintItem::Context(id) => {
51                    context = Some(id);
52                }
53                ConstraintItem::Access(tys) => {
54                    used_types.extend(tys);
55                }
56            }
57        }
58
59        context
60            .map(|context| Self {
61                context,
62                used_types,
63            })
64            .ok_or_else(|| {
65                syn::Error::new(
66                    pm2::Span::call_site(),
67                    "No context provided for constraints.",
68                )
69            })
70    }
71}
72
73/// Add the type constraints necessary for an impl using persian-rug.
74///
75/// Rust currently requires all relevant constraints to be written out
76/// for every impl using a given type. For persian-rug in particular,
77/// there are typically many constraints of a simple kind: for every
78/// type owned by the given `Context`, there must be an `Owner`
79/// implementation for the context and there must be a matching
80/// `Contextual` implementation for the type. This macro simply
81/// generates these constraints for you.
82///
83/// The attribute takes two types of argument:
84/// - `context` specifies the name of the type of the context.
85/// - `access(...)` specifies the types that this impl requires to
86///   exist within that context. Typically each type requires some
87///   other types to also exist in its context for it to be
88///   well-formed.  This argument needs to be given the transitive
89///   closure of all such types, both direct and indirect dependencies
90///   of the impl itself. It is unfortunately not possible at present
91///   to find the indirect dependencies automatically.
92///
93/// Example:
94/// ```rust
95/// use persian_rug::{contextual, Context, Mutator, Proxy};
96///
97/// #[contextual(C)]
98/// struct Foo<C: Context> {
99///    _marker: core::marker::PhantomData<C>,
100///    a: i32
101/// }
102///
103/// struct Bar<C: Context> {
104///    foo: Proxy<Foo<C>>
105/// }
106///
107/// #[persian_rug::constraints(context = C, access(Foo<C>))]
108/// impl<C> Bar<C> {
109///    pub fn new<M: Mutator<Context=C>>(foo: Foo<C>, mut mutator: M) -> Self {
110///        Self { foo: mutator.add(foo) }
111///    }
112/// }
113/// ```
114#[proc_macro_attribute]
115pub fn constraints(args: TokenStream, input: TokenStream) -> TokenStream {
116    let mut target: syn::Item = syn::parse_macro_input!(input);
117
118    let generics = match &mut target {
119        syn::Item::Enum(e) => &mut e.generics,
120        syn::Item::Fn(f) => &mut f.sig.generics,
121        syn::Item::Impl(i) => &mut i.generics,
122        syn::Item::Struct(s) => &mut s.generics,
123        syn::Item::Trait(t) => &mut t.generics,
124        syn::Item::TraitAlias(t) => &mut t.generics,
125        syn::Item::Type(t) => &mut t.generics,
126        syn::Item::Union(u) => &mut u.generics,
127        _ => {
128            return syn::Error::new(
129                pm2::Span::call_site(),
130                "This attribute extends a where clause, or generic constraints. It cannot be used here."
131            )
132                .to_compile_error()
133                .into();
134        }
135    };
136
137    let ConstraintArgs {
138        context,
139        used_types,
140    } = syn::parse_macro_input!(args);
141
142    let wc = generics.make_where_clause();
143
144    let mut getters = syn::punctuated::Punctuated::<syn::TypeParamBound, syn::token::Add>::new();
145    getters.push(syn::parse_quote! { ::persian_rug::Context });
146    for ty in &used_types {
147        getters.push(syn::parse_quote! { ::persian_rug::Owner<#ty> });
148    }
149
150    wc.predicates.push(syn::parse_quote! {
151        #context: #getters
152    });
153
154    for ty in &used_types {
155        wc.predicates.push(syn::parse_quote! {
156            #ty: ::persian_rug::Contextual<Context = #context>
157        });
158    }
159
160    target.into_token_stream().into()
161}
162
163/// Convert an annotated struct into a `Context`
164///
165/// Each field marked with `#[table]` will be converted to be a
166/// `Table` of values of the same type. An implementation of `Context`
167/// will be provided. In addition, an implementation of `Owner` for
168/// each field type will be derived for the overall struct.
169///
170/// Note that a `Context` can only contain one table of each type.
171///
172/// Example:
173/// ```rust
174/// use persian_rug::{contextual, persian_rug, Proxy};
175///
176/// #[contextual(MyRug)]
177/// struct Foo {
178///    a: i32
179/// }
180///
181/// #[contextual(MyRug)]
182/// struct Bar {
183///    a: i32,
184///    b: Proxy<Foo>
185/// };
186///
187/// #[persian_rug]
188/// struct MyRug(#[table] Foo, #[table] Bar);
189/// ```
190#[proc_macro_attribute]
191pub fn persian_rug(_args: TokenStream, input: TokenStream) -> TokenStream {
192    let syn::DeriveInput {
193        attrs,
194        vis,
195        ident: ty_ident,
196        data,
197        generics,
198    } = syn::parse_macro_input!(input);
199
200    let (generics, ty_generics, wc) = generics.split_for_impl();
201
202    let mut impls = pm2::TokenStream::new();
203
204    let body = if let syn::Data::Struct(s) = data {
205        let mut fields = syn::punctuated::Punctuated::<syn::Field, syn::Token![,]>::new();
206
207        let mut process_field = |field: &syn::Field| {
208            let is_table = field.attrs.iter().any(|attr| attr.path.is_ident("table"));
209
210            let field_type = &field.ty;
211            let ident = field
212                .ident
213                .as_ref()
214                .map(|id| syn::Member::Named(id.clone()))
215                .unwrap_or_else(|| {
216                    syn::Member::Unnamed(syn::Index {
217                        index: fields.len() as u32,
218                        span: pm2::Span::call_site(),
219                    })
220                });
221
222            let vis = &field.vis;
223
224            let attrs = field
225                .attrs
226                .iter()
227                .filter(|a| !a.path.is_ident("table"))
228                .cloned()
229                .collect::<Vec<_>>();
230
231            if !is_table {
232                fields.push(field.clone());
233            } else {
234                fields.push(syn::Field {
235                    attrs,
236                    vis: vis.clone(),
237                    ident: if let syn::Member::Named(id) = &ident {
238                        Some(id.clone())
239                    } else {
240                        None
241                    },
242                    colon_token: field.colon_token,
243                    ty: syn::parse_quote! {
244                        ::persian_rug::Table<#field_type>
245                    },
246                });
247
248                impls.extend(quote::quote! {
249                    impl #generics ::persian_rug::Owner<#field_type> for #ty_ident #ty_generics #wc {
250                        fn add(&mut self, what: #field_type) -> ::persian_rug::Proxy<#field_type> {
251                            self.#ident.push(what)
252                        }
253                        fn get(&self, what: &::persian_rug::Proxy<#field_type>) -> &#field_type {
254                            self.#ident.get(what).unwrap()
255                        }
256                        fn get_mut(&mut self, what: &::persian_rug::Proxy<#field_type>) -> &mut #field_type {
257                            self.#ident.get_mut(what).unwrap()
258                        }
259                        fn get_iter(&self) -> ::persian_rug::TableIterator<'_, #field_type> {
260                            self.#ident.iter()
261                        }
262                        fn get_iter_mut(&mut self) -> ::persian_rug::TableMutIterator<'_, #field_type> {
263                            self.#ident.iter_mut()
264                        }
265                        fn get_proxy_iter(&self) -> ::persian_rug::TableProxyIterator<'_, #field_type> {
266                            self.#ident.iter_proxies()
267                        }
268                    }
269                });
270            }
271        };
272
273        match s.fields {
274            syn::Fields::Named(syn::FieldsNamed { named, .. }) => {
275                for field in named.iter() {
276                    (process_field)(field);
277                }
278                quote::quote! {
279                    #vis struct #ty_ident #generics #wc {
280                        #fields
281                    }
282                }
283            }
284            syn::Fields::Unnamed(syn::FieldsUnnamed { unnamed, .. }) => {
285                for field in unnamed.iter() {
286                    (process_field)(field);
287                }
288                quote::quote! {
289                    #vis struct #ty_ident #generics(
290                        #fields
291                    ) #wc;
292                }
293            }
294            syn::Fields::Unit => {
295                quote::quote! {
296                    #vis struct #ty_ident #generics #wc;
297                }
298            }
299        }
300    } else {
301        return syn::Error::new(
302            pm2::Span::call_site(),
303            "Only structs can be annotated as persian-rugs.",
304        )
305        .to_compile_error()
306        .into();
307    };
308
309    let attrs = {
310        let mut res = pm2::TokenStream::new();
311        for attr in attrs {
312            attr.to_tokens(&mut res);
313        }
314        res
315    };
316
317    let res = quote::quote! {
318        #attrs
319        #body
320
321        impl #generics ::persian_rug::Context for #ty_ident #ty_generics #wc {
322            fn add<T>(&mut self, what: T) -> ::persian_rug::Proxy<T>
323            where
324                #ty_ident #ty_generics: ::persian_rug::Owner<T>,
325                T: ::persian_rug::Contextual<Context=Self>
326            {
327                <Self as ::persian_rug::Owner<T>>::add(self, what)
328            }
329
330            fn get<T>(&self, what: &::persian_rug::Proxy<T>) -> &T
331            where
332                #ty_ident #ty_generics: ::persian_rug::Owner<T>,
333                T: ::persian_rug::Contextual<Context=Self>
334            {
335                <Self as ::persian_rug::Owner<T>>::get(self, what)
336            }
337
338            fn get_mut<T>(&mut self, what: &::persian_rug::Proxy<T>) -> &mut T
339            where
340                #ty_ident #ty_generics: ::persian_rug::Owner<T>,
341                T: ::persian_rug::Contextual<Context=Self>
342            {
343                <Self as ::persian_rug::Owner<T>>::get_mut(self, what)
344            }
345
346            fn get_iter<T>(&self) -> ::persian_rug::TableIterator<'_, T>
347            where
348                #ty_ident #ty_generics: ::persian_rug::Owner<T>,
349                T: ::persian_rug::Contextual<Context=Self>
350            {
351                <Self as ::persian_rug::Owner<T>>::get_iter(self)
352            }
353
354            fn get_iter_mut<T>(&mut self) -> ::persian_rug::TableMutIterator<'_, T>
355            where
356                #ty_ident #ty_generics: ::persian_rug::Owner<T>,
357                T: ::persian_rug::Contextual<Context=Self>
358            {
359                <Self as ::persian_rug::Owner<T>>::get_iter_mut(self)
360            }
361
362            fn get_proxy_iter<T>(&self) -> ::persian_rug::TableProxyIterator<'_, T>
363            where
364                #ty_ident #ty_generics: ::persian_rug::Owner<T>,
365                T: ::persian_rug::Contextual<Context=Self>
366            {
367                <Self as ::persian_rug::Owner<T>>::get_proxy_iter(self)
368            }
369        }
370
371        #impls
372    };
373
374    res.into()
375}
376
377/// Provide a implementation of `Contextual` for a type.
378///
379/// This is a very simple derive-style macro, that creates an
380/// impl for `Contextual` for the type it annotates. It takes
381/// one argument, which is the `Context` type that this
382/// type belongs to.
383///
384/// Example:
385/// ```rust
386/// use persian_rug::{contextual, Context};
387///
388/// #[contextual(C)]
389/// struct Foo<C: Context> {
390///    _marker: core::marker::PhantomData<C>
391/// }
392/// ```
393/// which is equivalent to the following:
394/// ```rust
395/// use persian_rug::{Context, Contextual};
396///
397/// struct Foo<C: Context> {
398///    _marker: core::marker::PhantomData<C>
399/// }
400///
401/// impl<C: Context> Contextual for Foo<C> {
402///    type Context = C;
403/// }
404/// ```
405#[proc_macro_attribute]
406pub fn contextual(args: TokenStream, input: TokenStream) -> TokenStream {
407    let body = pm2::TokenStream::from(input.clone());
408
409    let syn::DeriveInput {
410        ident, generics, ..
411    } = syn::parse_macro_input!(input);
412
413    if args.is_empty() {
414        return syn::Error::new(
415            pm2::Span::call_site(),
416            "You must specify the associated context when using contextual.",
417        )
418        .to_compile_error()
419        .into();
420    }
421
422    let context: syn::Type = syn::parse_macro_input!(args);
423
424    let (generics, ty_generics, wc) = generics.split_for_impl();
425
426    let res = quote::quote! {
427        #body
428
429        impl #generics ::persian_rug::Contextual for #ident #ty_generics #wc {
430            type Context = #context;
431        }
432    };
433
434    res.into()
435}