algebraeon_macros/lib.rs
1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use proc_macro2::Span;
5use quote::quote;
6use syn::visit_mut::VisitMut;
7use syn::{
8 Attribute, DeriveInput, Error, FnArg, Ident, ItemTrait, PatIdent, Receiver, TraitItem,
9 TraitItemFn, parse_macro_input,
10};
11
12fn has_option(attrs: &[Attribute], option_name: &str) -> bool {
13 for attr in attrs
14 .iter()
15 .filter(|a| a.path().is_ident("canonical_structure"))
16 {
17 let mut found = false;
18
19 // `parse_nested_meta` lets us walk through the arguments in #[canonical_structure(...)]
20 let _ = attr.parse_nested_meta(|meta| {
21 if meta.path.is_ident(option_name) {
22 found = true;
23 }
24 // Continue parsing
25 Ok(())
26 });
27
28 if found {
29 return true;
30 }
31 }
32 false
33}
34
35/// Generate a canonical structure type for a type `T` by decorating it with `#[derive(CanonicalStructure)]`.
36/// Optional additional structure can be generated by adding `#[canonical_structure(eq, partial_ord, ord)]`.
37/// The type must implement `Debug` and `Clone`. The optional additional structures may require `T` to implement further traits.
38/// Requires `MetaType`, `Signature`, and `SetSignature` to be in scope. The optional additional structures may require further items to be in scope.
39///
40/// # Example
41/// ```rust,ignore
42/// #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, CanonicalStructure)]
43/// #[canonical_structure(eq, partial_ord, ord)]
44/// struct MyValue {
45/// data: i64,
46/// more_data: u32,
47/// }
48/// ```
49/// `#[derive(CanonicalStructure)]` generates the following
50/// ```rust,ignore
51/// #[derive(Debug, Clone, PartialEq, Eq)]
52/// struct MyValueCanonicalStructure {}
53///
54/// impl Signature for MyValueCanonicalStructure {}
55///
56/// impl MyValueCanonicalStructure {
57/// fn new() -> Self {
58/// Self {}
59/// }
60/// }
61///
62/// impl SetSignature for MyValueCanonicalStructure {
63/// type Set = MyValue;
64/// fn validate_element(&self, _x: &Self::Set) -> Result<(), String> {
65/// Ok(())
66/// }
67/// }
68///
69/// impl MetaType for MyValue {
70/// type Signature = MyValueCanonicalStructure;
71/// fn structure() -> Self::Signature {
72/// MyValueCanonicalStructure::new()
73/// }
74/// }
75///
76/// impl MyValue {
77/// pub fn structure_ref() -> &'static MyValueCanonicalStructure {
78/// static CELL: std::sync::OnceLock<MyValueCanonicalStructure> = std::sync::OnceLock::new();
79/// CELL.get_or_init(|| MyValueCanonicalStructure::new())
80/// }
81/// }
82/// ```
83///
84/// `#[canonical_structure(eq)]` requires `MyValue: Eq`, and `EqSignature` to be in scope. It generates the following
85/// ```rust,ignore
86/// impl EqSignature for MyValueCanonicalStructure
87/// where
88/// MyValue: Eq,
89/// {
90/// fn equal(&self, a: &Self::Set, b: &Self::Set) -> bool {
91/// a == b
92/// }
93/// }
94/// ```
95///
96/// `#[canonical_structure(partial_eq)]` requires `#[canonical_structure(eq)]`, `MyValue: PartialEq`, and `PartialEqSignature` to be in scope. It generates the following
97/// ```rust,ignore
98/// impl PartialOrdSignature for MyValueCanonicalStructure
99/// where
100/// MyValue: Ord,
101/// {
102/// fn partial_cmp(&self, a: &Self::Set, b: &Self::Set) -> Option<std::cmp::Ordering> {
103/// Some(a.cmp(b))
104/// }
105/// }
106/// ```
107///
108/// `#[canonical_structure(ord)]` requires `#[canonical_structure(partial_eq)]`, `MyValue: Ord`, and `OrdSignature` to be in scope. It generates the following
109/// ```rust,ignore
110/// impl OrdSignature for MyValueCanonicalStructure
111/// where
112/// MyValue: Ord,
113/// {
114/// fn cmp(&self, a: &Self::Set, b: &Self::Set) -> std::cmp::Ordering {
115/// a.cmp(b)
116/// }
117/// fn sort<S: std::borrow::Borrow<Self::Set>>(&self, mut a: Vec<S>) -> Vec<S> {
118/// a.sort_unstable_by(|x, y| x.borrow().cmp(y.borrow()));
119/// a
120/// }
121/// }
122/// ```
123#[proc_macro_derive(CanonicalStructure, attributes(canonical_structure))]
124pub fn derive_newtype(input: TokenStream) -> TokenStream {
125 let input = parse_macro_input!(input as DeriveInput);
126
127 let name = input.ident;
128 let vis = input.vis;
129 let newtype_name = Ident::new(&format!("{name}CanonicalStructure"), name.span());
130
131 let has_eq = has_option(&input.attrs, "eq");
132 let has_partial_ord = has_option(&input.attrs, "partial_ord");
133 let has_ord = has_option(&input.attrs, "ord");
134
135 let impl_eq_signature = if has_eq {
136 quote! {
137 impl EqSignature for #newtype_name
138 where #name: Eq
139 {
140 fn equal(&self, a: &Self::Set, b: &Self::Set) -> bool {
141 a == b
142 }
143 }
144 }
145 } else {
146 quote! {}
147 };
148
149 let impl_partial_ord_signature = if has_partial_ord {
150 quote! {
151 impl PartialOrdSignature for #newtype_name
152 where #name: Ord
153 {
154 fn partial_cmp(&self, a: &Self::Set, b: &Self::Set) -> Option<std::cmp::Ordering> {
155 Some(a.cmp(b))
156 }
157 }
158 }
159 } else {
160 quote! {}
161 };
162
163 let impl_ord_signature = if has_ord {
164 quote! {
165 impl OrdSignature for #newtype_name
166 where #name: Ord
167 {
168 fn cmp(&self, a: &Self::Set, b: &Self::Set) -> std::cmp::Ordering {
169 a.cmp(b)
170 }
171
172 fn sort<S: std::borrow::Borrow<Self::Set>>(&self, mut a: Vec<S>) -> Vec<S> {
173 a.sort_unstable_by(|x, y| x.borrow().cmp(y.borrow()));
174 a
175 }
176 }
177 }
178 } else {
179 quote! {}
180 };
181
182 let expanded = quote! {
183 #[derive(Debug, Clone, PartialEq, Eq)]
184 #vis struct #newtype_name {}
185
186 impl #newtype_name {
187 fn new() -> Self {
188 Self {}
189 }
190 }
191
192 impl Signature for #newtype_name {}
193
194 impl SetSignature for #newtype_name {
195 type Set = #name;
196
197 fn validate_element(&self, _x : &Self::Set) -> Result<(), String> {
198 Ok(())
199 }
200 }
201
202 #impl_eq_signature
203 #impl_partial_ord_signature
204 #impl_ord_signature
205
206 impl MetaType for #name {
207 type Signature = #newtype_name;
208
209 fn structure() -> Self::Signature {
210 #newtype_name::new()
211 }
212 }
213
214 impl #name {
215 pub fn structure_ref() -> &'static #newtype_name{
216 static CELL: std::sync::OnceLock<#newtype_name> = std::sync::OnceLock::new();
217 CELL.get_or_init(|| #newtype_name::new())
218 }
219 }
220 };
221
222 TokenStream::from(expanded)
223}
224
225/// In a structure trait decorated with `#[proc_macro_attribute]`, decorate a method with `#[skip_meta]` to exclude it from the auto-generated a meta structure trait.
226///
227/// # Example
228/// The decorated structure trait
229/// ```rust,ignore
230/// #[signature_meta_trait]
231/// pub trait MySignature: SetSignature {
232/// fn special_element(&self) -> Self::Set;
233/// #[skip_meta]
234/// fn binary_operation(&self, a: &Self::Set, b: &Self::Set) -> Self::Set;
235/// }
236/// ```
237/// produces the following meta structure trait.
238/// ```rust,ignore
239/// pub trait MetaMySignature: MetaType
240/// where
241/// Self::Signature: MySignature,
242/// {
243/// fn special_element() -> Self {
244/// Self::structure().special_element()
245/// }
246/// }
247/// ```
248#[proc_macro_attribute]
249pub fn skip_meta(_attr: TokenStream, item: TokenStream) -> TokenStream {
250 item
251}
252
253/// Decorate a structure trait with this to auto-generate a meta structure trait.
254///
255/// # Example
256/// The decorated structure trait
257/// ```rust,ignore
258/// #[signature_meta_trait]
259/// pub trait MySignature: SetSignature {
260/// fn special_element(&self) -> Self::Set;
261/// fn binary_operation(&self, a: &Self::Set, b: &Self::Set) -> Self::Set;
262/// }
263/// ```
264/// produces the following meta structure trait,
265/// ```rust,ignore
266/// pub trait MetaMySignature: MetaType
267/// where
268/// Self::Signature: MySignature,
269/// {
270/// fn special_element() -> Self {
271/// Self::structure().special_element()
272/// }
273/// fn binary_operation(&self, b: &Self) -> Self {
274/// Self::structure().binary_operation(self, b)
275/// }
276/// }
277/// ```
278/// and auto-implementation for meta structure types.
279/// ```rust,ignore
280/// impl<T> MetaMySignature for T
281/// where
282/// T: MetaType,
283/// T::Signature: MySignature,
284/// {
285/// }
286/// ```
287#[proc_macro_attribute]
288pub fn signature_meta_trait(_args: TokenStream, input: TokenStream) -> TokenStream {
289 let trait_item = parse_macro_input!(input as ItemTrait);
290
291 let expanded = expand_meta_trait(&trait_item);
292
293 quote! {
294 #trait_item
295 #expanded
296 }
297 .into()
298}
299
300/// Expand MetaTrait + impl
301fn expand_meta_trait(trait_item: &ItemTrait) -> proc_macro2::TokenStream {
302 let sig_trait_ident = &trait_item.ident;
303 let meta_trait_ident = Ident::new(&format!("Meta{}", sig_trait_ident), Span::call_site());
304
305 let mut meta_methods = Vec::new();
306
307 for item in &trait_item.items {
308 if let TraitItem::Fn(TraitItemFn { attrs, sig, .. }) = item {
309 if attrs.iter().any(|attr| attr.path().is_ident("skip_meta")) {
310 continue;
311 }
312
313 let mut meta_sig = sig.clone();
314 // Check the first argument is self, &self, or &mut self
315 if let Some(first_arg) = meta_sig.inputs.first() {
316 match first_arg {
317 FnArg::Receiver(_) => {
318 meta_sig.inputs = meta_sig.inputs.into_iter().skip(1).collect();
319 ReplaceSelfSetSignature {
320 sig_trait_ident: sig_trait_ident.clone(),
321 }
322 .visit_signature_mut(&mut meta_sig);
323
324 let ident = meta_sig.ident.clone();
325
326 let mut meta_args = Vec::new();
327 #[allow(clippy::never_loop)]
328 for arg in &mut meta_sig.inputs {
329 match arg {
330 FnArg::Typed(pat_type) => match pat_type.pat.as_mut() {
331 syn::Pat::Ident(pat_ident) => {
332 pat_ident.mutability = None;
333 meta_args.push(pat_ident.clone());
334 }
335 _ => {
336 return Error::new_spanned(
337 trait_item,
338 "Invalid pattern in argument list. Must be a plain Ident.",
339 )
340 .to_compile_error();
341 }
342 },
343 FnArg::Receiver(_) => {
344 panic!();
345 }
346 }
347 }
348
349 if let Some(first) = sig.inputs.iter().nth(1) {
350 match first {
351 FnArg::Receiver(_) => {}
352 FnArg::Typed(pat_type) => match pat_type.ty.as_ref() {
353 syn::Type::Reference(type_reference) => {
354 if let syn::Type::Path(type_path) =
355 type_reference.elem.as_ref()
356 && is_type_path_self_set(type_path)
357 {
358 // if the first argument is `a: &Self::Set` then replace it with `&self` in the meta type
359 // if the first argument is `a: &mut Self::Set` then replace it with `&mut self` in the meta type
360 meta_args[0] = PatIdent {
361 attrs: vec![],
362 by_ref: None,
363 mutability: None,
364 ident: Ident::new("self", Span::call_site()),
365 subpat: None,
366 };
367 meta_sig.inputs[0] = FnArg::Receiver(Receiver {
368 attrs: vec![],
369 reference: Some((
370 syn::token::And {
371 spans: [Span::call_site()],
372 },
373 None,
374 )),
375 mutability: type_reference.mutability,
376 self_token: syn::token::SelfValue {
377 span: Span::call_site(),
378 },
379 colon_token: None,
380 ty: Box::new(syn::Type::Reference(
381 syn::TypeReference {
382 and_token: syn::token::And {
383 spans: [Span::call_site()],
384 },
385 lifetime: None,
386 mutability: type_reference.mutability,
387 elem: Box::new(syn::Type::Path(
388 syn::TypePath {
389 qself: None,
390 path: syn::Path::from(Ident::new(
391 "Self",
392 Span::call_site(),
393 )),
394 },
395 )),
396 },
397 )),
398 });
399 }
400 }
401 syn::Type::Path(type_path) => {
402 // if the first argument is `a: Self::Set` then replace it with `self` in the meta type (TODO)
403 if is_type_path_self_set(type_path) {
404 meta_args[0] = PatIdent {
405 attrs: vec![],
406 by_ref: None,
407 mutability: None,
408 ident: Ident::new("self", Span::call_site()),
409 subpat: None,
410 };
411 meta_sig.inputs[0] = FnArg::Receiver(Receiver {
412 attrs: vec![],
413 reference: None,
414 mutability: None,
415 self_token: syn::token::SelfValue {
416 span: Span::call_site(),
417 },
418 colon_token: None,
419 ty: Box::new(syn::Type::Path(syn::TypePath {
420 qself: None,
421 path: syn::Path::from(Ident::new(
422 "Self",
423 Span::call_site(),
424 )),
425 })),
426 });
427 }
428 }
429 _ => {}
430 },
431 }
432 }
433
434 meta_methods.push(quote! {
435 #(#attrs)*
436 #meta_sig {
437 Self::structure().#ident(#(#meta_args),*)
438 }
439 });
440 }
441 FnArg::Typed(_) => {
442 // Not a method receiver
443 }
444 }
445 }
446 }
447 }
448
449 let where_clauses = if let Some(where_clause) = &trait_item.generics.where_clause {
450 let mut predicates = where_clause.predicates.clone();
451 for predicate in &mut predicates {
452 ReplaceSelfSetSignature {
453 sig_trait_ident: sig_trait_ident.clone(),
454 }
455 .visit_where_predicate_mut(predicate);
456 }
457 quote!(#predicates)
458 } else {
459 quote!()
460 };
461
462 quote! {
463 pub trait #meta_trait_ident: MetaType
464 where
465 Self::Signature: #sig_trait_ident,
466 #where_clauses
467 {
468
469 #(#meta_methods)*
470 }
471
472 impl<T> #meta_trait_ident for T
473 where
474 T: MetaType,
475 T::Signature: #sig_trait_ident,
476 #where_clauses
477 {
478 }
479 }
480}
481
482struct ReplaceSelfSetSignature {
483 sig_trait_ident: Ident,
484}
485impl VisitMut for ReplaceSelfSetSignature {
486 fn visit_type_path_mut(&mut self, ty: &mut syn::TypePath) {
487 syn::visit_mut::visit_type_path_mut(self, ty);
488 if is_type_path_self_set(ty) {
489 // Replace `Self::Set` with `Self`
490 *ty = syn::parse_quote!(Self);
491 } else if ty.qself.is_none()
492 && ty.path.segments.len() == 1
493 && ty.path.segments[0].ident == "Self"
494 && ty.path.segments[0].arguments.is_empty()
495 {
496 // Replace `Self` with `Self::Signature`
497 *ty = syn::parse_quote!(Self::Signature);
498 } else if ty.qself.is_none()
499 && ty.path.segments.len() >= 2
500 && ty.path.segments[0].ident == "Self"
501 && ty.path.segments[0].arguments.is_empty()
502 {
503 // Replace `Self::Foo::Bar` with `<Self::Signature as #sig_trait_ident>::Foo::Bar`
504 let sig_trait_ident = &self.sig_trait_ident;
505 ty.path.segments[0] = syn::parse_quote!(#sig_trait_ident);
506 ty.qself = Some(syn::QSelf {
507 lt_token: syn::token::Lt {
508 spans: [Span::call_site()],
509 },
510 ty: syn::parse_quote!(Self::Signature),
511 position: 1,
512 as_token: Some(syn::token::As {
513 span: Span::call_site(),
514 }),
515 gt_token: syn::token::Gt {
516 spans: [Span::call_site()],
517 },
518 });
519 }
520 }
521}
522
523fn is_type_path_self_set(ty: &syn::TypePath) -> bool {
524 ty.qself.is_none()
525 && ty.path.segments.len() == 2
526 && ty.path.segments[0].ident == "Self"
527 && ty.path.segments[1].ident == "Set"
528 && ty.path.segments[1].arguments.is_empty()
529}