1#![doc = include_str!("../README.md")]
2
3use crate::generate_fmap_body::generate_fmap_body;
4use crate::map::{map_path, map_where};
5use crate::parse_attribute::parse_attribute;
6use proc_macro2::{Ident, TokenStream};
7use proc_macro_error::proc_macro_error;
8use quote::{format_ident, quote};
9use syn::token::Colon;
10use syn::{
11 parse_macro_input, Data, DeriveInput, Expr, ExprPath, GenericArgument, GenericParam, Path,
12 PathSegment, PredicateType, TraitBound, TraitBoundModifier, Type, TypeParamBound, TypePath,
13 WhereClause, WherePredicate,
14};
15
16mod generate_fmap_body;
17mod generate_map;
18mod map;
19mod parse_attribute;
20
21#[proc_macro_derive(Functor, attributes(functor))]
22#[proc_macro_error]
23pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
24 let input = parse_macro_input!(input as DeriveInput);
25
26 let def_name = input.ident.clone();
28
29 let attribute = parse_attribute(&input);
31
32 let source_params = input
34 .generics
35 .params
36 .iter()
37 .map(|param| match param {
38 GenericParam::Type(param) => {
39 let mut param = param.clone();
40 param.eq_token = None;
41 param.default = None;
42 GenericParam::Type(param)
43 }
44 GenericParam::Const(param) => {
45 let mut param = param.clone();
46 param.eq_token = None;
47 param.default = None;
48 GenericParam::Const(param)
49 }
50 param => param.clone(),
51 })
52 .collect::<Vec<_>>();
53
54 let source_args = source_params
56 .iter()
57 .map(|param| match param {
58 GenericParam::Lifetime(l) => GenericArgument::Lifetime(l.lifetime.clone()),
59 GenericParam::Type(t) => GenericArgument::Type(Type::Path(TypePath {
60 qself: None,
61 path: Path::from(PathSegment::from(t.ident.clone())),
62 })),
63 GenericParam::Const(c) => GenericArgument::Const(Expr::Path(ExprPath {
64 attrs: vec![],
65 qself: None,
66 path: Path::from(PathSegment::from(c.ident.clone())),
67 })),
68 })
69 .collect::<Vec<_>>();
70
71 let lints = quote! {
73 #[allow(absolute_paths_not_starting_with_crate)]
74 #[allow(bare_trait_objects)]
75 #[allow(deprecated)]
76 #[allow(drop_bounds)]
77 #[allow(dyn_drop)]
78 #[allow(non_camel_case_types)]
79 #[allow(trivial_bounds)]
80 #[allow(unused_qualifications)]
81 #[allow(clippy::allow)]
82 #[automatically_derived]
83 };
84
85 let mut tokens = TokenStream::new();
86
87 if let Some(default) = attribute.default {
89 tokens.extend(generate_default_impl(
90 &default,
91 &def_name,
92 &source_params,
93 &source_args,
94 &input.generics.where_clause,
95 &lints,
96 ));
97 }
98
99 for (param, name) in attribute.name_map {
101 tokens.extend(generate_named_impl(
102 ¶m,
103 &name,
104 &def_name,
105 &source_params,
106 &source_args,
107 &input.generics.where_clause,
108 &lints,
109 ));
110 }
111
112 tokens.extend(generate_refs_impl(
114 &input.data,
115 &def_name,
116 &source_params,
117 &source_args,
118 &input.generics.where_clause,
119 &lints,
120 ));
121
122 tokens.into()
123}
124
125fn find_index(source_params: &[GenericParam], ident: &Ident) -> usize {
126 for (total, param) in source_params.iter().enumerate() {
127 match param {
128 GenericParam::Type(t) if &t.ident == ident => return total,
129 _ => {}
130 }
131 }
132 unreachable!()
133}
134
135fn generate_refs_impl(
136 data: &Data,
137 def_name: &Ident,
138 source_params: &Vec<GenericParam>,
139 source_args: &Vec<GenericArgument>,
140 where_clause: &Option<WhereClause>,
141 lints: &TokenStream,
142) -> TokenStream {
143 let mut tokens = TokenStream::new();
144 for param in source_params {
145 if let GenericParam::Type(t) = param {
146 let param_ident = t.ident.clone();
147 let param_idx = find_index(source_params, &t.ident);
148
149 let functor_trait_ident = format_ident!("Functor{param_idx}");
150 let fmap_ident = format_ident!("__fmap_{param_idx}_ref");
151 let try_fmap_ident = format_ident!("__try_fmap_{param_idx}_ref");
152
153 let Some(fmap_ref_body) = generate_fmap_body(data, def_name, ¶m_ident, false)
155 else {
156 continue;
157 };
158 let Some(try_fmap_ref_body) = generate_fmap_body(data, def_name, ¶m_ident, true)
159 else {
160 continue;
161 };
162
163 let mut target_args = source_args.clone();
164 target_args[param_idx] = GenericArgument::Type(Type::Path(TypePath {
165 qself: None,
166 path: Path::from(PathSegment::from(format_ident!("__B"))),
167 }));
168
169 if let Some(fn_where_clause) =
170 create_fn_where_clause(where_clause, source_params, ¶m_ident)
171 {
172 tokens.extend(quote!(
173 #lints
174 impl<#(#source_params),*> #def_name<#(#source_args),*> #where_clause {
175 pub fn #fmap_ident<__B>(self, __f: &impl Fn(#param_ident) -> __B) -> #def_name<#(#target_args),*> #fn_where_clause {
176 use ::functor_derive::*;
177 #fmap_ref_body
178 }
179
180 pub fn #try_fmap_ident<__B, __E>(self, __f: &impl Fn(#param_ident) -> Result<__B, __E>) -> Result<#def_name<#(#target_args),*>, __E> #fn_where_clause {
181 use ::functor_derive::*;
182 Ok(#try_fmap_ref_body)
183 }
184 }
185 ))
186 } else {
187 tokens.extend(quote!(
188 #lints
189 impl<#(#source_params),*> ::functor_derive::#functor_trait_ident<#param_ident> for #def_name<#(#source_args),*> #where_clause {
190 type Target<__B> = #def_name<#(#target_args),*>;
191
192 fn #fmap_ident<__B>(self, __f: &impl Fn(#param_ident) -> __B) -> #def_name<#(#target_args),*> {
193 use ::functor_derive::*;
194 #fmap_ref_body
195 }
196
197 fn #try_fmap_ident<__B, __E>(self, __f: &impl Fn(#param_ident) -> Result<__B, __E>) -> Result<#def_name<#(#target_args),*>, __E> {
198 use ::functor_derive::*;
199 Ok(#try_fmap_ref_body)
200 }
201 }
202 ))
203 }
204 }
205 }
206 tokens
207}
208
209fn generate_default_impl(
210 param: &Ident,
211 def_name: &Ident,
212 source_params: &Vec<GenericParam>,
213 source_args: &Vec<GenericArgument>,
214 where_clause: &Option<WhereClause>,
215 lints: &TokenStream,
216) -> TokenStream {
217 let default_idx = find_index(source_params, param);
218
219 let mut target_args = source_args.clone();
221 target_args[default_idx] = GenericArgument::Type(Type::Path(TypePath {
222 qself: None,
223 path: Path::from(PathSegment::from(format_ident!("__B"))),
224 }));
225
226 let default_map = format_ident!("__fmap_{default_idx}_ref");
227 let default_try_map = format_ident!("__try_fmap_{default_idx}_ref");
228
229 if let Some(fn_where_clause) = create_fn_where_clause(where_clause, source_params, param) {
230 quote!(
231 #lints
232 impl<#(#source_params),*> #def_name<#(#source_args),*> #where_clause {
233 pub fn fmap<__B>(self, __f: impl Fn(#param) -> __B) -> #def_name<#(#target_args),*> #fn_where_clause {
234 use ::functor_derive::*;
235 self.#default_map(&__f)
236 }
237
238 pub fn try_fmap<__B, __E>(self, __f: impl Fn(#param) -> Result<__B, __E>) -> Result<#def_name<#(#target_args),*>, __E> #fn_where_clause {
239 use ::functor_derive::*;
240 self.#default_try_map(&__f)
241 }
242 }
243 )
244 } else {
245 quote!(
246 #lints
247 impl<#(#source_params),*> ::functor_derive::Functor<#param> for #def_name<#(#source_args),*> {
248 type Target<__B> = #def_name<#(#target_args),*>;
249
250 fn fmap<__B>(self, __f: impl Fn(#param) -> __B) -> #def_name<#(#target_args),*> {
251 use ::functor_derive::*;
252 self.#default_map(&__f)
253 }
254
255 fn try_fmap<__B, __E>(self, __f: impl Fn(#param) -> Result<__B, __E>) -> Result<#def_name<#(#target_args),*>, __E> {
256 use ::functor_derive::*;
257 self.#default_try_map(&__f)
258 }
259 }
260 )
261 }
262}
263
264fn generate_named_impl(
265 param: &Ident,
266 name: &Ident,
267 def_name: &Ident,
268 source_params: &Vec<GenericParam>,
269 source_args: &Vec<GenericArgument>,
270 where_clause: &Option<WhereClause>,
271 lints: &TokenStream,
272) -> TokenStream {
273 let default_idx = find_index(source_params, param);
274
275 let mut target_args = source_args.clone();
277 target_args[default_idx] = GenericArgument::Type(Type::Path(TypePath {
278 qself: None,
279 path: Path::from(PathSegment::from(format_ident!("__B"))),
280 }));
281
282 let fmap_name = format_ident!("fmap_{name}");
283 let try_fmap_name = format_ident!("try_fmap_{name}");
284
285 let fmap = format_ident!("__fmap_{default_idx}_ref");
286 let fmap_try = format_ident!("__try_fmap_{default_idx}_ref");
287
288 let fn_where_clause = create_fn_where_clause(where_clause, source_params, param);
289
290 quote!(
291 #lints
292 impl<#(#source_params),*> #def_name<#(#source_args),*> #where_clause {
293 pub fn #fmap_name<__B>(self, __f: impl Fn(#param) -> __B) -> #def_name<#(#target_args),*> #fn_where_clause {
294 use ::functor_derive::*;
295 self.#fmap(&__f)
296 }
297
298 pub fn #try_fmap_name<__B, __E>(self, __f: impl Fn(#param) -> Result<__B, __E>) -> Result<#def_name<#(#target_args),*>, __E> #fn_where_clause {
299 use ::functor_derive::*;
300 self.#fmap_try(&__f)
301 }
302 }
303 )
304}
305
306fn create_fn_where_clause(
307 where_clause: &Option<WhereClause>,
308 source_params: &Vec<GenericParam>,
309 param: &Ident,
310) -> Option<WhereClause> {
311 let mut predicates = where_clause
312 .iter()
313 .flat_map(|where_clause| map_where(where_clause, param))
314 .flat_map(|where_clause| where_clause.predicates)
315 .collect::<Vec<_>>();
316
317 for source_param in source_params {
318 if let GenericParam::Type(typ) = source_param {
319 if typ.bounds.is_empty() {
320 continue;
321 };
322
323 let bounds = typ
324 .bounds
325 .iter()
326 .cloned()
327 .flat_map(|bound| {
328 if let TypeParamBound::Trait(mut trt) = bound {
329 match trt.modifier {
330 TraitBoundModifier::Maybe(_) => None,
331 TraitBoundModifier::None => {
332 map_path(&mut trt.path, param, &mut false);
333 Some(TypeParamBound::Trait(trt))
334 }
335 }
336 } else {
337 Some(bound)
338 }
339 })
340 .collect();
341
342 predicates.push(WherePredicate::Type(PredicateType {
343 lifetimes: None,
344 bounded_ty: Type::Path(TypePath {
345 qself: None,
346 path: Path {
347 leading_colon: None,
348 segments: [PathSegment {
349 ident: if &typ.ident == param {
350 format_ident!("__B")
351 } else {
352 typ.ident.clone()
353 },
354 arguments: Default::default(),
355 }]
356 .into_iter()
357 .collect(),
358 },
359 }),
360 colon_token: Colon::default(),
361 bounds,
362 }))
363 }
364 }
365
366 predicates.push(WherePredicate::Type(PredicateType {
368 lifetimes: None,
369 bounded_ty: Type::Path(TypePath {
370 qself: None,
371 path: Path {
372 leading_colon: None,
373 segments: [PathSegment {
374 ident: param.clone(),
375 arguments: Default::default(),
376 }]
377 .into_iter()
378 .collect(),
379 },
380 }),
381 colon_token: Colon::default(),
382 bounds: [TypeParamBound::Trait(TraitBound {
383 paren_token: None,
384 modifier: TraitBoundModifier::None,
385 lifetimes: None,
386 path: Path {
387 leading_colon: None,
388 segments: [PathSegment {
389 ident: format_ident!("Sized"),
390 arguments: Default::default(),
391 }]
392 .into_iter()
393 .collect(),
394 },
395 })]
396 .into_iter()
397 .collect(),
398 }));
399
400 if predicates.is_empty() {
401 None
402 } else {
403 Some(WhereClause {
404 where_token: Default::default(),
405 predicates: predicates.into_iter().collect(),
406 })
407 }
408}