clasma/
lib.rs

1use std::collections::BTreeSet;
2use proc_macro2::TokenStream;
3use quote::{format_ident, quote};
4use syn::{parse_macro_input, punctuated::Punctuated, visit_mut::{self, VisitMut}, FnArg, Ident, ImplItem, Item, Macro, Pat, Signature, Type};
5
6fn extract_args(fields: &BTreeSet<Ident>, sig: &Signature) -> impl Iterator<Item = Result<(Option<bool>, Ident), syn::Error>> {
7    sig.inputs.iter().map(|arg| {
8        let pat_type = match arg {
9            FnArg::Typed(x) => x,
10            FnArg::Receiver(r) => return Ok((None, Ident::new("self", r.self_token.span))),
11        };
12        let Pat::Ident(pat_ident) = &*pat_type.pat else {
13            return Err(syn::Error::new_spanned(
14                pat_type.clone(),
15                "#[clasma] arguments must be normal identifiers",
16            ))
17        };
18        if !fields.contains(&pat_ident.ident) {
19            return Ok((None, pat_ident.ident.clone()))
20        };
21        let Type::Reference(refty) = &*pat_type.ty else {
22            return Err(syn::Error::new_spanned(
23                pat_type.clone(),
24                "#[clasma] arguments must be reference types",
25            ))
26        };
27        return Ok((Some(refty.mutability.is_some()), pat_ident.ident.clone()))
28    })
29}
30
31fn handle_fn<'a>(fields: &BTreeSet<Ident>, sig: &'a Signature)
32    -> Result<(&'a Ident,Vec<TokenStream>,Vec<TokenStream>,Ident,Vec<TokenStream>,Vec<TokenStream>), syn::Error>
33{
34    let args = extract_args(fields, sig).collect::<Result<Vec<_>,_>>()?;
35    let func_name = &sig.ident;
36    let match_args: Vec<_> = args.iter()
37        .filter(|(mu,_)| mu.is_none())
38        .map(|(_,arg)| {
39            let id = format_ident!("__{arg}");
40            quote! { $#id: expr }
41        }).collect();
42
43    let expan_args: Vec<_> = args.iter().map(|(mu,id)| {
44        let &Some(mu) = mu else {
45            let matchid = format_ident!("__{id}");
46            return quote! { $#matchid }
47        };
48        return if mu {
49            quote! { &mut ($st).#id }
50        } else {
51            quote! { &($st).#id }
52        }
53    }).collect();
54
55
56    let mac_scope_name = format_ident!("{func_name}_scope");
57
58    let match_fields: Vec<_> = fields.iter().map(|field| {
59        let id = format_ident!("__{field}");
60        quote! { $#id: ident }
61    }).collect();
62    let expan_args_scope: Vec<_> = args.iter().map(|(_,id)| {
63        let id = format_ident!("__{id}");
64        quote! { $#id }
65    }).collect();
66
67    return Ok((func_name, match_args, expan_args, mac_scope_name, match_fields, expan_args_scope));
68}
69
70struct ScopeMacroVisitor<'a>(&'a BTreeSet<Ident>);
71
72impl<'a> visit_mut::VisitMut for ScopeMacroVisitor<'a> {
73    fn visit_macro_mut(&mut self, mac: &mut Macro) {
74        'blk: {
75            let Some(last_segment) = mac.path.segments.last() else { break 'blk };
76            if !last_segment.ident.to_string().ends_with("_scope") { break 'blk };
77            let original_tokens = &mac.tokens;
78            let fields = &self.0;
79            mac.tokens = quote! { [ #(#fields)* ] #original_tokens };
80        }
81        visit_mut::visit_macro_mut(self, mac);
82    }
83}
84
85#[proc_macro_attribute]
86pub fn clasma(attr: proc_macro::TokenStream, item: proc_macro::TokenStream) -> proc_macro::TokenStream {
87    let fields: BTreeSet<_> = parse_macro_input!(attr with Punctuated::<Ident, syn::Token![,]>::parse_terminated).into_iter().collect();
88    let mut item = parse_macro_input!(item as Item);
89    ScopeMacroVisitor(&fields).visit_item_mut(&mut item);
90
91    match item {
92        Item::Fn(item_fn) => {
93            let (func_name, match_args, expan_args, mac_scope_name, match_fields, expan_args_scope)
94                    = match handle_fn(&fields, &item_fn.sig) {
95                Ok(x) => x,
96                Err(x) => return x.to_compile_error().into(),
97            };
98
99            let res = quote! {
100                #item_fn
101
102                #[macro_export]
103                macro_rules! #func_name {
104                    ( < $($lt:lifetime),+ $(, $t:ty)* >, $st:expr #(, #match_args)* ) => {
105                        #func_name::< $($lt),* $(, $t)* >( #(#expan_args),* );
106                    };
107                    ( < $($t:ty),+ >, $st:expr #(, #match_args)* ) => {
108                        #func_name::< $($t),* >( #(#expan_args),* );
109                    };
110
111                    ( $st:expr #(, #match_args)* ) => {
112                        #func_name( #(#expan_args),* );
113                    };
114                }
115
116                #[macro_export]
117                macro_rules! #mac_scope_name {
118                    ( [ #(#match_fields)* ] < $($lt:lifetime),+ $(, $t:ty)* > #(, #match_args)* ) => {
119                        #func_name::< $($lt),* $(, $t)* >( #(#expan_args_scope),* );
120                    };
121                    ( [ #(#match_fields)* ] < $($t:ty),+ > #(, #match_args)* ) => {
122                        #func_name::< $($t),* >( #(#expan_args_scope),* );
123                    };
124
125                    ( [ #(#match_fields)* ] #(#match_args),* ) => {
126                        #func_name( #(#expan_args_scope),* );
127                    };
128                }
129            };
130            return res.into();
131        },
132        Item::Impl(item_impl) => {
133            if item_impl.trait_.is_some() {
134                return syn::Error::new_spanned(
135                    item_impl,
136                    "clasma::partial currently does not support `impl Trait` blocks.",
137                ).to_compile_error().into();
138            }
139            let Some(st_name) = ('blk: {
140                    let Type::Path(st_path) = &*item_impl.self_ty else { break 'blk None };
141                    let Some(st_name) = st_path.path.segments.last() else { break 'blk None };
142                    Some(&st_name.ident)
143            }) else {
144                return syn::Error::new_spanned(
145                    item_impl,
146                    "clasma::partial only supports `impl` blocks of `path::to::Type`",
147                ).to_compile_error().into();
148            };
149
150
151
152            let macs: Result<Vec<_>, syn::Error> = item_impl.items.iter().filter_map(|item| {
153                let ImplItem::Fn(f) = item else { return None };
154
155                if !f.sig.inputs.iter().any(|arg| {
156                    let FnArg::Typed(pat_type) = arg else { return false };
157                    let Pat::Ident(pat_ident) = &*pat_type.pat else { return false };
158                    fields.contains(&pat_ident.ident)
159                }) {
160                    return None
161                }
162                // TODO produce a warning, if it "seems" like user is attempting to incorrectly use macro, instead of ignoring the `Err`
163                return handle_fn(&fields, &f.sig).ok();
164            }).map(|(func_name, match_args, expan_args, mac_scope_name, match_fields, expan_args_scope)| {
165                // It's really annoying that lifetimes and types can not be parsed unambiguously in one rule.
166                // TODO use a tt-muncher to make less verbose
167                return Ok(quote! {
168                    #[macro_export]
169                    macro_rules! #func_name {
170                        ( < $($lt1:lifetime),+ $(, $t1:ty)* >::< $($lt2:lifetime),+ $(, $t2:ty)* >, $st:expr #(, #match_args)* ) => {
171                            #st_name::< $($lt1),* $(, $t1)* >::#func_name::< $($lt2),* $(, $t2)* >( #(#expan_args),* );
172                        };
173                        ( < $($lt1:lifetime),+ $(, $t1:ty)* >::< $($t2:ty),+ >, $st:expr #(, #match_args)* ) => {
174                            #st_name::< $($lt1),* $(, $t1)* >::#func_name::< $($t2),* >( #(#expan_args),* );
175                        };
176                        ( < $($t1:ty),+ >::< $($lt2:lifetime),+ $(, $t2:ty)* >, $st:expr #(, #match_args)* ) => {
177                            #st_name::< $($t1),* >::#func_name::< $($lt2),* $(, $t2)* >( #(#expan_args),* );
178                        };
179                        ( < $($t1:ty),+ >::< $($t2:ty),+ >, $st:expr #(, #match_args)* ) => {
180                            #st_name::< $($t1),* >::#func_name::< $($t2),* >( #(#expan_args),* );
181                        };
182
183                        ( < $($lt:lifetime),+ $(, $t:ty)* >::, $st:expr #(, #match_args)* ) => {
184                            #st_name::< $($lt),* $(, $t)* >::#func_name( #(#expan_args),* );
185                        };
186                        ( < $($t:ty),+ >::, $st:expr #(, #match_args)* ) => {
187                            #st_name::< $($t),* >::#func_name( #(#expan_args),* );
188                        };
189
190                        ( < $($lt:lifetime),+ $(, $t:ty)* >, $st:expr #(, #match_args)* ) => {
191                            #st_name::#func_name::< $($lt),* $(, $t)* >( #(#expan_args),* );
192                        };
193                        ( < $($t:ty)+ >, $st:expr #(, #match_args)* ) => {
194                            #st_name::#func_name::< $($t),* >( #(#expan_args),* );
195                        };
196
197                        ( $st:expr #(, #match_args)* ) => {
198                            #st_name::#func_name( #(#expan_args),* );
199                        };
200                    }
201
202                    #[macro_export]
203                    macro_rules! #mac_scope_name {
204                        ( [ #(#match_fields)* ] < $($lt1:lifetime),+ $(, $t1:ty)* >::< $($lt2:lifetime),+ $(, $t2:ty)* > #(, #match_args)* ) => {
205                            #st_name::< $($lt1),* $(, $t1)* >::#func_name::< $($lt2),* $(, $t2)* >( #(#expan_args_scope),* );
206                        };
207                        ( [ #(#match_fields)* ] < $($lt1:lifetime),+ $(, $t1:ty)* >::< $($t2:ty),+ > #(, #match_args)* ) => {
208                            #st_name::< $($lt1),* $(, $t1)* >::#func_name::< $($t2),* >( #(#expan_args_scope),* );
209                        };
210                        ( [ #(#match_fields)* ] < $($t1:ty),+ >::< $($lt2:lifetime),+ $(, $t2:ty)* > #(, #match_args)* ) => {
211                            #st_name::< $($t1),* >::#func_name::< $($lt2),* $(, $t2)* >( #(#expan_args_scope),* );
212                        };
213                        ( [ #(#match_fields)* ] < $($t1:ty),+ >::< $($t2:ty),+ > #(, #match_args)* ) => {
214                            #st_name::< $($t1),* >::#func_name::< $($t2),* >( #(#expan_args_scope),* );
215                        };
216
217                        ( [ #(#match_fields)* ] < $($lt:lifetime),+ $(, $t:ty)* >:: #(, #match_args)* ) => {
218                            #st_name::< $($lt),* $(, $t)* >::#func_name( #(#expan_args_scope),* );
219                        };
220                        ( [ #(#match_fields)* ] < $($t:ty),+ >:: #(, #match_args)* ) => {
221                            #st_name::< $($t),* >::#func_name( #(#expan_args_scope),* );
222                        };
223
224                        ( [ #(#match_fields)* ] < $($lt:lifetime),+ $(, $t:ty)* > #(, #match_args)* ) => {
225                            #st_name::#func_name::< $($lt),* $(, $t)* >( #(#expan_args_scope),* );
226                        };
227                        ( [ #(#match_fields)* ] < $($t:ty)+ > #(, #match_args)* ) => {
228                            #st_name::#func_name::< $($t),* >( #(#expan_args_scope),* );
229                        };
230
231                        ( [ #(#match_fields)* ] #(#match_args),* ) => {
232                            #st_name::#func_name( #(#expan_args_scope),* );
233                        };
234                    }
235                });
236            }).collect();
237
238            let macs = match macs {Ok(x) => x, Err(x) => return x.to_compile_error().into()};
239            let res = quote! {
240                #item_impl
241
242                #(#macs)*
243            };
244            return res.into();
245        },
246        _ => {
247            return syn::Error::new_spanned(
248                item,
249                "clasma::partial must be applied to an `fn` or `impl` block.",
250            ).to_compile_error().into()
251        },
252    }
253}