Skip to main content

algebraeon_macros/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use proc_macro2::Span;
5use quote::quote;
6use syn::visit_mut::VisitMut;
7use syn::{
8    Attribute, DeriveInput, Error, FnArg, Ident, ItemTrait, PatIdent, Receiver, TraitItem,
9    TraitItemFn, parse_macro_input,
10};
11
12fn has_option(attrs: &[Attribute], option_name: &str) -> bool {
13    for attr in attrs
14        .iter()
15        .filter(|a| a.path().is_ident("canonical_structure"))
16    {
17        let mut found = false;
18
19        // `parse_nested_meta` lets us walk through the arguments in #[canonical_structure(...)]
20        let _ = attr.parse_nested_meta(|meta| {
21            if meta.path.is_ident(option_name) {
22                found = true;
23            }
24            // Continue parsing
25            Ok(())
26        });
27
28        if found {
29            return true;
30        }
31    }
32    false
33}
34
35/// Generate a canonical structure type for a type `T` by decorating it with `#[derive(CanonicalStructure)]`.
36/// Optional additional structure can be generated by adding `#[canonical_structure(eq, partial_ord, ord)]`.
37/// The type must implement `Debug` and `Clone`. The optional additional structures may require `T` to implement further traits.
38/// Requires `MetaType`, `Signature`, and `SetSignature` to be in scope. The optional additional structures may require further items to be in scope.
39///
40/// # Example
41/// ```rust,ignore
42/// #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, CanonicalStructure)]
43/// #[canonical_structure(eq, partial_ord, ord)]
44/// struct MyValue {
45///     data: i64,
46///     more_data: u32,
47/// }
48/// ```
49/// `#[derive(CanonicalStructure)]` generates the following
50/// ```rust,ignore
51/// #[derive(Debug, Clone, PartialEq, Eq)]
52/// struct MyValueCanonicalStructure {}
53///
54/// impl Signature for MyValueCanonicalStructure {}
55///
56/// impl MyValueCanonicalStructure {
57///     fn new() -> Self {
58///         Self {}
59///     }
60/// }
61///
62/// impl SetSignature for MyValueCanonicalStructure {
63///     type Set = MyValue;
64///     fn validate_element(&self, _x: &Self::Set) -> Result<(), String> {
65///         Ok(())
66///     }
67/// }
68///
69/// impl MetaType for MyValue {
70///     type Signature = MyValueCanonicalStructure;
71///     fn structure() -> Self::Signature {
72///         MyValueCanonicalStructure::new()
73///     }
74/// }
75///
76/// impl MyValue {
77///     pub fn structure_ref() -> &'static MyValueCanonicalStructure {
78///         static CELL: std::sync::OnceLock<MyValueCanonicalStructure> = std::sync::OnceLock::new();
79///         CELL.get_or_init(|| MyValueCanonicalStructure::new())
80///     }
81/// }
82/// ```
83///
84/// `#[canonical_structure(eq)]` requires `MyValue: Eq`, and `EqSignature` to be in scope. It generates the following
85/// ```rust,ignore
86/// impl EqSignature for MyValueCanonicalStructure
87/// where
88///     MyValue: Eq,
89/// {
90///     fn equal(&self, a: &Self::Set, b: &Self::Set) -> bool {
91///         a == b
92///     }
93/// }
94/// ```
95///
96/// `#[canonical_structure(partial_eq)]` requires `#[canonical_structure(eq)]`, `MyValue: PartialEq`, and `PartialEqSignature` to be in scope. It generates the following
97/// ```rust,ignore
98/// impl PartialOrdSignature for MyValueCanonicalStructure
99/// where
100///     MyValue: Ord,
101/// {
102///     fn partial_cmp(&self, a: &Self::Set, b: &Self::Set) -> Option<std::cmp::Ordering> {
103///         Some(a.cmp(b))
104///     }
105/// }
106/// ```
107///
108/// `#[canonical_structure(ord)]` requires `#[canonical_structure(partial_eq)]`, `MyValue: Ord`, and `OrdSignature` to be in scope. It generates the following
109/// ```rust,ignore
110/// impl OrdSignature for MyValueCanonicalStructure
111/// where
112///     MyValue: Ord,
113/// {
114///     fn cmp(&self, a: &Self::Set, b: &Self::Set) -> std::cmp::Ordering {
115///         a.cmp(b)
116///     }
117///     fn sort<S: std::borrow::Borrow<Self::Set>>(&self, mut a: Vec<S>) -> Vec<S> {
118///         a.sort_unstable_by(|x, y| x.borrow().cmp(y.borrow()));
119///         a
120///     }
121/// }
122/// ```
123#[proc_macro_derive(CanonicalStructure, attributes(canonical_structure))]
124pub fn derive_newtype(input: TokenStream) -> TokenStream {
125    let input = parse_macro_input!(input as DeriveInput);
126
127    let name = input.ident;
128    let vis = input.vis;
129    let newtype_name = Ident::new(&format!("{name}CanonicalStructure"), name.span());
130
131    let has_eq = has_option(&input.attrs, "eq");
132    let has_partial_ord = has_option(&input.attrs, "partial_ord");
133    let has_ord = has_option(&input.attrs, "ord");
134
135    let impl_eq_signature = if has_eq {
136        quote! {
137            impl EqSignature for #newtype_name
138                where #name: Eq
139            {
140                fn equal(&self, a: &Self::Set, b: &Self::Set) -> bool {
141                    a == b
142                }
143            }
144        }
145    } else {
146        quote! {}
147    };
148
149    let impl_partial_ord_signature = if has_partial_ord {
150        quote! {
151            impl PartialOrdSignature for #newtype_name
152                where #name: Ord
153            {
154                fn partial_cmp(&self, a: &Self::Set, b: &Self::Set) -> Option<std::cmp::Ordering> {
155                    Some(a.cmp(b))
156                }
157            }
158        }
159    } else {
160        quote! {}
161    };
162
163    let impl_ord_signature = if has_ord {
164        quote! {
165            impl OrdSignature for #newtype_name
166                where #name: Ord
167            {
168                fn cmp(&self, a: &Self::Set, b: &Self::Set) -> std::cmp::Ordering {
169                    a.cmp(b)
170                }
171
172                fn sort<S: std::borrow::Borrow<Self::Set>>(&self, mut a: Vec<S>) -> Vec<S> {
173                    a.sort_unstable_by(|x, y| x.borrow().cmp(y.borrow()));
174                    a
175                }
176            }
177        }
178    } else {
179        quote! {}
180    };
181
182    let expanded = quote! {
183        #[derive(Debug, Clone, PartialEq, Eq)]
184        #vis struct #newtype_name {}
185
186        impl #newtype_name {
187            fn new() -> Self {
188                Self {}
189            }
190        }
191
192        impl Signature for #newtype_name {}
193
194        impl SetSignature for #newtype_name {
195            type Set = #name;
196
197            fn validate_element(&self, _x : &Self::Set) -> Result<(), String> {
198                Ok(())
199            }
200        }
201
202        #impl_eq_signature
203        #impl_partial_ord_signature
204        #impl_ord_signature
205
206        impl MetaType for #name {
207            type Signature = #newtype_name;
208
209            fn structure() -> Self::Signature {
210                #newtype_name::new()
211            }
212        }
213
214        impl #name {
215            pub fn structure_ref() -> &'static #newtype_name{
216                static CELL: std::sync::OnceLock<#newtype_name> = std::sync::OnceLock::new();
217                CELL.get_or_init(|| #newtype_name::new())
218            }
219        }
220    };
221
222    TokenStream::from(expanded)
223}
224
225/// In a structure trait decorated with `#[proc_macro_attribute]`, decorate a method with `#[skip_meta]` to exclude it from the auto-generated a meta structure trait.
226///
227/// # Example
228/// The decorated structure trait
229/// ```rust,ignore
230/// #[signature_meta_trait]
231/// pub trait MySignature: SetSignature {
232///     fn special_element(&self) -> Self::Set;
233///     #[skip_meta]
234///     fn binary_operation(&self, a: &Self::Set, b: &Self::Set) -> Self::Set;
235/// }
236/// ```
237/// produces the following meta structure trait.
238/// ```rust,ignore
239/// pub trait MetaMySignature: MetaType
240/// where
241///     Self::Signature: MySignature,
242/// {
243///     fn special_element() -> Self {
244///         Self::structure().special_element()
245///     }
246/// }
247/// ```
248#[proc_macro_attribute]
249pub fn skip_meta(_attr: TokenStream, item: TokenStream) -> TokenStream {
250    item
251}
252
253/// Decorate a structure trait with this to auto-generate a meta structure trait.
254///
255/// # Example
256/// The decorated structure trait
257/// ```rust,ignore
258/// #[signature_meta_trait]
259/// pub trait MySignature: SetSignature {
260///     fn special_element(&self) -> Self::Set;
261///     fn binary_operation(&self, a: &Self::Set, b: &Self::Set) -> Self::Set;
262/// }
263/// ```
264/// produces the following meta structure trait,
265/// ```rust,ignore
266/// pub trait MetaMySignature: MetaType
267/// where
268///     Self::Signature: MySignature,
269/// {
270///     fn special_element() -> Self {
271///         Self::structure().special_element()
272///     }
273///     fn binary_operation(&self, b: &Self) -> Self {
274///         Self::structure().binary_operation(self, b)
275///     }
276/// }
277/// ```
278/// and auto-implementation for meta structure types.
279/// ```rust,ignore
280/// impl<T> MetaMySignature for T
281/// where
282///     T: MetaType,
283///     T::Signature: MySignature,
284/// {
285/// }
286/// ```
287#[proc_macro_attribute]
288pub fn signature_meta_trait(_args: TokenStream, input: TokenStream) -> TokenStream {
289    let trait_item = parse_macro_input!(input as ItemTrait);
290
291    let expanded = expand_meta_trait(&trait_item);
292
293    quote! {
294        #trait_item
295        #expanded
296    }
297    .into()
298}
299
300/// Expand MetaTrait + impl
301fn expand_meta_trait(trait_item: &ItemTrait) -> proc_macro2::TokenStream {
302    let sig_trait_ident = &trait_item.ident;
303    let meta_trait_ident = Ident::new(&format!("Meta{}", sig_trait_ident), Span::call_site());
304
305    let mut meta_methods = Vec::new();
306
307    for item in &trait_item.items {
308        if let TraitItem::Fn(TraitItemFn { attrs, sig, .. }) = item {
309            if attrs.iter().any(|attr| attr.path().is_ident("skip_meta")) {
310                continue;
311            }
312
313            let mut meta_sig = sig.clone();
314            // Check the first argument is self, &self, or &mut self
315            if let Some(first_arg) = meta_sig.inputs.first() {
316                match first_arg {
317                    FnArg::Receiver(_) => {
318                        meta_sig.inputs = meta_sig.inputs.into_iter().skip(1).collect();
319                        ReplaceSelfSetSignature {
320                            sig_trait_ident: sig_trait_ident.clone(),
321                        }
322                        .visit_signature_mut(&mut meta_sig);
323
324                        let ident = meta_sig.ident.clone();
325
326                        let mut meta_args = Vec::new();
327                        #[allow(clippy::never_loop)]
328                        for arg in &mut meta_sig.inputs {
329                            match arg {
330                                FnArg::Typed(pat_type) => match pat_type.pat.as_mut() {
331                                    syn::Pat::Ident(pat_ident) => {
332                                        pat_ident.mutability = None;
333                                        meta_args.push(pat_ident.clone());
334                                    }
335                                    _ => {
336                                        return Error::new_spanned(
337                                        trait_item,
338                                        "Invalid pattern in argument list. Must be a plain Ident.",
339                                    )
340                                    .to_compile_error();
341                                    }
342                                },
343                                FnArg::Receiver(_) => {
344                                    panic!();
345                                }
346                            }
347                        }
348
349                        if let Some(first) = sig.inputs.iter().nth(1) {
350                            match first {
351                                FnArg::Receiver(_) => {}
352                                FnArg::Typed(pat_type) => match pat_type.ty.as_ref() {
353                                    syn::Type::Reference(type_reference) => {
354                                        if let syn::Type::Path(type_path) =
355                                            type_reference.elem.as_ref()
356                                            && is_type_path_self_set(type_path)
357                                        {
358                                            // if the first argument is `a: &Self::Set` then replace it with `&self` in the meta type
359                                            // if the first argument is `a: &mut Self::Set` then replace it with `&mut self` in the meta type
360                                            meta_args[0] = PatIdent {
361                                                attrs: vec![],
362                                                by_ref: None,
363                                                mutability: None,
364                                                ident: Ident::new("self", Span::call_site()),
365                                                subpat: None,
366                                            };
367                                            meta_sig.inputs[0] = FnArg::Receiver(Receiver {
368                                                attrs: vec![],
369                                                reference: Some((
370                                                    syn::token::And {
371                                                        spans: [Span::call_site()],
372                                                    },
373                                                    None,
374                                                )),
375                                                mutability: type_reference.mutability,
376                                                self_token: syn::token::SelfValue {
377                                                    span: Span::call_site(),
378                                                },
379                                                colon_token: None,
380                                                ty: Box::new(syn::Type::Reference(
381                                                    syn::TypeReference {
382                                                        and_token: syn::token::And {
383                                                            spans: [Span::call_site()],
384                                                        },
385                                                        lifetime: None,
386                                                        mutability: type_reference.mutability,
387                                                        elem: Box::new(syn::Type::Path(
388                                                            syn::TypePath {
389                                                                qself: None,
390                                                                path: syn::Path::from(Ident::new(
391                                                                    "Self",
392                                                                    Span::call_site(),
393                                                                )),
394                                                            },
395                                                        )),
396                                                    },
397                                                )),
398                                            });
399                                        }
400                                    }
401                                    syn::Type::Path(type_path) => {
402                                        // if the first argument is `a: Self::Set` then replace it with `self` in the meta type (TODO)
403                                        if is_type_path_self_set(type_path) {
404                                            meta_args[0] = PatIdent {
405                                                attrs: vec![],
406                                                by_ref: None,
407                                                mutability: None,
408                                                ident: Ident::new("self", Span::call_site()),
409                                                subpat: None,
410                                            };
411                                            meta_sig.inputs[0] = FnArg::Receiver(Receiver {
412                                                attrs: vec![],
413                                                reference: None,
414                                                mutability: None,
415                                                self_token: syn::token::SelfValue {
416                                                    span: Span::call_site(),
417                                                },
418                                                colon_token: None,
419                                                ty: Box::new(syn::Type::Path(syn::TypePath {
420                                                    qself: None,
421                                                    path: syn::Path::from(Ident::new(
422                                                        "Self",
423                                                        Span::call_site(),
424                                                    )),
425                                                })),
426                                            });
427                                        }
428                                    }
429                                    _ => {}
430                                },
431                            }
432                        }
433
434                        meta_methods.push(quote! {
435                            #(#attrs)*
436                            #meta_sig {
437                                Self::structure().#ident(#(#meta_args),*)
438                            }
439                        });
440                    }
441                    FnArg::Typed(_) => {
442                        // Not a method receiver
443                    }
444                }
445            }
446        }
447    }
448
449    let where_clauses = if let Some(where_clause) = &trait_item.generics.where_clause {
450        let mut predicates = where_clause.predicates.clone();
451        for predicate in &mut predicates {
452            ReplaceSelfSetSignature {
453                sig_trait_ident: sig_trait_ident.clone(),
454            }
455            .visit_where_predicate_mut(predicate);
456        }
457        quote!(#predicates)
458    } else {
459        quote!()
460    };
461
462    quote! {
463        pub trait #meta_trait_ident: MetaType
464        where
465            Self::Signature: #sig_trait_ident,
466            #where_clauses
467        {
468
469            #(#meta_methods)*
470        }
471
472        impl<T> #meta_trait_ident for T
473        where
474            T: MetaType,
475            T::Signature: #sig_trait_ident,
476             #where_clauses
477        {
478        }
479    }
480}
481
482struct ReplaceSelfSetSignature {
483    sig_trait_ident: Ident,
484}
485impl VisitMut for ReplaceSelfSetSignature {
486    fn visit_type_path_mut(&mut self, ty: &mut syn::TypePath) {
487        syn::visit_mut::visit_type_path_mut(self, ty);
488        if is_type_path_self_set(ty) {
489            // Replace `Self::Set` with `Self`
490            *ty = syn::parse_quote!(Self);
491        } else if ty.qself.is_none()
492            && ty.path.segments.len() == 1
493            && ty.path.segments[0].ident == "Self"
494            && ty.path.segments[0].arguments.is_empty()
495        {
496            // Replace `Self` with `Self::Signature`
497            *ty = syn::parse_quote!(Self::Signature);
498        } else if ty.qself.is_none()
499            && ty.path.segments.len() >= 2
500            && ty.path.segments[0].ident == "Self"
501            && ty.path.segments[0].arguments.is_empty()
502        {
503            // Replace `Self::Foo::Bar` with `<Self::Signature as #sig_trait_ident>::Foo::Bar`
504            let sig_trait_ident = &self.sig_trait_ident;
505            ty.path.segments[0] = syn::parse_quote!(#sig_trait_ident);
506            ty.qself = Some(syn::QSelf {
507                lt_token: syn::token::Lt {
508                    spans: [Span::call_site()],
509                },
510                ty: syn::parse_quote!(Self::Signature),
511                position: 1,
512                as_token: Some(syn::token::As {
513                    span: Span::call_site(),
514                }),
515                gt_token: syn::token::Gt {
516                    spans: [Span::call_site()],
517                },
518            });
519        }
520    }
521}
522
523fn is_type_path_self_set(ty: &syn::TypePath) -> bool {
524    ty.qself.is_none()
525        && ty.path.segments.len() == 2
526        && ty.path.segments[0].ident == "Self"
527        && ty.path.segments[1].ident == "Set"
528        && ty.path.segments[1].arguments.is_empty()
529}