1use std::collections::BTreeSet;
2use proc_macro2::TokenStream;
3use quote::{format_ident, quote};
4use syn::{parse_macro_input, punctuated::Punctuated, visit_mut::{self, VisitMut}, FnArg, Ident, ImplItem, Item, Macro, Pat, Signature, Type};
5
6fn extract_args(fields: &BTreeSet<Ident>, sig: &Signature) -> impl Iterator<Item = Result<(Option<bool>, Ident), syn::Error>> {
7 sig.inputs.iter().map(|arg| {
8 let pat_type = match arg {
9 FnArg::Typed(x) => x,
10 FnArg::Receiver(r) => return Ok((None, Ident::new("self", r.self_token.span))),
11 };
12 let Pat::Ident(pat_ident) = &*pat_type.pat else {
13 return Err(syn::Error::new_spanned(
14 pat_type.clone(),
15 "#[clasma] arguments must be normal identifiers",
16 ))
17 };
18 if !fields.contains(&pat_ident.ident) {
19 return Ok((None, pat_ident.ident.clone()))
20 };
21 let Type::Reference(refty) = &*pat_type.ty else {
22 return Err(syn::Error::new_spanned(
23 pat_type.clone(),
24 "#[clasma] arguments must be reference types",
25 ))
26 };
27 return Ok((Some(refty.mutability.is_some()), pat_ident.ident.clone()))
28 })
29}
30
31fn handle_fn<'a>(fields: &BTreeSet<Ident>, sig: &'a Signature)
32 -> Result<(&'a Ident,Vec<TokenStream>,Vec<TokenStream>,Ident,Vec<TokenStream>,Vec<TokenStream>), syn::Error>
33{
34 let args = extract_args(fields, sig).collect::<Result<Vec<_>,_>>()?;
35 let func_name = &sig.ident;
36 let match_args: Vec<_> = args.iter()
37 .filter(|(mu,_)| mu.is_none())
38 .map(|(_,arg)| {
39 let id = format_ident!("__{arg}");
40 quote! { $#id: expr }
41 }).collect();
42
43 let expan_args: Vec<_> = args.iter().map(|(mu,id)| {
44 let &Some(mu) = mu else {
45 let matchid = format_ident!("__{id}");
46 return quote! { $#matchid }
47 };
48 return if mu {
49 quote! { &mut ($st).#id }
50 } else {
51 quote! { &($st).#id }
52 }
53 }).collect();
54
55
56 let mac_scope_name = format_ident!("{func_name}_scope");
57
58 let match_fields: Vec<_> = fields.iter().map(|field| {
59 let id = format_ident!("__{field}");
60 quote! { $#id: ident }
61 }).collect();
62 let expan_args_scope: Vec<_> = args.iter().map(|(_,id)| {
63 let id = format_ident!("__{id}");
64 quote! { $#id }
65 }).collect();
66
67 return Ok((func_name, match_args, expan_args, mac_scope_name, match_fields, expan_args_scope));
68}
69
70struct ScopeMacroVisitor<'a>(&'a BTreeSet<Ident>);
71
72impl<'a> visit_mut::VisitMut for ScopeMacroVisitor<'a> {
73 fn visit_macro_mut(&mut self, mac: &mut Macro) {
74 'blk: {
75 let Some(last_segment) = mac.path.segments.last() else { break 'blk };
76 if !last_segment.ident.to_string().ends_with("_scope") { break 'blk };
77 let original_tokens = &mac.tokens;
78 let fields = &self.0;
79 mac.tokens = quote! { [ #(#fields)* ] #original_tokens };
80 }
81 visit_mut::visit_macro_mut(self, mac);
82 }
83}
84
85#[proc_macro_attribute]
86pub fn clasma(attr: proc_macro::TokenStream, item: proc_macro::TokenStream) -> proc_macro::TokenStream {
87 let fields: BTreeSet<_> = parse_macro_input!(attr with Punctuated::<Ident, syn::Token![,]>::parse_terminated).into_iter().collect();
88 let mut item = parse_macro_input!(item as Item);
89 ScopeMacroVisitor(&fields).visit_item_mut(&mut item);
90
91 match item {
92 Item::Fn(item_fn) => {
93 let (func_name, match_args, expan_args, mac_scope_name, match_fields, expan_args_scope)
94 = match handle_fn(&fields, &item_fn.sig) {
95 Ok(x) => x,
96 Err(x) => return x.to_compile_error().into(),
97 };
98
99 let res = quote! {
100 #item_fn
101
102 #[macro_export]
103 macro_rules! #func_name {
104 ( < $($lt:lifetime),+ $(, $t:ty)* >, $st:expr #(, #match_args)* ) => {
105 #func_name::< $($lt),* $(, $t)* >( #(#expan_args),* );
106 };
107 ( < $($t:ty),+ >, $st:expr #(, #match_args)* ) => {
108 #func_name::< $($t),* >( #(#expan_args),* );
109 };
110
111 ( $st:expr #(, #match_args)* ) => {
112 #func_name( #(#expan_args),* );
113 };
114 }
115
116 #[macro_export]
117 macro_rules! #mac_scope_name {
118 ( [ #(#match_fields)* ] < $($lt:lifetime),+ $(, $t:ty)* > #(, #match_args)* ) => {
119 #func_name::< $($lt),* $(, $t)* >( #(#expan_args_scope),* );
120 };
121 ( [ #(#match_fields)* ] < $($t:ty),+ > #(, #match_args)* ) => {
122 #func_name::< $($t),* >( #(#expan_args_scope),* );
123 };
124
125 ( [ #(#match_fields)* ] #(#match_args),* ) => {
126 #func_name( #(#expan_args_scope),* );
127 };
128 }
129 };
130 return res.into();
131 },
132 Item::Impl(item_impl) => {
133 if item_impl.trait_.is_some() {
134 return syn::Error::new_spanned(
135 item_impl,
136 "clasma::partial currently does not support `impl Trait` blocks.",
137 ).to_compile_error().into();
138 }
139 let Some(st_name) = ('blk: {
140 let Type::Path(st_path) = &*item_impl.self_ty else { break 'blk None };
141 let Some(st_name) = st_path.path.segments.last() else { break 'blk None };
142 Some(&st_name.ident)
143 }) else {
144 return syn::Error::new_spanned(
145 item_impl,
146 "clasma::partial only supports `impl` blocks of `path::to::Type`",
147 ).to_compile_error().into();
148 };
149
150
151
152 let macs: Result<Vec<_>, syn::Error> = item_impl.items.iter().filter_map(|item| {
153 let ImplItem::Fn(f) = item else { return None };
154
155 if !f.sig.inputs.iter().any(|arg| {
156 let FnArg::Typed(pat_type) = arg else { return false };
157 let Pat::Ident(pat_ident) = &*pat_type.pat else { return false };
158 fields.contains(&pat_ident.ident)
159 }) {
160 return None
161 }
162 return handle_fn(&fields, &f.sig).ok();
164 }).map(|(func_name, match_args, expan_args, mac_scope_name, match_fields, expan_args_scope)| {
165 return Ok(quote! {
168 #[macro_export]
169 macro_rules! #func_name {
170 ( < $($lt1:lifetime),+ $(, $t1:ty)* >::< $($lt2:lifetime),+ $(, $t2:ty)* >, $st:expr #(, #match_args)* ) => {
171 #st_name::< $($lt1),* $(, $t1)* >::#func_name::< $($lt2),* $(, $t2)* >( #(#expan_args),* );
172 };
173 ( < $($lt1:lifetime),+ $(, $t1:ty)* >::< $($t2:ty),+ >, $st:expr #(, #match_args)* ) => {
174 #st_name::< $($lt1),* $(, $t1)* >::#func_name::< $($t2),* >( #(#expan_args),* );
175 };
176 ( < $($t1:ty),+ >::< $($lt2:lifetime),+ $(, $t2:ty)* >, $st:expr #(, #match_args)* ) => {
177 #st_name::< $($t1),* >::#func_name::< $($lt2),* $(, $t2)* >( #(#expan_args),* );
178 };
179 ( < $($t1:ty),+ >::< $($t2:ty),+ >, $st:expr #(, #match_args)* ) => {
180 #st_name::< $($t1),* >::#func_name::< $($t2),* >( #(#expan_args),* );
181 };
182
183 ( < $($lt:lifetime),+ $(, $t:ty)* >::, $st:expr #(, #match_args)* ) => {
184 #st_name::< $($lt),* $(, $t)* >::#func_name( #(#expan_args),* );
185 };
186 ( < $($t:ty),+ >::, $st:expr #(, #match_args)* ) => {
187 #st_name::< $($t),* >::#func_name( #(#expan_args),* );
188 };
189
190 ( < $($lt:lifetime),+ $(, $t:ty)* >, $st:expr #(, #match_args)* ) => {
191 #st_name::#func_name::< $($lt),* $(, $t)* >( #(#expan_args),* );
192 };
193 ( < $($t:ty)+ >, $st:expr #(, #match_args)* ) => {
194 #st_name::#func_name::< $($t),* >( #(#expan_args),* );
195 };
196
197 ( $st:expr #(, #match_args)* ) => {
198 #st_name::#func_name( #(#expan_args),* );
199 };
200 }
201
202 #[macro_export]
203 macro_rules! #mac_scope_name {
204 ( [ #(#match_fields)* ] < $($lt1:lifetime),+ $(, $t1:ty)* >::< $($lt2:lifetime),+ $(, $t2:ty)* > #(, #match_args)* ) => {
205 #st_name::< $($lt1),* $(, $t1)* >::#func_name::< $($lt2),* $(, $t2)* >( #(#expan_args_scope),* );
206 };
207 ( [ #(#match_fields)* ] < $($lt1:lifetime),+ $(, $t1:ty)* >::< $($t2:ty),+ > #(, #match_args)* ) => {
208 #st_name::< $($lt1),* $(, $t1)* >::#func_name::< $($t2),* >( #(#expan_args_scope),* );
209 };
210 ( [ #(#match_fields)* ] < $($t1:ty),+ >::< $($lt2:lifetime),+ $(, $t2:ty)* > #(, #match_args)* ) => {
211 #st_name::< $($t1),* >::#func_name::< $($lt2),* $(, $t2)* >( #(#expan_args_scope),* );
212 };
213 ( [ #(#match_fields)* ] < $($t1:ty),+ >::< $($t2:ty),+ > #(, #match_args)* ) => {
214 #st_name::< $($t1),* >::#func_name::< $($t2),* >( #(#expan_args_scope),* );
215 };
216
217 ( [ #(#match_fields)* ] < $($lt:lifetime),+ $(, $t:ty)* >:: #(, #match_args)* ) => {
218 #st_name::< $($lt),* $(, $t)* >::#func_name( #(#expan_args_scope),* );
219 };
220 ( [ #(#match_fields)* ] < $($t:ty),+ >:: #(, #match_args)* ) => {
221 #st_name::< $($t),* >::#func_name( #(#expan_args_scope),* );
222 };
223
224 ( [ #(#match_fields)* ] < $($lt:lifetime),+ $(, $t:ty)* > #(, #match_args)* ) => {
225 #st_name::#func_name::< $($lt),* $(, $t)* >( #(#expan_args_scope),* );
226 };
227 ( [ #(#match_fields)* ] < $($t:ty)+ > #(, #match_args)* ) => {
228 #st_name::#func_name::< $($t),* >( #(#expan_args_scope),* );
229 };
230
231 ( [ #(#match_fields)* ] #(#match_args),* ) => {
232 #st_name::#func_name( #(#expan_args_scope),* );
233 };
234 }
235 });
236 }).collect();
237
238 let macs = match macs {Ok(x) => x, Err(x) => return x.to_compile_error().into()};
239 let res = quote! {
240 #item_impl
241
242 #(#macs)*
243 };
244 return res.into();
245 },
246 _ => {
247 return syn::Error::new_spanned(
248 item,
249 "clasma::partial must be applied to an `fn` or `impl` block.",
250 ).to_compile_error().into()
251 },
252 }
253}