1use std::collections::BTreeMap;
2use proc_macro::TokenStream;
3use proc_macro2::{Span, TokenStream as TokenStream2};
4use quote::{quote, quote_spanned, ToTokens};
5use syn::parse::Parser;
6use syn::punctuated::Punctuated;
7use syn::spanned::Spanned;
8use syn::{
9 parse_macro_input, Attribute, Field,
10 Fields, GenericParam, Ident, Item, ItemEnum, ItemStruct, Meta, MetaList,
11 MetaNameValue, PathArguments, Token, Type, TypeParam, Variant,
12};
13
14#[proc_macro_attribute]
42pub fn derive_n_functor(args: TokenStream, item: TokenStream) -> TokenStream {
43 let _args: TokenStream2 = args.clone().into();
44 let _item: TokenStream2 = item.clone().into();
45 let args = Args::from_token_stream(args);
46 let mut input = parse_macro_input!(item as Item);
47 let output = match &mut input {
48 Item::Enum(_enum) => AbstractFunctorFactory::from_item_enum(args, _enum),
49 Item::Struct(_struct) => AbstractFunctorFactory::from_item_struct(args, _struct),
50 _ => {
51 quote_spanned! {_args.span() => compile_error!("Could not derive n-functor for this, it is neither an enum or struct.")}
52 }
53 };
54 quote! {
55 #input
56 #output
57 }
58 .into()
59}
60
61struct Args {
62 pub parameter_names: BTreeMap<Ident, Ident>,
63 pub mapping_name: String,
64 }
67
68impl Args {
69 fn from_token_stream(stream: TokenStream) -> Self {
70 let parsed_attrs: Punctuated<MetaNameValue, Token![,]> =
71 Parser::parse2(Punctuated::parse_terminated, stream.into()).unwrap();
72 Args::from_iter(parsed_attrs.into_iter())
73 }
74
75 fn from_iter(input: impl Iterator<Item = MetaNameValue>) -> Self {
76 let search_for_mapping_token = Ident::new("map_name", Span::call_site());
77 let mut mapping_name = "map".to_string();
78 let parameter_names = input
79 .filter_map(|name_val| {
80 if name_val.path.segments.last().unwrap().ident == search_for_mapping_token {
81 if let syn::Expr::Path(path) = name_val.value {
83 mapping_name = path.path.segments.last()?.ident.to_string();
84 }
85 return None
87 }
88 let ty_ident = &name_val.path.segments.last()?.ident;
90 let rename_ident = &match &name_val.value {
91 syn::Expr::Path(path) => path.path.segments.last(),
92 _ => None,
93 }?
94 .ident;
95 Some((ty_ident.clone(), rename_ident.clone()))
96 })
97 .collect();
98 Self { parameter_names, mapping_name }
99 }
100
101 fn get_suffix_for(&self, ident: &Ident) -> Ident {
102 self.parameter_names
103 .get(ident)
104 .cloned()
105 .unwrap_or_else(|| Ident::new(&format!("{ident}"), Span::call_site()))
106 }
107
108 fn get_whole_map_name(&self, ident: &Ident) -> Ident {
109 let suffix = self.get_suffix_for(ident);
110 Ident::new(&format!("map_{suffix}"), Span::call_site())
111 }
112
113 fn get_map_all_name(&self) -> Ident {
114 Ident::new(&self.mapping_name, Span::call_site())
115 }
116}
117
118enum FieldMapping {
119 Trivial(Ident),
120 SubMapForArgs(Vec<Ident>),
121}
122
123type FieldNameMapping = Option<Vec<Ident>>;
124
125struct AbstractFunctorFactory {
126 pub args: Args,
127 pub type_maps_to_type: Vec<(Ident, Ident)>,
129 pub type_name: Ident,
130}
131
132impl AbstractFunctorFactory {
133 fn from_generics<'a>(
134 args: Args,
135 generics: impl Iterator<Item = &'a GenericParam>,
136 type_name: Ident,
137 ) -> Self {
138 let mut type_maps_to_type = vec![];
139 for generic in generics {
140 match generic {
141 GenericParam::Lifetime(_) => {}
142 GenericParam::Type(ty) => type_maps_to_type.push((
143 ty.ident.clone(),
144 Ident::new(&format!("{}2", ty.ident), Span::call_site()),
145 )),
146 GenericParam::Const(_) => {}
147 }
148 }
149 AbstractFunctorFactory {
150 args,
151 type_maps_to_type,
152 type_name,
153 }
154 }
155
156 fn from_item_enum(args: Args, source: &mut ItemEnum) -> TokenStream2 {
157 let name = source.ident.clone();
158 let factory = AbstractFunctorFactory::from_generics(
159 args,
160 source.generics.params.iter(),
161 source.ident.clone(),
162 );
163 let map_name = factory.args.get_map_all_name();
164 let (impl_gen, type_gen, where_clause) = source.generics.split_for_impl();
165 let mapped_params: Punctuated<TypeParam, Token![,]> = factory
166 .type_maps_to_type
167 .iter()
168 .map(|a| TypeParam::from(a.1.clone()))
169 .collect();
170 let fn_args = factory.make_fn_arguments();
171 let implemented: Punctuated<TokenStream2, Token![,]> = source
172 .variants
173 .iter_mut()
174 .map(|variant| factory.implement_body_for_variant(variant))
175 .collect();
176 quote! {
177 impl #impl_gen #name #type_gen #where_clause {
178 pub fn #map_name<#mapped_params>(self, #fn_args) -> #name<#mapped_params> {
179 match self {
180 #implemented
181 }
182 }
183 }
184 }
185 }
186
187 fn from_item_struct(args: Args, source: &mut ItemStruct) -> TokenStream2 {
188 let name = source.ident.clone();
189 let factory = AbstractFunctorFactory::from_generics(
190 args,
191 source.generics.params.iter(),
192 source.ident.clone(),
193 );
194 let map_name = factory.args.get_map_all_name();
195 let (impl_gen, type_gen, where_clause) = source.generics.split_for_impl();
196 let mapped_params: Punctuated<TypeParam, Token![,]> = factory
197 .type_maps_to_type
198 .iter()
199 .map(|a| TypeParam::from(a.1.clone()))
200 .collect();
201 let fn_args = factory.make_fn_arguments();
202 let (fields, names_for_unnamed) = Self::unpack_fields(&source.fields);
203 let expanded = match source.fields {
204 Fields::Named(_) => quote! {#name {#fields}},
205 Fields::Unnamed(_) => quote! {#name(#fields)},
206 Fields::Unit => quote! {#name},
207 };
208 let implemented =
209 factory.apply_mapping_to_fields(&mut source.fields, name.clone(), names_for_unnamed);
210 quote! {
211 impl #impl_gen #name #type_gen #where_clause {
212 pub fn #map_name<#mapped_params>(self, #fn_args) -> #name<#mapped_params> {
213 let #expanded = self;
214 #implemented
215 }
216 }
217 }
218 }
219
220 fn implement_body_for_variant(&self, variant: &mut Variant) -> TokenStream2 {
221 let type_name = &self.type_name;
222 let name = &variant.ident;
223 let (unpacked, name_mapping) = Self::unpack_fields(&variant.fields);
224 match variant.fields {
225 Fields::Named(_) => {
226 let implemented =
227 self.apply_mapping_to_fields(&mut variant.fields, name.clone(), name_mapping);
228 quote! {
229 #type_name::#name{#unpacked} => #type_name::#implemented
230 }
231 }
232 Fields::Unnamed(_) => {
233 let implemented =
234 self.apply_mapping_to_fields(&mut variant.fields, name.clone(), name_mapping);
235 quote! {
236 #type_name::#name(#unpacked) => #type_name::#implemented
237 }
238 }
239 Fields::Unit => quote! {
240 #type_name::#name => #type_name::#name
241 },
242 }
243 }
244
245 fn get_mappable_generics_of_type(&self, ty: &Type) -> Option<FieldMapping> {
247 if let Type::Path(path) = ty {
248 let last_segment = path.path.segments.last();
249 if path.path.segments.len() == 1
251 && self
252 .type_maps_to_type
253 .iter()
254 .any(|(gen, _)| *gen == last_segment.unwrap().ident)
255 {
256 return Some(FieldMapping::Trivial(last_segment.unwrap().ident.clone()));
258 }
259 }
260 let mut buffer = Vec::new();
261 self.recursive_get_generics_of_type_to_buffer(ty, &mut buffer);
262 (!buffer.is_empty()).then_some(FieldMapping::SubMapForArgs(buffer))
263 }
264
265 fn recursive_get_generics_of_type_to_buffer(&self, ty: &Type, buffer: &mut Vec<Ident>) {
268 match ty {
269 Type::Array(array) => {
270 self.recursive_get_generics_of_type_to_buffer(&array.elem, buffer)
271 }
272 Type::Paren(paren) => {
273 self.recursive_get_generics_of_type_to_buffer(&paren.elem, buffer)
274 }
275 Type::Path(path) => {
276 if let Some(segment) = path.path.segments.last().filter(|segment| {
277 self.type_maps_to_type
278 .iter()
279 .any(|(gen, _)| segment.ident == *gen)
280 }) {
281 if !buffer.contains(&segment.ident) {
282 buffer.push(segment.ident.clone());
283 }
284 if let PathArguments::AngleBracketed(generics) = &segment.arguments {
285 for generic in &generics.args {
286 if let syn::GenericArgument::Type(ty) = generic {
287 self.recursive_get_generics_of_type_to_buffer(ty, buffer)
288 }
289 }
290 }
291 }
292 if let Some(PathArguments::AngleBracketed(generics)) = &path.path.segments.last().map(|segment| &segment.arguments) {
294 for generic in &generics.args {
295 if let syn::GenericArgument::Type(ty) = generic {
296 self.recursive_get_generics_of_type_to_buffer(ty, buffer)
297 }
298 }
299 }
300 }
301 Type::Tuple(tuple) => {
302 for ty in &tuple.elems {
303 self.recursive_get_generics_of_type_to_buffer(ty, buffer)
304 }
305 }
306 _ => {}
307 }
308 }
309
310 fn unpack_fields(fields: &Fields) -> (TokenStream2, FieldNameMapping) {
311 match fields {
312 Fields::Named(named) => {
313 let names: Punctuated<Ident, Token![,]> = named
314 .named
315 .iter()
316 .map(|field| field.ident.clone().unwrap())
317 .collect();
318 let tokens = quote! {
319 #names
320 };
321 (tokens, None)
322 }
323 Fields::Unnamed(unnamed) => {
324 let faux_names: Punctuated<Ident, Token![,]> = unnamed
325 .unnamed
326 .iter()
327 .enumerate()
328 .map(|(num, _)| Ident::new(&format!("field_{num}"), Span::call_site()))
329 .collect();
330 let tokens = quote! {
331 #faux_names
332 };
333 (tokens, Some(faux_names.into_iter().collect()))
334 }
335 Fields::Unit => (quote! {}, None),
336 }
337 }
338
339 fn apply_mapping_to_fields(
340 &self,
341 fields: &mut Fields,
342 name: Ident,
343 names_for_unnamed: FieldNameMapping,
344 ) -> TokenStream2 {
345 match fields {
346 Fields::Named(named) => {
347 let mapped: Punctuated<TokenStream2, Token![,]> = named
348 .named
349 .iter_mut()
350 .map(|field| {
351 let field_name = field.ident.clone().unwrap();
353 let new_field_content =
354 self.apply_mapping_to_field_ref(field, quote! {#field_name});
355 quote! {
356 #field_name: #new_field_content
357 }
358 })
359 .collect();
360 let implemented = mapped.to_token_stream();
361 quote! {
362 #name {
363 #implemented
364 }
365 }
366 }
367 Fields::Unnamed(unnamed) => {
368 let names = names_for_unnamed.unwrap();
369 let mapped: Punctuated<TokenStream2, Token![,]> = unnamed
370 .unnamed
371 .iter_mut()
372 .enumerate()
373 .map(|(field_num, field)| {
374 let name_of_field = &names[field_num];
375 let field_ref = quote! {#name_of_field};
376 let new_field_content = self.apply_mapping_to_field_ref(field, field_ref);
377 quote! {
378 #new_field_content
379 }
380 })
381 .collect();
382 quote! {
383 #name(#mapped)
384 }
385 }
386 Fields::Unit => quote! {#name},
387 }
388 }
389
390 fn apply_mapping_to_field_ref(
391 &self,
392 field: &mut Field,
393 field_ref: TokenStream2,
394 ) -> TokenStream2 {
395 match self.get_mappable_generics_of_type(&field.ty) {
396 Some(mapping) => match mapping {
397 FieldMapping::Trivial(type_to_map) => {
398 let map = self.args.get_whole_map_name(&type_to_map);
399 quote! {
400 #map(#field_ref)
401 }
402 }
403 FieldMapping::SubMapForArgs(sub_maps) => {
405 let map_all_name = self.args.get_map_all_name();
406 let all_fns: Punctuated<TokenStream2, Token![,]> = sub_maps
407 .iter()
408 .map(|ident| {
409 let ident = self.args.get_whole_map_name(ident);
410 quote! {&#ident}
411 })
412 .collect();
413 match FieldArg::find_in_attributes(field.attrs.iter()) {
414 Some(FieldArg { alt_function }) => {
415 FieldArg::remove_from_attributes(&mut field.attrs);
416 quote! {
417 (#alt_function)(#field_ref, #all_fns)
418 }
419 }
420 None => {
421 quote! {
422 #field_ref.#map_all_name(#all_fns)
423 }
424 }
425 }
426 }
427 },
428 None => quote! {#field_ref},
430 }
431 }
432
433 fn make_fn_arguments(&self) -> TokenStream2 {
434 let mapped: Punctuated<TokenStream2, Token![,]> = self
435 .type_maps_to_type
436 .iter()
437 .map(|(from, to)| {
438 let fn_name = self.args.get_whole_map_name(from);
439 quote! {
442 #fn_name: impl Fn(#from) -> #to
443 }
444 })
445 .collect();
446 mapped.into_token_stream()
447 }
448}
449
450struct FieldArg {
451 pub alt_function: TokenStream2,
452}
453
454impl FieldArg {
455 fn map_with_attr_ident() -> Ident {
456 Ident::new("map_with", Span::call_site())
457 }
458
459 fn remove_from_attributes(attributes: &mut Vec<Attribute>) {
460 let ident_to_check = Self::map_with_attr_ident();
461 let to_remove: Vec<_> = attributes
463 .iter()
464 .enumerate()
465 .rev()
466 .filter_map(|(num, attribute)| match &attribute.meta {
467 Meta::Path(_) => None,
468 Meta::List(meta) => {
469 let last = meta.path.segments.last()?;
470 (last.ident == ident_to_check).then_some(num)
471 }
472 Meta::NameValue(_) => None,
473 })
474 .collect();
475 for remove in to_remove {
476 attributes.remove(remove);
477 }
478 }
479
480 fn find_in_attributes<'a>(mut attributes: impl Iterator<Item = &'a Attribute>) -> Option<Self> {
481 attributes.find_map(|attribute| match &attribute.meta {
482 Meta::Path(_) => None,
483 Meta::List(meta) => Self::from_meta_list(meta),
484 Meta::NameValue(_) => None,
485 })
486 }
487
488 fn from_meta_list(meta: &MetaList) -> Option<Self> {
489 let ident_to_check = Self::map_with_attr_ident();
490 if meta.path.segments.iter().last().map(|x| &x.ident) == Some(&ident_to_check) {
491 Some(Self {
492 alt_function: meta.tokens.clone(),
493 })
494 } else {
495 None
496 }
497 }
498}
499