1use proc_macro::TokenStream;
97use proc_macro2::{Span, TokenStream as TokenStream2};
98use quote::{quote, quote_spanned, ToTokens};
99use std::collections::BTreeMap;
100use syn::parse::Parser;
101use syn::punctuated::Punctuated;
102use syn::spanned::Spanned;
103use syn::{
104 parse_macro_input, Attribute, Expr, Field, Fields, GenericParam, Ident, Item, ItemEnum,
105 ItemStruct, Meta, MetaList, MetaNameValue, PathArguments, Token, Type, TypeParam, Variant,
106};
107
108#[proc_macro_attribute]
110pub fn derive_n_functor(args: TokenStream, item: TokenStream) -> TokenStream {
111 let _args: TokenStream2 = args.clone().into();
112 let _item: TokenStream2 = item.clone().into();
113 let args = Args::from_token_stream(args);
114 let mut input = parse_macro_input!(item as Item);
115 let output = match &mut input {
116 Item::Enum(_enum) => AbstractFunctorFactory::from_item_enum(args, _enum),
117 Item::Struct(_struct) => AbstractFunctorFactory::from_item_struct(args, _struct),
118 _ => {
119 quote_spanned! {_args.span() => compile_error!("Could not derive n-functor for this, it is neither an enum or struct.")}
120 }
121 };
122 quote! {
123 #input
124 #output
125 }
126 .into()
127}
128
129#[derive(Clone)]
131struct Args {
132 pub parameter_names: BTreeMap<Ident, Ident>,
133 pub mapping_name: String,
134 pub should_generate_map_res: bool,
135 pub map_res_suffix: String,
136}
137
138impl Args {
139 fn from_token_stream(stream: TokenStream) -> Self {
140 let parsed_attrs: Punctuated<MetaNameValue, Token![,]> =
141 Parser::parse2(Punctuated::parse_terminated, stream.into()).unwrap();
142 Args::from_iter(parsed_attrs.into_iter())
143 }
144
145 fn from_iter(input: impl Iterator<Item = MetaNameValue>) -> Self {
146 let search_for_mapping_token = Ident::new("map_name", Span::call_site());
147 let mut mapping_name = "map".to_string();
148 let mut should_generate_map_res = false;
149 let mut map_res_suffix = "res".to_string();
150 let parameter_names = input
151 .filter_map(|name_val| {
152 if name_val.path.segments.last().unwrap().ident == search_for_mapping_token {
153 if let syn::Expr::Path(path) = name_val.value {
155 mapping_name = path.path.segments.last()?.ident.to_string();
156 }
157 return None;
159 }
160 if name_val.path.segments.len() == 1
161 && name_val.path.segments.get(0).unwrap().ident == "impl_map_res"
162 {
163 should_generate_map_res =
164 name_val.value.to_token_stream().to_string() == "true";
165 return None;
166 }
167 if name_val.path.segments.len() == 1
168 && name_val.path.segments.get(0).unwrap().ident == "map_res_suffix"
169 {
170 map_res_suffix = name_val.value.to_token_stream().to_string();
171 return None;
172 }
173 let ty_ident = &name_val.path.segments.last()?.ident;
175 let rename_ident = &match &name_val.value {
176 syn::Expr::Path(path) => path.path.segments.last(),
177 _ => None,
178 }?
179 .ident;
180 Some((ty_ident.clone(), rename_ident.clone()))
181 })
182 .collect();
183 Self {
184 parameter_names,
185 mapping_name,
186 should_generate_map_res,
187 map_res_suffix,
188 }
189 }
190
191 fn get_suffix_for(&self, ident: &Ident) -> Ident {
192 self.parameter_names
193 .get(ident)
194 .cloned()
195 .unwrap_or_else(|| Ident::new(&format!("{ident}"), Span::call_site()))
196 }
197
198 fn get_whole_map_name(&self, ident: &Ident) -> Ident {
199 let suffix = self.get_suffix_for(ident);
200 Ident::new(&format!("map_{suffix}"), Span::call_site())
201 }
202
203 fn get_map_all_name(&self) -> Ident {
204 Ident::new(&self.mapping_name, Span::call_site())
205 }
206}
207
208enum FieldMapping {
209 Trivial(Ident),
210 SubMapForArgs(Vec<Ident>),
211}
212
213type FieldNameMapping = Option<Vec<Ident>>;
214
215struct AbstractFunctorFactory {
218 pub args: Args,
219 pub type_maps_to_type: Vec<(Ident, Ident)>,
221 pub type_name: Ident,
222}
223
224impl AbstractFunctorFactory {
225 fn from_generics<'a>(
226 args: Args,
227 generics: impl Iterator<Item = &'a GenericParam>,
228 type_name: Ident,
229 ) -> Self {
230 let mut type_maps_to_type = vec![];
231 for generic in generics {
232 match generic {
233 GenericParam::Lifetime(_) => {}
234 GenericParam::Type(ty) => type_maps_to_type.push((
235 ty.ident.clone(),
236 Ident::new(&format!("{}2", ty.ident), Span::call_site()),
237 )),
238 GenericParam::Const(_) => {}
239 }
240 }
241 AbstractFunctorFactory {
242 args,
243 type_maps_to_type,
244 type_name,
245 }
246 }
247
248 fn from_item_enum(args: Args, source: &mut ItemEnum) -> TokenStream2 {
249 let name = source.ident.clone();
250 let factory = AbstractFunctorFactory::from_generics(
251 args,
252 source.generics.params.iter(),
253 source.ident.clone(),
254 );
255 let map_name = factory.args.get_map_all_name();
256 let (impl_gen, type_gen, where_clause) = source.generics.split_for_impl();
257 let mapped_params: Punctuated<TypeParam, Token![,]> = factory
258 .type_maps_to_type
259 .iter()
260 .map(|a| TypeParam::from(a.1.clone()))
261 .collect();
262 let fn_args = factory.make_fn_arguments();
263 let implemented: Punctuated<TokenStream2, Token![,]> = source
264 .variants
265 .iter_mut()
266 .map(|variant| factory.implement_body_for_variant(variant, false))
267 .collect();
268 quote! {
269 impl #impl_gen #name #type_gen #where_clause {
270 pub fn #map_name<#mapped_params>(self, #fn_args) -> #name<#mapped_params> {
271 match self {
272 #implemented
273 }
274 }
275 }
276 }
277 }
278
279 fn from_item_struct(args: Args, source: &mut ItemStruct) -> TokenStream2 {
280 let name = source.ident.clone();
281 let factory = AbstractFunctorFactory::from_generics(
282 args.clone(),
283 source.generics.params.iter(),
284 source.ident.clone(),
285 );
286 let Args {
287 should_generate_map_res,
288 map_res_suffix,
289 ..
290 } = args;
291 let map_name = factory.args.get_map_all_name();
292 let map_res_name = Ident::new(
293 &format!("{}_{}", map_name, map_res_suffix),
294 Span::call_site(),
295 );
296 let (impl_gen, type_gen, where_clause) = source.generics.split_for_impl();
297 let mapped_params: Punctuated<TypeParam, Token![,]> = factory
298 .type_maps_to_type
299 .iter()
300 .map(|a| TypeParam::from(a.1.clone()))
301 .collect();
302 let map_res_error_ident = factory.ident_for_error();
303 let fn_args_for_map_res = factory.make_fn_arguments_map_res(map_res_error_ident.clone());
304 let mut mapped_params_for_map_res = mapped_params.clone();
305 mapped_params_for_map_res.push(map_res_error_ident.clone().into());
306 let fn_args = factory.make_fn_arguments();
307 let (fields, names_for_unnamed) = Self::unpack_fields(&source.fields);
308 let expanded = match source.fields {
309 Fields::Named(_) => quote! {#name {#fields}},
310 Fields::Unnamed(_) => quote! {#name(#fields)},
311 Fields::Unit => quote! {#name},
312 };
313 let mut source_2 = source.fields.clone();
314 let implemented = factory.apply_mapping_to_fields(
315 &mut source.fields,
316 name.clone(),
317 names_for_unnamed.clone(),
318 false,
319 );
320 let implemented_map_res =
321 factory.apply_mapping_to_fields(&mut source_2, name.clone(), names_for_unnamed, true);
322 let map_res_impl = if should_generate_map_res {
323 quote! {
324 pub fn #map_res_name<#mapped_params_for_map_res>(self, #fn_args_for_map_res) -> Result<#name<#mapped_params>, #map_res_error_ident> {
325 let #expanded = self;
326 Ok({#implemented_map_res})
327 }
328 }
329 } else {
330 quote! {}
331 };
332 quote! {
333 impl #impl_gen #name #type_gen #where_clause {
334 pub fn #map_name<#mapped_params>(self, #fn_args) -> #name<#mapped_params> {
335 let #expanded = self;
336 #implemented
337 }
338 #map_res_impl
339 }
340 }
341 }
342
343 fn ident_for_error(&self) -> Ident {
345 let mut candidate = Ident::new("E", Span::call_site());
346 let mut suffix: u32 = 0;
347 loop {
348 let mut changed_this_loop = false;
349 for (key, value) in self.type_maps_to_type.iter() {
350 if &candidate == key || &candidate == value {
351 suffix += 1;
352 candidate = Ident::new(&format!("E{suffix}"), Span::call_site());
353 changed_this_loop = true;
354 }
355 }
356 if !changed_this_loop {
357 break;
358 }
359 }
360 candidate
361 }
362
363 fn implement_body_for_variant(&self, variant: &mut Variant, is_map_res: bool) -> TokenStream2 {
364 let type_name = &self.type_name;
365 let name = &variant.ident;
366 let (unpacked, name_mapping) = Self::unpack_fields(&variant.fields);
367 match variant.fields {
368 Fields::Named(_) => {
369 let implemented = self.apply_mapping_to_fields(
370 &mut variant.fields,
371 name.clone(),
372 name_mapping,
373 is_map_res,
374 );
375 quote! {
376 #type_name::#name{#unpacked} => #type_name::#implemented
377 }
378 }
379 Fields::Unnamed(_) => {
380 let implemented = self.apply_mapping_to_fields(
381 &mut variant.fields,
382 name.clone(),
383 name_mapping,
384 is_map_res,
385 );
386 quote! {
387 #type_name::#name(#unpacked) => #type_name::#implemented
388 }
389 }
390 Fields::Unit => quote! {
391 #type_name::#name => #type_name::#name
392 },
393 }
394 }
395
396 fn get_mappable_generics_of_type(&self, ty: &Type) -> Option<FieldMapping> {
398 if let Type::Path(path) = ty {
399 let last_segment = path.path.segments.last();
400 if path.path.segments.len() == 1
402 && self
403 .type_maps_to_type
404 .iter()
405 .any(|(generic, _)| *generic == last_segment.unwrap().ident)
406 {
407 return Some(FieldMapping::Trivial(last_segment.unwrap().ident.clone()));
409 }
410 }
411 let mut buffer = Vec::new();
412 self.recursive_get_generics_of_type_to_buffer(ty, &mut buffer);
413 (!buffer.is_empty()).then_some(FieldMapping::SubMapForArgs(buffer))
414 }
415
416 fn recursive_get_generics_of_type_to_buffer(&self, ty: &Type, buffer: &mut Vec<Ident>) {
419 match ty {
420 Type::Array(array) => {
421 self.recursive_get_generics_of_type_to_buffer(&array.elem, buffer)
422 }
423 Type::Paren(paren) => {
424 self.recursive_get_generics_of_type_to_buffer(&paren.elem, buffer)
425 }
426 Type::Path(path) => {
427 if let Some(segment) = path.path.segments.last().filter(|segment| {
428 self.type_maps_to_type
429 .iter()
430 .any(|(generic, _)| segment.ident == *generic)
431 }) {
432 if !buffer.contains(&segment.ident) {
433 buffer.push(segment.ident.clone());
434 }
435 if let PathArguments::AngleBracketed(generics) = &segment.arguments {
436 for generic in &generics.args {
437 if let syn::GenericArgument::Type(ty) = generic {
438 self.recursive_get_generics_of_type_to_buffer(ty, buffer)
439 }
440 }
441 }
442 }
443 if let Some(PathArguments::AngleBracketed(generics)) =
445 &path.path.segments.last().map(|segment| &segment.arguments)
446 {
447 for generic in &generics.args {
448 if let syn::GenericArgument::Type(ty) = generic {
449 self.recursive_get_generics_of_type_to_buffer(ty, buffer)
450 }
451 }
452 }
453 }
454 Type::Tuple(tuple) => {
455 for ty in &tuple.elems {
456 self.recursive_get_generics_of_type_to_buffer(ty, buffer)
457 }
458 }
459 _ => {}
460 }
461 }
462
463 fn unpack_fields(fields: &Fields) -> (TokenStream2, FieldNameMapping) {
464 match fields {
465 Fields::Named(named) => {
466 let names: Punctuated<Ident, Token![,]> = named
467 .named
468 .iter()
469 .map(|field| field.ident.clone().unwrap())
470 .collect();
471 let tokens = quote! {
472 #names
473 };
474 (tokens, None)
475 }
476 Fields::Unnamed(unnamed) => {
477 let faux_names: Punctuated<Ident, Token![,]> = unnamed
478 .unnamed
479 .iter()
480 .enumerate()
481 .map(|(num, _)| Ident::new(&format!("field_{num}"), Span::call_site()))
482 .collect();
483 let tokens = quote! {
484 #faux_names
485 };
486 (tokens, Some(faux_names.into_iter().collect()))
487 }
488 Fields::Unit => (quote! {}, None),
489 }
490 }
491
492 fn apply_mapping_to_fields(
493 &self,
494 fields: &mut Fields,
495 name: Ident,
496 names_for_unnamed: FieldNameMapping,
497 is_map_res: bool,
498 ) -> TokenStream2 {
499 match fields {
500 Fields::Named(named) => {
501 let mapped: Punctuated<TokenStream2, Token![,]> = named
502 .named
503 .iter_mut()
504 .map(|field| {
505 let field_name = field.ident.clone().unwrap();
507 let new_field_content = self.apply_mapping_to_field_ref(
508 field,
509 quote! {#field_name},
510 is_map_res,
511 );
512 quote! {
513 #field_name: #new_field_content
514 }
515 })
516 .collect();
517 let implemented = mapped.to_token_stream();
518 quote! {
519 #name {
520 #implemented
521 }
522 }
523 }
524 Fields::Unnamed(unnamed) => {
525 let names = names_for_unnamed.unwrap();
526 let mapped: Punctuated<TokenStream2, Token![,]> = unnamed
527 .unnamed
528 .iter_mut()
529 .enumerate()
530 .map(|(field_num, field)| {
531 let name_of_field = &names[field_num];
532 let field_ref = quote! {#name_of_field};
533 let new_field_content =
534 self.apply_mapping_to_field_ref(field, field_ref, is_map_res);
535 quote! {
536 #new_field_content
537 }
538 })
539 .collect();
540 quote! {
541 #name(#mapped)
542 }
543 }
544 Fields::Unit => quote! {#name},
545 }
546 }
547
548 fn apply_mapping_to_field_ref(
549 &self,
550 field: &mut Field,
551 field_ref: TokenStream2,
552 is_map_res: bool,
553 ) -> TokenStream2 {
554 let postfix = if is_map_res {
555 quote! {?}
556 } else {
557 quote! {}
558 };
559 match self.get_mappable_generics_of_type(&field.ty) {
560 Some(mapping) => match mapping {
561 FieldMapping::Trivial(type_to_map) => {
562 let map = self.args.get_whole_map_name(&type_to_map);
563 quote! {
564 #map(#field_ref)#postfix
565 }
566 }
567 FieldMapping::SubMapForArgs(sub_maps) => {
569 let map_all_name = self.args.get_map_all_name();
570 let all_fns: Punctuated<TokenStream2, Token![,]> = sub_maps
571 .iter()
572 .map(|ident| {
573 let ident = self.args.get_whole_map_name(ident);
574 quote! {&#ident}
575 })
576 .collect();
577 match FieldArg::find_in_attributes(field.attrs.iter()) {
578 Some(FieldArg {
579 alt_function,
580 map_res_with_function,
581 }) => {
582 let function_to_use = if is_map_res && map_res_with_function.is_some() {
583 map_res_with_function.clone().unwrap()
584 } else {
585 alt_function
586 };
587 FieldArg::remove_from_attributes(&mut field.attrs);
588 quote! {
589 (#function_to_use)(#field_ref, #all_fns)#postfix
590 }
591 }
592 None => {
593 quote! {
594 #field_ref.#map_all_name(#all_fns)#postfix
595 }
596 }
597 }
598 }
599 },
600 None => quote! {#field_ref},
602 }
603 }
604
605 fn make_fn_arguments(&self) -> TokenStream2 {
606 let mapped: Punctuated<TokenStream2, Token![,]> = self
607 .type_maps_to_type
608 .iter()
609 .map(|(from, to)| {
610 let fn_name = self.args.get_whole_map_name(from);
611 quote! {
614 #fn_name: impl Fn(#from) -> #to
615 }
616 })
617 .collect();
618 mapped.into_token_stream()
619 }
620
621 fn make_fn_arguments_map_res(&self, err_ident: Ident) -> TokenStream2 {
622 let mapped: Punctuated<TokenStream2, Token![,]> = self
623 .type_maps_to_type
624 .iter()
625 .map(|(from, to)| {
626 let fn_name = self.args.get_whole_map_name(from);
627 quote! {
630 #fn_name: impl Fn(#from) -> Result<#to, #err_ident>
631 }
632 })
633 .collect();
634 mapped.into_token_stream()
635 }
636}
637
638struct FieldArg {
639 pub alt_function: TokenStream2,
640 pub map_res_with_function: Option<TokenStream2>,
641}
642
643impl FieldArg {
644 fn map_with_attr_ident() -> Ident {
645 Ident::new("map_with", Span::call_site())
646 }
647
648 fn remove_from_attributes(attributes: &mut Vec<Attribute>) {
649 let ident_to_check = Self::map_with_attr_ident();
650 let to_remove: Vec<_> = attributes
652 .iter()
653 .enumerate()
654 .rev()
655 .filter_map(|(num, attribute)| match &attribute.meta {
656 Meta::Path(_) => None,
657 Meta::List(meta) => {
658 let last = meta.path.segments.last()?;
659 (last.ident == ident_to_check).then_some(num)
660 }
661 Meta::NameValue(_) => None,
662 })
663 .collect();
664 for remove in to_remove {
665 attributes.remove(remove);
666 }
667 }
668
669 fn find_in_attributes<'a>(mut attributes: impl Iterator<Item = &'a Attribute>) -> Option<Self> {
670 attributes.find_map(|attribute| match &attribute.meta {
671 Meta::Path(_) => None,
672 Meta::List(meta) => Self::from_meta_list(meta),
673 Meta::NameValue(_) => None,
674 })
675 }
676
677 fn from_meta_list(meta: &MetaList) -> Option<Self> {
678 let ident_to_check = Self::map_with_attr_ident();
679 if meta.path.segments.iter().next_back().map(|x| &x.ident) == Some(&ident_to_check) {
680 let punctuated: Punctuated<Expr, Token![,]> = Punctuated::parse_terminated
682 .parse2(meta.tokens.clone())
683 .unwrap();
684 match punctuated.len() {
685 1 => Some(Self {
686 alt_function: punctuated[0].to_token_stream(),
687 map_res_with_function: None,
688 }),
689 2 => Some(Self {
690 alt_function: punctuated[0].to_token_stream(),
691 map_res_with_function: Some(punctuated[1].to_token_stream()),
692 }),
693 _ => Some(Self {
694 alt_function: quote! {compile_error!("Wrong number of arguments passed to map_with, this takes up to 2 arguments: one for regular mapping, and one for the 'map_res' function.")},
695 map_res_with_function: None,
696 }),
697 }
698 } else {
699 None
700 }
701 }
702}