1use std::collections::HashMap;
19
20use heck::ToPascalCase;
21use itertools::Itertools;
22use proc_macro2::{Span, TokenStream};
23use quote::quote;
24use syn::parse::{Parse, ParseStream};
25use syn::punctuated::Punctuated;
26use syn::spanned::Spanned;
27use syn::{parse_quote, parse_quote_spanned, Error, Result};
28
29#[proc_macro_attribute]
30pub fn derive_trait(
31 attr: proc_macro::TokenStream,
32 item: proc_macro::TokenStream,
33) -> proc_macro::TokenStream {
34 real_derive_trait(attr.into(), item.into()).unwrap_or_else(Error::into_compile_error).into()
35}
36
37fn real_derive_trait(attr: TokenStream, item: TokenStream) -> Result<TokenStream> {
38 let attr: Attr = syn::parse2(attr)?;
39 let Attr { debug_print, vis, trait_ident, supers, generics: trait_generics, fixed_assoc_types } =
40 &attr;
41 let debug_print = debug_print.is_some();
42 let (_, trait_generics_ref, _) = trait_generics.split_for_impl();
43
44 let inherent_impl: syn::ItemImpl = syn::parse2(item)?;
45
46 let mut item_trait = syn::ItemTrait {
47 attrs: Vec::new(),
48 vis: vis.clone(),
49 unsafety: None,
50 auto_token: None,
51 restriction: None,
52 trait_token: syn::Token),
53 ident: trait_ident.clone(),
54 generics: trait_generics.clone(),
55 colon_token: supers.as_ref().map(|&(colon, _)| colon),
56 supertraits: supers.clone().map(|(_, supers)| supers).unwrap_or_default(),
57 brace_token: syn::token::Brace(Span::call_site()),
58 items: Vec::new(),
59 };
60
61 let mut item_impl = syn::ItemImpl {
62 attrs: Vec::new(),
63 unsafety: None,
64 defaultness: None,
65 impl_token: syn::Token),
66 generics: inherent_impl.generics.clone(),
67 trait_: Some((
68 None,
69 syn::parse_quote!(#trait_ident #trait_generics_ref),
70 syn::Token),
71 )),
72 self_ty: inherent_impl.self_ty.clone(),
73 brace_token: syn::token::Brace(Span::call_site()),
74 items: Vec::new(),
75 };
76
77 let mut ident_assoc_map = HashMap::new();
78
79 for assoc in fixed_assoc_types {
80 if assoc.generics.params.len() > 0 {
81 return Err(syn::Error::new_spanned(&assoc.generics, "GATs here are not supported"));
82 }
83
84 let Some((
85 eq_token,
86 default @ syn::Type::Path(syn::TypePath { qself: None, path: default_ident }),
87 )) = &assoc.default
88 else {
89 return Err(syn::Error::new_spanned(assoc, "expected `type Type: Bounds = Ident;`"));
90 };
91 let default_ident = default_ident.require_ident()?;
92
93 ident_assoc_map.insert(default_ident.clone(), assoc.ident.clone());
94
95 item_trait.items.push(syn::TraitItem::Type(syn::TraitItemType {
96 attrs: assoc.attrs.clone(),
97 type_token: assoc.type_token,
98 ident: assoc.ident.clone(),
99 generics: assoc.generics.clone(),
100 colon_token: assoc.colon_token,
101 bounds: assoc.bounds.clone(),
102 default: None,
103 semi_token: assoc.semi_token,
104 }));
105
106 item_impl.items.push(syn::ImplItem::Type(syn::ImplItemType {
107 attrs: Vec::new(),
108 vis: syn::Visibility::Inherited,
109 defaultness: None,
110 type_token: assoc.type_token,
111 ident: assoc.ident.clone(),
112 generics: assoc.generics.clone(),
113 eq_token: *eq_token,
114 ty: default.clone(),
115 semi_token: assoc.semi_token,
116 }));
117 }
118
119 struct ReplaceIdentVisitor<'t>(&'t HashMap<syn::Ident, syn::Ident>);
120 impl<'t> syn::visit_mut::VisitMut for ReplaceIdentVisitor<'t> {
121 fn visit_type_path_mut(&mut self, type_path: &mut syn::TypePath) {
122 if let Some(ident) = type_path.path.segments.first() {
123 if let Some(target) = self.0.get(&ident.ident) {
124 let mut segments: Vec<_> =
125 type_path.path.segments.clone().into_pairs().collect();
126 segments.insert(
127 0,
128 syn::punctuated::Pair::Punctuated(
129 syn::PathSegment {
130 ident: syn::Ident::new("Self", ident.span()),
131 arguments: syn::PathArguments::None,
132 },
133 syn::Token),
134 ),
135 );
136 *segments[1].value_mut() = syn::PathSegment {
137 ident: target.clone(),
138 arguments: syn::PathArguments::None,
139 };
140 type_path.path.segments = segments.into_iter().collect();
141 }
142 }
143
144 if let Some(qself) = &mut type_path.qself {
145 self.visit_type_mut(&mut qself.ty);
146 }
147 self.visit_path_mut(&mut type_path.path);
148 }
149 }
150 let mut replace_ident_visitor = ReplaceIdentVisitor(&ident_assoc_map);
151
152 let self_ty = &*inherent_impl.self_ty;
153
154 for item in &inherent_impl.items {
155 match item {
156 syn::ImplItem::Fn(item) => {
157 let mut sig = item.sig.clone();
158 let sig_span = sig.span();
159
160 if let syn::ReturnType::Type(r_arrow, ret_ty) = &sig.output {
161 let transformed = for_each_impl_trait(ret_ty, &mut |tit| {
162 let span = tit.span();
163
164 let assoc_ident = item.sig.ident.to_string().to_pascal_case();
166 let assoc_ident = syn::Ident::new(&assoc_ident, span);
167 let ty_bounds = &tit.bounds;
168
169 let (assoc_generics, assoc_generics_names, assoc_where) = if sig
175 .generics
176 .params
177 .is_empty()
178 {
179 (None, None, None)
180 } else {
181 let (sig_impl_generics, sig_ty_generics, sig_where_generics) =
182 sig.generics.split_for_impl();
183 let mut sig_impl_generics: syn::Generics =
184 syn::parse_quote!(#sig_impl_generics);
185 let mut sig_ty_generics: syn::AngleBracketedGenericArguments =
186 syn::parse_quote!(#sig_ty_generics);
187 let mut sig_where_generics = sig_where_generics.cloned();
188
189 if let Some(recv) = sig.receiver() {
190 if let Some((and, lt)) = &recv.reference {
191 let lt = match lt {
192 Some(lt) => lt.clone(),
193 None => {
194 let lt: syn::Lifetime =
195 syn::parse_quote_spanned!(and.span() => '__self);
196 sig_impl_generics.params.push(
197 syn::GenericParam::Lifetime(parse_quote!(#lt)),
198 );
199 sig_ty_generics
200 .args
201 .push(syn::parse_quote_spanned!(and.span() => '_));
202 lt
203 }
204 };
205 let where_predicate: syn::WherePredicate =
206 syn::parse_quote_spanned!(and.span() => Self: #lt);
207 sig_where_generics.get_or_insert(syn::WhereClause {
208 where_token: syn::Token,
209 predicates: Punctuated::new(),
210 }).predicates.push(syn::parse_quote_spanned!(and.span() => #where_predicate));
211 }
212 }
213
214 (
215 Some(quote!(#sig_impl_generics)),
216 Some(quote!(#sig_ty_generics)),
217 Some(quote!(#sig_where_generics)),
218 )
219 };
220
221 let assoc_doc = format!(
222 "Return value for [`{fn_ident}`](Self::{fn_ident})",
223 fn_ident = &sig.ident
224 );
225 let mut trait_item_ty: syn::TraitItemType = parse_quote_spanned! { span =>
226 #[doc = #assoc_doc]
227 type #assoc_ident #assoc_generics: #ty_bounds #assoc_where;
228 };
229 syn::visit_mut::visit_trait_item_type_mut(
230 &mut replace_ident_visitor,
231 &mut trait_item_ty,
232 );
233 item_trait.items.push(syn::TraitItem::Type(trait_item_ty));
234 item_impl.items.push(parse_quote_spanned! { span =>
235 type #assoc_ident #assoc_generics = #tit #assoc_where;
236 });
237
238 parse_quote_spanned! { span =>
239 Self::#assoc_ident #assoc_generics_names
240 }
241 });
242 sig.output = syn::ReturnType::Type(*r_arrow, Box::new(transformed));
243 }
244
245 let sig_ident = &sig.ident;
246 let sig_args: Vec<syn::Pat> = sig
247 .inputs
248 .iter()
249 .map(|input| match input {
250 syn::FnArg::Receiver(syn::Receiver { self_token, .. }) => {
251 parse_quote!(#self_token)
252 }
253 syn::FnArg::Typed(typed) => (*typed.pat).clone(),
254 })
255 .collect();
256
257 let fn_docs: Vec<_> =
258 item.attrs.iter().filter(|attr| attr.path().is_ident("doc")).cloned().collect();
259
260 let mut trait_item_fn = syn::TraitItemFn {
261 attrs: fn_docs.clone(),
262 sig: sig.clone(),
263 default: None,
264 semi_token: Some(syn::Token)),
265 };
266 syn::visit_mut::visit_trait_item_fn_mut(
267 &mut replace_ident_visitor,
268 &mut trait_item_fn,
269 );
270 item_trait.items.push(syn::TraitItem::Fn(trait_item_fn));
271
272 item_impl.items.push(syn::ImplItem::Fn(syn::ImplItemFn {
273 attrs: fn_docs.clone(),
274 vis: syn::Visibility::Inherited,
275 defaultness: None,
276 sig: sig.clone(),
277 block: parse_quote_spanned! { item.span() => {
278 <#self_ty>::#sig_ident(#(#sig_args),*)
279 }},
280 }));
281 }
282 _ => return Err(Error::new_spanned(item, "only associated functions are supported")),
283 }
284 }
285
286 let trait_item_doc = format!(
287 "Derived trait for [`{}`].",
288 match self_ty {
289 syn::Type::Path(path) =>
290 path.path.segments.iter().map(|ident| ident.ident.to_string()).join("::"),
291 _ => quote!(#self_ty).to_string(),
292 }
293 );
294
295 let output = quote! {
296 #[allow(clippy::needless_lifetimes)]
297 #inherent_impl
298 #[allow(clippy::needless_lifetimes, non_camel_case_types)]
299 #[doc = #trait_item_doc]
300 #item_trait
301 #[automatically_derived]
302 #[allow(clippy::needless_lifetimes, non_camel_case_types)]
303 #item_impl
304 };
305 if debug_print {
306 println!("{}", output);
307 }
308 Ok(output)
309}
310
311fn for_each_impl_trait(
312 ty: &syn::Type,
313 f: &mut impl FnMut(&syn::TypeImplTrait) -> syn::Type,
314) -> syn::Type {
315 match ty {
316 syn::Type::Array(ty) => syn::Type::Array(syn::TypeArray {
317 elem: Box::new(for_each_impl_trait(&*ty.elem, f)),
318 ..ty.clone()
319 }),
320 syn::Type::BareFn(ty) => syn::Type::BareFn(syn::TypeBareFn {
321 inputs: ty
322 .inputs
323 .clone()
324 .into_pairs()
325 .map(|mut pair| {
326 let value = pair.value_mut();
327 value.ty = for_each_impl_trait(&value.ty, f);
328 pair
329 })
330 .collect(),
331 ..ty.clone()
332 }),
333 syn::Type::Group(_) => ty.clone(),
334 syn::Type::ImplTrait(ty) => f(ty),
335 syn::Type::Infer(_) => ty.clone(),
336 syn::Type::Macro(_) => ty.clone(),
337 syn::Type::Never(_) => ty.clone(),
338 syn::Type::Paren(ty) => syn::Type::Paren(syn::TypeParen {
339 elem: Box::new(for_each_impl_trait(&*ty.elem, f)),
340 ..ty.clone()
341 }),
342 syn::Type::Path(ty) => syn::Type::Path(syn::TypePath {
343 qself: ty.qself.clone().map(|mut qself| {
344 qself.ty = Box::new(for_each_impl_trait(&*qself.ty, f));
345 qself
346 }),
347 path: for_each_impl_trait_in_path(&ty.path, f),
348 }),
349 syn::Type::Ptr(ty) => syn::Type::Ptr(syn::TypePtr {
350 elem: Box::new(for_each_impl_trait(&*ty.elem, f)),
351 ..ty.clone()
352 }),
353 syn::Type::Reference(ty) => syn::Type::Reference(syn::TypeReference {
354 elem: Box::new(for_each_impl_trait(&*ty.elem, f)),
355 ..ty.clone()
356 }),
357 syn::Type::Slice(ty) => syn::Type::Slice(syn::TypeSlice {
358 elem: Box::new(for_each_impl_trait(&*ty.elem, f)),
359 ..ty.clone()
360 }),
361 syn::Type::TraitObject(ty) => syn::Type::TraitObject(syn::TypeTraitObject {
362 bounds: ty
363 .bounds
364 .clone()
365 .into_pairs()
366 .map(|mut pair| {
367 if let syn::TypeParamBound::Trait(bound) = pair.value_mut() {
368 bound.path = for_each_impl_trait_in_path(&bound.path, f);
369 }
370 pair
371 })
372 .collect(),
373 ..ty.clone()
374 }),
375 syn::Type::Tuple(ty) => syn::Type::Tuple(syn::TypeTuple {
376 elems: ty
377 .elems
378 .clone()
379 .into_pairs()
380 .map(|mut pair| {
381 let value = pair.value_mut();
382 *value = for_each_impl_trait(&value, f);
383 pair
384 })
385 .collect(),
386 ..ty.clone()
387 }),
388 syn::Type::Verbatim(_) => ty.clone(),
389 _ => ty.clone(),
390 }
391}
392
393fn for_each_impl_trait_in_path(
394 path: &syn::Path,
395 f: &mut impl FnMut(&syn::TypeImplTrait) -> syn::Type,
396) -> syn::Path {
397 syn::Path {
398 leading_colon: path.leading_colon,
399 segments: path
400 .segments
401 .clone()
402 .into_pairs()
403 .map(|mut pair| {
404 let value = pair.value_mut();
405 match &mut value.arguments {
406 syn::PathArguments::None => {}
407 syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments {
408 args,
409 ..
410 }) => {
411 for pair in args.pairs_mut() {
412 match pair.into_value() {
413 syn::GenericArgument::Type(ty) => *ty = for_each_impl_trait(ty, f),
414 syn::GenericArgument::AssocType(ty) => {
415 ty.ty = for_each_impl_trait(&ty.ty, f)
416 }
417 _ => {}
418 }
419 }
420 }
421 syn::PathArguments::Parenthesized(syn::ParenthesizedGenericArguments {
422 inputs,
423 output,
424 ..
425 }) => {
426 for input in inputs {
427 *input = for_each_impl_trait(input, f);
428 }
429 if let syn::ReturnType::Type(_, ty) = output {
430 *ty = Box::new(for_each_impl_trait(ty, f));
431 }
432 }
433 }
434 pair
435 })
436 .collect(),
437 }
438}
439
440struct Attr {
441 debug_print: Option<kw::__debug_print>,
442 vis: syn::Visibility,
443 trait_ident: syn::Ident,
444 generics: syn::Generics,
445 supers: Option<(syn::Token![:], Punctuated<syn::TypeParamBound, syn::Token![+]>)>,
446 fixed_assoc_types: Vec<syn::TraitItemType>,
447}
448
449impl Parse for Attr {
450 fn parse(input: ParseStream) -> Result<Self> {
451 let debug_print = input.parse::<kw::__debug_print>().ok();
452 let vis = input.parse()?;
453 let trait_ident = input.parse()?;
454 let mut generics = syn::Generics::default();
455 let mut supers = None;
456 let mut fixed_assoc_types = Vec::new();
457
458 while !input.is_empty() {
459 let lh = input.lookahead1();
460 if generics.lt_token.is_none() && lh.peek(syn::Token![<]) {
461 generics = input.parse()?;
462 } else if lh.peek(syn::Token![:]) {
463 supers = Some((input.parse()?, Punctuated::parse_separated_nonempty(input)?));
464 } else if !generics.params.is_empty() && lh.peek(syn::Token![where]) {
465 generics.where_clause = Some(input.parse()?);
466 } else if lh.peek(syn::token::Brace) {
467 let inner;
468 _ = syn::braced!(inner in input);
469 while !inner.is_empty() {
470 fixed_assoc_types.push(inner.parse()?);
471 }
472 } else {
473 return Err(lh.error());
474 }
475 }
476
477 Ok(Self { debug_print, vis, trait_ident, supers, generics, fixed_assoc_types })
478 }
479}
480
481mod kw {
482 use syn::custom_keyword;
483
484 custom_keyword!(Sized);
485 custom_keyword!(__debug_print);
486}