1use proc_macro::TokenStream;
79use proc_macro2::{Span, TokenStream as TokenStream2};
80use quote::{quote, quote_spanned, ToTokens};
81use std::collections::BTreeMap;
82use syn::parse::Parser;
83use syn::punctuated::Punctuated;
84use syn::spanned::Spanned;
85use syn::{
86 parse_macro_input, Attribute, Expr, Field, Fields, GenericParam, Ident, Item, ItemEnum,
87 ItemStruct, Meta, MetaList, MetaNameValue, PathArguments, Token, Type, TypeParam, Variant,
88};
89
90#[proc_macro_attribute]
92pub fn derive_n_functor(args: TokenStream, item: TokenStream) -> TokenStream {
93 let _args: TokenStream2 = args.clone().into();
94 let _item: TokenStream2 = item.clone().into();
95 let args = Args::from_token_stream(args);
96 let mut input = parse_macro_input!(item as Item);
97 let output = match &mut input {
98 Item::Enum(_enum) => AbstractFunctorFactory::from_item_enum(args, _enum),
99 Item::Struct(_struct) => AbstractFunctorFactory::from_item_struct(args, _struct),
100 _ => {
101 quote_spanned! {_args.span() => compile_error!("Could not derive n-functor for this, it is neither an enum or struct.")}
102 }
103 };
104 quote! {
105 #input
106 #output
107 }
108 .into()
109}
110
111#[derive(Clone)]
113struct Args {
114 pub parameter_names: BTreeMap<Ident, Ident>,
115 pub mapping_name: String,
116 pub should_generate_map_res: bool,
117 pub map_res_suffix: String,
118}
119
120impl Args {
121 fn from_token_stream(stream: TokenStream) -> Self {
122 let parsed_attrs: Punctuated<MetaNameValue, Token![,]> =
123 Parser::parse2(Punctuated::parse_terminated, stream.into()).unwrap();
124 Args::from_iter(parsed_attrs.into_iter())
125 }
126
127 fn from_iter(input: impl Iterator<Item = MetaNameValue>) -> Self {
128 let search_for_mapping_token = Ident::new("map_name", Span::call_site());
129 let mut mapping_name = "map".to_string();
130 let mut should_generate_map_res = false;
131 let mut map_res_suffix = "res".to_string();
132 let parameter_names = input
133 .filter_map(|name_val| {
134 if name_val.path.segments.last().unwrap().ident == search_for_mapping_token {
135 if let syn::Expr::Path(path) = name_val.value {
137 mapping_name = path.path.segments.last()?.ident.to_string();
138 }
139 return None;
141 }
142 if name_val.path.segments.len() == 1
143 && name_val.path.segments.get(0).unwrap().ident == "impl_map_res"
144 {
145 should_generate_map_res =
146 name_val.value.to_token_stream().to_string() == "true";
147 return None;
148 }
149 if name_val.path.segments.len() == 1
150 && name_val.path.segments.get(0).unwrap().ident == "map_res_suffix"
151 {
152 map_res_suffix = name_val.value.to_token_stream().to_string();
153 return None;
154 }
155 let ty_ident = &name_val.path.segments.last()?.ident;
157 let rename_ident = &match &name_val.value {
158 syn::Expr::Path(path) => path.path.segments.last(),
159 _ => None,
160 }?
161 .ident;
162 Some((ty_ident.clone(), rename_ident.clone()))
163 })
164 .collect();
165 Self {
166 parameter_names,
167 mapping_name,
168 should_generate_map_res,
169 map_res_suffix,
170 }
171 }
172
173 fn get_suffix_for(&self, ident: &Ident) -> Ident {
174 self.parameter_names
175 .get(ident)
176 .cloned()
177 .unwrap_or_else(|| Ident::new(&format!("{ident}"), Span::call_site()))
178 }
179
180 fn get_whole_map_name(&self, ident: &Ident) -> Ident {
181 let suffix = self.get_suffix_for(ident);
182 Ident::new(&format!("map_{suffix}"), Span::call_site())
183 }
184
185 fn get_map_all_name(&self) -> Ident {
186 Ident::new(&self.mapping_name, Span::call_site())
187 }
188}
189
190enum FieldMapping {
191 Trivial(Ident),
192 SubMapForArgs(Vec<Ident>),
193}
194
195type FieldNameMapping = Option<Vec<Ident>>;
196
197struct AbstractFunctorFactory {
200 pub args: Args,
201 pub type_maps_to_type: Vec<(Ident, Ident)>,
203 pub type_name: Ident,
204}
205
206impl AbstractFunctorFactory {
207 fn from_generics<'a>(
208 args: Args,
209 generics: impl Iterator<Item = &'a GenericParam>,
210 type_name: Ident,
211 ) -> Self {
212 let mut type_maps_to_type = vec![];
213 for generic in generics {
214 match generic {
215 GenericParam::Lifetime(_) => {}
216 GenericParam::Type(ty) => type_maps_to_type.push((
217 ty.ident.clone(),
218 Ident::new(&format!("{}2", ty.ident), Span::call_site()),
219 )),
220 GenericParam::Const(_) => {}
221 }
222 }
223 AbstractFunctorFactory {
224 args,
225 type_maps_to_type,
226 type_name,
227 }
228 }
229
230 fn from_item_enum(args: Args, source: &mut ItemEnum) -> TokenStream2 {
231 let name = source.ident.clone();
232 let factory = AbstractFunctorFactory::from_generics(
233 args,
234 source.generics.params.iter(),
235 source.ident.clone(),
236 );
237 let map_name = factory.args.get_map_all_name();
238 let (impl_gen, type_gen, where_clause) = source.generics.split_for_impl();
239 let mapped_params: Punctuated<TypeParam, Token![,]> = factory
240 .type_maps_to_type
241 .iter()
242 .map(|a| TypeParam::from(a.1.clone()))
243 .collect();
244 let fn_args = factory.make_fn_arguments();
245 let implemented: Punctuated<TokenStream2, Token![,]> = source
246 .variants
247 .iter_mut()
248 .map(|variant| factory.implement_body_for_variant(variant, false))
249 .collect();
250 quote! {
251 impl #impl_gen #name #type_gen #where_clause {
252 pub fn #map_name<#mapped_params>(self, #fn_args) -> #name<#mapped_params> {
253 match self {
254 #implemented
255 }
256 }
257 }
258 }
259 }
260
261 fn from_item_struct(args: Args, source: &mut ItemStruct) -> TokenStream2 {
262 let name = source.ident.clone();
263 let factory = AbstractFunctorFactory::from_generics(
264 args.clone(),
265 source.generics.params.iter(),
266 source.ident.clone(),
267 );
268 let Args {
269 should_generate_map_res,
270 map_res_suffix,
271 ..
272 } = args;
273 let map_name = factory.args.get_map_all_name();
274 let map_res_name = Ident::new(
275 &format!("{}_{}", map_name, map_res_suffix),
276 Span::call_site(),
277 );
278 let (impl_gen, type_gen, where_clause) = source.generics.split_for_impl();
279 let mapped_params: Punctuated<TypeParam, Token![,]> = factory
280 .type_maps_to_type
281 .iter()
282 .map(|a| TypeParam::from(a.1.clone()))
283 .collect();
284 let map_res_error_ident = factory.ident_for_error();
285 let fn_args_for_map_res = factory.make_fn_arguments_map_res(map_res_error_ident.clone());
286 let mut mapped_params_for_map_res = mapped_params.clone();
287 mapped_params_for_map_res.push(map_res_error_ident.clone().into());
288 let fn_args = factory.make_fn_arguments();
289 let (fields, names_for_unnamed) = Self::unpack_fields(&source.fields);
290 let expanded = match source.fields {
291 Fields::Named(_) => quote! {#name {#fields}},
292 Fields::Unnamed(_) => quote! {#name(#fields)},
293 Fields::Unit => quote! {#name},
294 };
295 let mut source_2 = source.fields.clone();
296 let implemented = factory.apply_mapping_to_fields(
297 &mut source.fields,
298 name.clone(),
299 names_for_unnamed.clone(),
300 false,
301 );
302 let implemented_map_res =
303 factory.apply_mapping_to_fields(&mut source_2, name.clone(), names_for_unnamed, true);
304 let map_res_impl = if should_generate_map_res {
305 quote! {
306 pub fn #map_res_name<#mapped_params_for_map_res>(self, #fn_args_for_map_res) -> Result<#name<#mapped_params>, #map_res_error_ident> {
307 let #expanded = self;
308 Ok({#implemented_map_res})
309 }
310 }
311 } else {
312 quote! {}
313 };
314 quote! {
315 impl #impl_gen #name #type_gen #where_clause {
316 pub fn #map_name<#mapped_params>(self, #fn_args) -> #name<#mapped_params> {
317 let #expanded = self;
318 #implemented
319 }
320 #map_res_impl
321 }
322 }
323 }
324
325 fn ident_for_error(&self) -> Ident {
327 let mut candidate = Ident::new("E", Span::call_site());
328 let mut suffix: u32 = 0;
329 loop {
330 let mut changed_this_loop = false;
331 for (key, value) in self.type_maps_to_type.iter() {
332 if &candidate == key || &candidate == value {
333 suffix += 1;
334 candidate = Ident::new(&format!("E{suffix}"), Span::call_site());
335 changed_this_loop = true;
336 }
337 }
338 if !changed_this_loop {
339 break;
340 }
341 }
342 candidate
343 }
344
345 fn implement_body_for_variant(&self, variant: &mut Variant, is_map_res: bool) -> TokenStream2 {
346 let type_name = &self.type_name;
347 let name = &variant.ident;
348 let (unpacked, name_mapping) = Self::unpack_fields(&variant.fields);
349 match variant.fields {
350 Fields::Named(_) => {
351 let implemented = self.apply_mapping_to_fields(
352 &mut variant.fields,
353 name.clone(),
354 name_mapping,
355 is_map_res,
356 );
357 quote! {
358 #type_name::#name{#unpacked} => #type_name::#implemented
359 }
360 }
361 Fields::Unnamed(_) => {
362 let implemented = self.apply_mapping_to_fields(
363 &mut variant.fields,
364 name.clone(),
365 name_mapping,
366 is_map_res,
367 );
368 quote! {
369 #type_name::#name(#unpacked) => #type_name::#implemented
370 }
371 }
372 Fields::Unit => quote! {
373 #type_name::#name => #type_name::#name
374 },
375 }
376 }
377
378 fn get_mappable_generics_of_type(&self, ty: &Type) -> Option<FieldMapping> {
380 if let Type::Path(path) = ty {
381 let last_segment = path.path.segments.last();
382 if path.path.segments.len() == 1
384 && self
385 .type_maps_to_type
386 .iter()
387 .any(|(generic, _)| *generic == last_segment.unwrap().ident)
388 {
389 return Some(FieldMapping::Trivial(last_segment.unwrap().ident.clone()));
391 }
392 }
393 let mut buffer = Vec::new();
394 self.recursive_get_generics_of_type_to_buffer(ty, &mut buffer);
395 (!buffer.is_empty()).then_some(FieldMapping::SubMapForArgs(buffer))
396 }
397
398 fn recursive_get_generics_of_type_to_buffer(&self, ty: &Type, buffer: &mut Vec<Ident>) {
401 match ty {
402 Type::Array(array) => {
403 self.recursive_get_generics_of_type_to_buffer(&array.elem, buffer)
404 }
405 Type::Paren(paren) => {
406 self.recursive_get_generics_of_type_to_buffer(&paren.elem, buffer)
407 }
408 Type::Path(path) => {
409 if let Some(segment) = path.path.segments.last().filter(|segment| {
410 self.type_maps_to_type
411 .iter()
412 .any(|(generic, _)| segment.ident == *generic)
413 }) {
414 if !buffer.contains(&segment.ident) {
415 buffer.push(segment.ident.clone());
416 }
417 if let PathArguments::AngleBracketed(generics) = &segment.arguments {
418 for generic in &generics.args {
419 if let syn::GenericArgument::Type(ty) = generic {
420 self.recursive_get_generics_of_type_to_buffer(ty, buffer)
421 }
422 }
423 }
424 }
425 if let Some(PathArguments::AngleBracketed(generics)) =
427 &path.path.segments.last().map(|segment| &segment.arguments)
428 {
429 for generic in &generics.args {
430 if let syn::GenericArgument::Type(ty) = generic {
431 self.recursive_get_generics_of_type_to_buffer(ty, buffer)
432 }
433 }
434 }
435 }
436 Type::Tuple(tuple) => {
437 for ty in &tuple.elems {
438 self.recursive_get_generics_of_type_to_buffer(ty, buffer)
439 }
440 }
441 _ => {}
442 }
443 }
444
445 fn unpack_fields(fields: &Fields) -> (TokenStream2, FieldNameMapping) {
446 match fields {
447 Fields::Named(named) => {
448 let names: Punctuated<Ident, Token![,]> = named
449 .named
450 .iter()
451 .map(|field| field.ident.clone().unwrap())
452 .collect();
453 let tokens = quote! {
454 #names
455 };
456 (tokens, None)
457 }
458 Fields::Unnamed(unnamed) => {
459 let faux_names: Punctuated<Ident, Token![,]> = unnamed
460 .unnamed
461 .iter()
462 .enumerate()
463 .map(|(num, _)| Ident::new(&format!("field_{num}"), Span::call_site()))
464 .collect();
465 let tokens = quote! {
466 #faux_names
467 };
468 (tokens, Some(faux_names.into_iter().collect()))
469 }
470 Fields::Unit => (quote! {}, None),
471 }
472 }
473
474 fn apply_mapping_to_fields(
475 &self,
476 fields: &mut Fields,
477 name: Ident,
478 names_for_unnamed: FieldNameMapping,
479 is_map_res: bool,
480 ) -> TokenStream2 {
481 match fields {
482 Fields::Named(named) => {
483 let mapped: Punctuated<TokenStream2, Token![,]> = named
484 .named
485 .iter_mut()
486 .map(|field| {
487 let field_name = field.ident.clone().unwrap();
489 let new_field_content = self.apply_mapping_to_field_ref(
490 field,
491 quote! {#field_name},
492 is_map_res,
493 );
494 quote! {
495 #field_name: #new_field_content
496 }
497 })
498 .collect();
499 let implemented = mapped.to_token_stream();
500 quote! {
501 #name {
502 #implemented
503 }
504 }
505 }
506 Fields::Unnamed(unnamed) => {
507 let names = names_for_unnamed.unwrap();
508 let mapped: Punctuated<TokenStream2, Token![,]> = unnamed
509 .unnamed
510 .iter_mut()
511 .enumerate()
512 .map(|(field_num, field)| {
513 let name_of_field = &names[field_num];
514 let field_ref = quote! {#name_of_field};
515 let new_field_content =
516 self.apply_mapping_to_field_ref(field, field_ref, is_map_res);
517 quote! {
518 #new_field_content
519 }
520 })
521 .collect();
522 quote! {
523 #name(#mapped)
524 }
525 }
526 Fields::Unit => quote! {#name},
527 }
528 }
529
530 fn apply_mapping_to_field_ref(
531 &self,
532 field: &mut Field,
533 field_ref: TokenStream2,
534 is_map_res: bool,
535 ) -> TokenStream2 {
536 let postfix = if is_map_res {
537 quote! {?}
538 } else {
539 quote! {}
540 };
541 match self.get_mappable_generics_of_type(&field.ty) {
542 Some(mapping) => match mapping {
543 FieldMapping::Trivial(type_to_map) => {
544 let map = self.args.get_whole_map_name(&type_to_map);
545 quote! {
546 #map(#field_ref)#postfix
547 }
548 }
549 FieldMapping::SubMapForArgs(sub_maps) => {
551 let map_all_name = self.args.get_map_all_name();
552 let all_fns: Punctuated<TokenStream2, Token![,]> = sub_maps
553 .iter()
554 .map(|ident| {
555 let ident = self.args.get_whole_map_name(ident);
556 quote! {&#ident}
557 })
558 .collect();
559 match FieldArg::find_in_attributes(field.attrs.iter()) {
560 Some(FieldArg {
561 alt_function,
562 map_res_with_function,
563 }) => {
564 let function_to_use = if is_map_res && map_res_with_function.is_some() {
565 map_res_with_function.clone().unwrap()
566 } else {
567 alt_function
568 };
569 FieldArg::remove_from_attributes(&mut field.attrs);
570 quote! {
571 (#function_to_use)(#field_ref, #all_fns)#postfix
572 }
573 }
574 None => {
575 quote! {
576 #field_ref.#map_all_name(#all_fns)#postfix
577 }
578 }
579 }
580 }
581 },
582 None => quote! {#field_ref},
584 }
585 }
586
587 fn make_fn_arguments(&self) -> TokenStream2 {
588 let mapped: Punctuated<TokenStream2, Token![,]> = self
589 .type_maps_to_type
590 .iter()
591 .map(|(from, to)| {
592 let fn_name = self.args.get_whole_map_name(from);
593 quote! {
596 #fn_name: impl Fn(#from) -> #to
597 }
598 })
599 .collect();
600 mapped.into_token_stream()
601 }
602
603 fn make_fn_arguments_map_res(&self, err_ident: Ident) -> TokenStream2 {
604 let mapped: Punctuated<TokenStream2, Token![,]> = self
605 .type_maps_to_type
606 .iter()
607 .map(|(from, to)| {
608 let fn_name = self.args.get_whole_map_name(from);
609 quote! {
612 #fn_name: impl Fn(#from) -> Result<#to, #err_ident>
613 }
614 })
615 .collect();
616 mapped.into_token_stream()
617 }
618}
619
620struct FieldArg {
621 pub alt_function: TokenStream2,
622 pub map_res_with_function: Option<TokenStream2>,
623}
624
625impl FieldArg {
626 fn map_with_attr_ident() -> Ident {
627 Ident::new("map_with", Span::call_site())
628 }
629
630 fn remove_from_attributes(attributes: &mut Vec<Attribute>) {
631 let ident_to_check = Self::map_with_attr_ident();
632 let to_remove: Vec<_> = attributes
634 .iter()
635 .enumerate()
636 .rev()
637 .filter_map(|(num, attribute)| match &attribute.meta {
638 Meta::Path(_) => None,
639 Meta::List(meta) => {
640 let last = meta.path.segments.last()?;
641 (last.ident == ident_to_check).then_some(num)
642 }
643 Meta::NameValue(_) => None,
644 })
645 .collect();
646 for remove in to_remove {
647 attributes.remove(remove);
648 }
649 }
650
651 fn find_in_attributes<'a>(mut attributes: impl Iterator<Item = &'a Attribute>) -> Option<Self> {
652 attributes.find_map(|attribute| match &attribute.meta {
653 Meta::Path(_) => None,
654 Meta::List(meta) => Self::from_meta_list(meta),
655 Meta::NameValue(_) => None,
656 })
657 }
658
659 fn from_meta_list(meta: &MetaList) -> Option<Self> {
660 let ident_to_check = Self::map_with_attr_ident();
661 if meta.path.segments.iter().next_back().map(|x| &x.ident) == Some(&ident_to_check) {
662 let punctuated: Punctuated<Expr, Token![,]> = Punctuated::parse_terminated
664 .parse2(meta.tokens.clone())
665 .unwrap();
666 match punctuated.len() {
667 1 => Some(Self {
668 alt_function: punctuated[0].to_token_stream(),
669 map_res_with_function: None,
670 }),
671 2 => Some(Self {
672 alt_function: punctuated[0].to_token_stream(),
673 map_res_with_function: Some(punctuated[1].to_token_stream()),
674 }),
675 _ => Some(Self {
676 alt_function: quote! {compile_error!("Wrong number of arguments passed to map_with, this takes up to 2 arguments: one for regular mapping, and one for the 'map_res' function.")},
677 map_res_with_function: None,
678 }),
679 }
680 } else {
681 None
682 }
683 }
684}