1#![recursion_limit = "256"]
2
3extern crate proc_macro;
10
11use std::collections::HashMap;
12
13use proc_macro2::{Span, TokenStream};
14use quote::{quote, quote_spanned};
15use syn::{
16 parse_macro_input, punctuated::Punctuated, spanned::Spanned, token::Comma, Data, DataEnum,
17 DeriveInput, Field, Fields, FieldsNamed, FieldsUnnamed, GenericParam, Ident, Index, Type,
18 TypeParam,
19};
20
21fn type_params_replace(
22 input_params: &Punctuated<GenericParam, Comma>,
23 replace: &TypeParam,
24 with: Ident,
25) -> Punctuated<GenericParam, Comma> {
26 let mut output = input_params.clone();
27 for param in output.iter_mut() {
28 match param {
29 GenericParam::Type(ref mut type_param) if type_param == replace => {
30 *(&mut type_param.ident) = with;
31 break;
32 }
33 _ => {}
34 }
35 }
36 output
37}
38
39fn report_error(span: Span, msg: &str) -> proc_macro::TokenStream {
40 (quote_spanned! {span => compile_error! {#msg}}).into()
41}
42
43fn decide_functor_generic_type<'a>(
44 input: &'a DeriveInput,
45) -> Result<&'a TypeParam, proc_macro::TokenStream> {
46 let mut generics_iter = input.generics.type_params();
47 let generic_type = match generics_iter.next() {
48 Some(t) => t,
49 None => {
50 return Err(report_error(
51 input.ident.span(),
52 "can't derive Functor for a type without type parameters",
53 ));
54 }
55 };
56
57 if let Some(next_type_param) = generics_iter.next() {
58 return Err(report_error(
59 next_type_param.span(),
60 "can't derive Functor for a type with multiple type parameters; did you mean Bifunctor?",
61 ));
62 }
63
64 return Ok(generic_type);
65}
66
67fn decide_bifunctor_generic_types<'a>(
68 input: &'a DeriveInput,
69) -> Result<(&'a TypeParam, &'a TypeParam), proc_macro::TokenStream> {
70 let mut generics_iter = input.generics.type_params();
71 let generic_type_a = match generics_iter.next() {
72 Some(t) => t,
73 None => {
74 return Err(report_error(
75 input.ident.span(),
76 "can't derive Bifunctor for a type without type parameters",
77 ))
78 }
79 };
80
81 let generic_type_b = match generics_iter.next() {
82 Some(t) => t,
83 None => return Err(report_error(
84 input.ident.span(),
85 "can't derive Bifunctor for a type with only one type parameter; did you mean Functor?",
86 )),
87 };
88
89 if let Some(next_type_param) = generics_iter.next() {
90 return Err(report_error(
91 next_type_param.span(),
92 "can't derive Functor for a type with three or more type parameters",
93 ));
94 }
95
96 return Ok((generic_type_a, generic_type_b));
97}
98
99#[proc_macro_derive(Bifunctor)]
100pub fn derive_bifunctor(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
101 let input = parse_macro_input!(input as DeriveInput);
102 let name = &input.ident;
103 let type_params = &input.generics.params;
104 let where_clause = &input.generics.where_clause;
105
106 let (generic_type_a, generic_type_b) = match decide_bifunctor_generic_types(&input) {
107 Ok(t) => t,
108 Err(err) => return err,
109 };
110
111 let type_map = HashMap::from([
112 (
113 generic_type_a.ident.clone(),
114 Ident::new("left", Span::call_site()),
115 ),
116 (
117 generic_type_b.ident.clone(),
118 Ident::new("right", Span::call_site()),
119 ),
120 ]);
121
122 let bimap_impl = match &input.data {
123 Data::Struct(data) => match &data.fields {
124 Fields::Named(fields) => derive_functor_named_struct(name, fields, &type_map),
125 Fields::Unnamed(fields) => derive_functor_unnamed_struct(name, fields, &type_map),
126 Fields::Unit => {
127 return report_error(
128 input.ident.span(),
129 "can't derive Bifunctor for an empty struct",
130 );
131 }
132 },
133 Data::Enum(data) => derive_functor_enum(name, data, &type_map),
134 Data::Union(_) => {
135 return report_error(
136 input.ident.span(),
137 "can't derive Bifunctor for a union type",
138 );
139 }
140 };
141
142 let type_params_generic = type_params_replace(
143 &type_params_replace(
144 type_params,
145 generic_type_a,
146 Ident::new("DerivedTargetTypeA", Span::call_site()),
147 ),
148 generic_type_b,
149 Ident::new("DerivedTargetTypeB", Span::call_site()),
150 );
151
152 quote!(
153 impl<#type_params> ::higher::Bifunctor<'_, #generic_type_a, #generic_type_b> for #name<#type_params> #where_clause {
154 type Target<DerivedTargetTypeA, DerivedTargetTypeB> = #name<#type_params_generic>;
155 fn bimap<DerivedTypeA, DerivedTypeB, L, R>(self, left: L, right: R) -> Self::Target<DerivedTypeA, DerivedTypeB>
156 where
157 L: Fn(#generic_type_a) -> DerivedTypeA,
158 R: Fn(#generic_type_b) -> DerivedTypeB
159 {
160 #bimap_impl
161 }
162 }
163 )
164 .into()
165}
166
167#[proc_macro_derive(Functor)]
168pub fn derive_functor(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
169 let input = parse_macro_input!(input as DeriveInput);
170 let name = &input.ident;
171 let type_params = &input.generics.params;
172 let where_clause = &input.generics.where_clause;
173
174 let generic_type = match decide_functor_generic_type(&input) {
175 Ok(t) => t,
176 Err(err) => return err,
177 };
178
179 let type_map = HashMap::from([(
180 generic_type.ident.clone(),
181 Ident::new("f", Span::call_site()),
182 )]);
183
184 let fmap_impl = match &input.data {
185 Data::Struct(data) => match &data.fields {
186 Fields::Named(fields) => derive_functor_named_struct(name, fields, &type_map),
187 Fields::Unnamed(fields) => derive_functor_unnamed_struct(name, fields, &type_map),
188 Fields::Unit => {
189 return report_error(
190 input.ident.span(),
191 "can't derive Functor for an empty struct",
192 );
193 }
194 },
195 Data::Enum(data) => derive_functor_enum(name, data, &type_map),
196 Data::Union(_) => {
197 return report_error(input.ident.span(), "can't derive Functor for a union type");
198 }
199 };
200
201 let type_params_with_t = type_params_replace(
202 type_params,
203 generic_type,
204 Ident::new("DerivedTargetType", Span::call_site()),
205 );
206
207 quote!(
208 impl<#type_params> ::higher::Functor<'_, #generic_type> for #name<#type_params> #where_clause {
209 type Target<DerivedTargetType> = #name<#type_params_with_t>;
210 fn fmap<DerivedType, F>(self, f: F) -> Self::Target<DerivedType>
211 where
212 F: Fn(#generic_type) -> DerivedType
213 {
214 #fmap_impl
215 }
216 }
217 )
218 .into()
219}
220
221fn match_type_param<'a>(params: &'a HashMap<Ident, Ident>, ty: &Type) -> Option<&'a Ident> {
222 if let Type::Path(path) = ty {
223 if let Some(segment) = path.path.segments.iter().next() {
224 return params.get(&segment.ident);
225 }
226 }
227 None
228}
229
230fn filter_fields<'a, P, F1, F2>(
231 fields: &'a Punctuated<Field, P>,
232 ty: &HashMap<Ident, Ident>,
233 transform: F1,
234 copy: F2,
235) -> Vec<TokenStream>
236where
237 F1: Fn(&Ident, &Ident) -> TokenStream,
238 F2: Fn(&Ident) -> TokenStream,
239{
240 fields
241 .iter()
242 .map(|field| {
243 if let Some(f) = match_type_param(ty, &field.ty) {
244 transform(&field.ident.clone().unwrap(), f)
245 } else {
246 copy(&field.ident.clone().unwrap())
247 }
248 })
249 .collect()
250}
251
252fn derive_functor_named_struct(
253 name: &Ident,
254 fields: &FieldsNamed,
255 generic_types: &HashMap<Ident, Ident>,
256) -> TokenStream {
257 let apply_fields = filter_fields(
258 &fields.named,
259 generic_types,
260 |field, function_name| {
261 quote! {
262 #field: #function_name(self.#field),
263 }
264 },
265 |field| {
266 quote! {
267 #field: self.#field,
268 }
269 },
270 )
271 .into_iter();
272 quote! {
273 #name {
274 #(#apply_fields)*
275 }
276 }
277}
278
279fn derive_functor_unnamed_struct(
280 name: &Ident,
281 fields: &FieldsUnnamed,
282 generic_types: &HashMap<Ident, Ident>,
283) -> TokenStream {
284 let fields = fields.unnamed.iter().enumerate().map(|(index, field)| {
285 let index = Index::from(index);
286 if let Some(function_name) = match_type_param(generic_types, &field.ty) {
287 quote! { #function_name(self.#index), }
288 } else {
289 quote! { self.#index, }
290 }
291 });
292 quote! { #name(#(#fields)*) }
293}
294
295fn derive_functor_enum(
296 name: &Ident,
297 data: &DataEnum,
298 generic_types: &HashMap<Ident, Ident>,
299) -> TokenStream {
300 let variants = data.variants.iter().map(|variant| {
301 let ident = &variant.ident;
302 match &variant.fields {
303 Fields::Named(fields) => {
304 let args: Vec<Ident> = fields
305 .named
306 .iter()
307 .map(|field| {
308 Ident::new(
309 &format!("arg_{}", field.ident.clone().unwrap()),
310 field.ident.clone().unwrap().span(),
311 )
312 })
313 .collect();
314 let apply =
315 fields
316 .named
317 .iter()
318 .zip(args.clone().into_iter())
319 .map(|(field, arg)| {
320 let name = &field.ident;
321 if let Some(function_name) = match_type_param(generic_types, &field.ty)
322 {
323 quote! { #name: #function_name(#arg) }
324 } else {
325 quote! { #name: #arg }
326 }
327 });
328 let args = fields
329 .named
330 .iter()
331 .zip(args.into_iter())
332 .map(|(field, arg)| {
333 let name = &field.ident;
334 quote! { #name:#arg }
335 });
336 quote! {
337 #name::#ident { #(#args,)* } => #name::#ident { #(#apply,)* },
338 }
339 }
340 Fields::Unnamed(fields) => {
341 let args: Vec<Ident> = fields
342 .unnamed
343 .iter()
344 .enumerate()
345 .map(|(index, _)| Ident::new(&format!("arg{}", index), Span::call_site()))
346 .collect();
347 let fields = fields.unnamed.iter().zip(args.iter()).map(|(field, arg)| {
348 if let Some(function_name) = match_type_param(generic_types, &field.ty) {
349 quote! { #function_name(#arg) }
350 } else {
351 quote! { #arg }
352 }
353 });
354 let args = args.iter();
355 quote! {
356 #name::#ident(#(#args,)*) => #name::#ident(#(#fields,)*),
357 }
358 }
359 Fields::Unit => quote! {
360 #name::#ident => #name::#ident,
361 },
362 }
363 });
364 quote! {
365 match self {
366 #(#variants)*
367 }
368 }
369}
370
371#[cfg(test)]
372mod test {
373 use higher::{Bifunctor, Functor};
374
375 #[derive(PartialEq, Eq, Debug, Functor)]
376 struct FunctorNamed<A> {
377 named: A,
378 }
379
380 #[derive(PartialEq, Eq, Debug, Functor)]
381 struct FunctorUnnamed<A>(A);
382
383 #[derive(PartialEq, Eq, Debug, Functor)]
384 #[allow(dead_code)]
385 enum FunctorEnum<A> {
386 Some(A),
387 SomeNumber(usize),
388 SomeOther(A),
389 None,
390 }
391
392 #[test]
393 fn derive_functor() {
394 assert_eq!(
395 (FunctorNamed { named: 2u32 }).fmap(|x| x + 3),
396 FunctorNamed { named: 5u32 }
397 );
398
399 assert_eq!(FunctorUnnamed(2u32).fmap(|x| x + 3), FunctorUnnamed(5u32));
400
401 assert_eq!(
402 FunctorEnum::Some(2u32).fmap(|x| x + 3),
403 FunctorEnum::Some(5u32)
404 );
405 assert_eq!(
406 FunctorEnum::<u32>::SomeNumber(2).fmap(|x| x + 3),
407 FunctorEnum::<u32>::SomeNumber(2)
408 );
409 assert_eq!(
410 FunctorEnum::SomeOther(2u32).fmap(|x| x + 3),
411 FunctorEnum::SomeOther(5u32)
412 );
413 assert_eq!(FunctorEnum::<u32>::None.fmap(|x| x + 3), FunctorEnum::None);
414 }
415
416 #[derive(PartialEq, Eq, Debug, Bifunctor)]
417 struct BifunctorNamed<A, B> {
418 a: A,
419 b: B,
420 }
421
422 #[derive(PartialEq, Eq, Debug, Bifunctor)]
423 struct BifunctorUnnamed<A, B>(A, B);
424
425 #[derive(PartialEq, Eq, Debug, Bifunctor)]
426 #[allow(dead_code)]
427 enum BifunctorEnum<A, B> {
428 Ok(A),
429 Err(B),
430 Number(usize),
431 Nothing,
432 }
433
434 #[test]
435 fn derive_bifunctor() {
436 assert_eq!(
437 (BifunctorNamed { a: 2u32, b: 2u8 }).bimap(|x| x + 3, |x| x + 4),
438 BifunctorNamed { a: 5u32, b: 6u8 }
439 );
440
441 assert_eq!(
442 BifunctorUnnamed(2u32, 2u8).bimap(|x| x + 3, |x| x + 4),
443 BifunctorUnnamed(5u32, 6u8)
444 );
445
446 assert_eq!(
447 BifunctorEnum::<u32, u8>::Ok(2u32).bimap(|x| x + 3, |x| x + 4),
448 BifunctorEnum::Ok(5u32)
449 );
450 assert_eq!(
451 BifunctorEnum::<u32, u8>::Err(2u8).bimap(|x| x + 3, |x| x + 4),
452 BifunctorEnum::Err(6u8)
453 );
454 assert_eq!(
455 BifunctorEnum::<u32, u8>::Number(2).bimap(|x| x + 3, |x| x + 4),
456 BifunctorEnum::Number(2)
457 );
458 assert_eq!(
459 BifunctorEnum::<u32, u8>::Nothing.bimap(|x| x + 3, |x| x + 4),
460 BifunctorEnum::Nothing
461 );
462 }
463}